Skip to content

Commit f6bb816

Browse files
committed
Remove broken quantization_config logic (#4654)
1 parent 32fe2f7 commit f6bb816

File tree

2 files changed

+7
-121
lines changed

2 files changed

+7
-121
lines changed

paddlex/inference/models/common/vlm/transformers/configuration_utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -823,11 +823,6 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
823823
)
824824
to_remove = []
825825
for key, value in kwargs.items():
826-
if key == "quantization_config" and isinstance(value, Dict):
827-
for q_key in value:
828-
setattr(config.quantization_config, q_key, value[q_key])
829-
to_remove.append(key)
830-
continue
831826
if hasattr(config, key):
832827
setattr(config, key, value)
833828
if key != "dtype":
@@ -889,11 +884,6 @@ def to_diff_dict(self, saving_file=False) -> Dict[str, Any]:
889884

890885
# only serialize values that differ from the default config
891886
for key, value in config_dict.items():
892-
if key == "quantization_config":
893-
quantization_diff_dict = self.quantization_config.to_diff_dict()
894-
if len(quantization_diff_dict) > 0:
895-
serializable_config_dict[key] = quantization_diff_dict
896-
continue
897887
if (
898888
key not in default_config_dict
899889
or key == "paddlenlp_version"
@@ -942,16 +932,6 @@ def to_dict(self, saving_file=False) -> Dict[str, Any]:
942932
if key in self._unsavable_keys:
943933
output.pop(key)
944934

945-
if hasattr(self, "quantization_config"):
946-
output["quantization_config"] = (
947-
self.quantization_config.to_dict()
948-
if not isinstance(self.quantization_config, dict)
949-
else self.quantization_config
950-
)
951-
952-
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
953-
_ = output.pop("_pre_quantization_dtype", None)
954-
955935
return output
956936

957937
def update(self, config_dict: Dict[str, Any]):

paddlex/inference/models/common/vlm/transformers/model_utils.py

Lines changed: 7 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,6 @@ def load_state_dict(
258258
tensor_parallel_split_mapping,
259259
fliter_dict_keys,
260260
"expected",
261-
quantization_linear_list=None,
262-
quantization_config=None,
263261
dtype=None,
264262
return_numpy=False,
265263
convert_from_hf=convert_from_hf,
@@ -631,34 +629,6 @@ def set_inference_config(cls, config, predictor_args, **kwargs):
631629
config.weightonly_group_size = predictor_args.weightonly_group_size
632630
config.weight_block_size = predictor_args.weight_block_size
633631
config.moe_quant_type = predictor_args.moe_quant_type
634-
if config.quantization_config.quant_method is not None:
635-
predictor_args.weight_block_size = (
636-
config.quantization_config.weight_block_size
637-
)
638-
config.weight_block_size = predictor_args.weight_block_size
639-
640-
if config.quantization_config.quant_type is not None:
641-
if predictor_args.mode == "dynamic":
642-
predictor_args.quant_type = config.quantization_config.quant_type
643-
config.quant_type = config.quantization_config.quant_type
644-
if "c8" in config.quant_type:
645-
predictor_args.cachekv_int8_type = "static"
646-
if predictor_args.mode == "dynamic":
647-
config.cachekv_int8_type = "static"
648-
649-
if predictor_args.mode == "dynamic":
650-
ptq_multicards_num = 0
651-
if os.path.exists(config.model_name_or_path):
652-
prefix = "act_scales_"
653-
for filename in os.listdir(config.model_name_or_path):
654-
if filename.startswith(prefix):
655-
ptq_multicards_num += 1
656-
657-
logging.info(
658-
f"PTQ from {ptq_multicards_num} cards, so we will not split"
659-
)
660-
if ptq_multicards_num > 1:
661-
config.single_card_ptq = False
662632

663633
if predictor_args.block_attn:
664634
config.block_size = predictor_args.block_size
@@ -1323,45 +1293,6 @@ def _load_pretrained_model(
13231293
".".join([prefix, s]) for s in quantization_linear_list
13241294
]
13251295

1326-
# Weight quantization if not yet quantized & update loaded_keys
1327-
if (
1328-
hasattr(config, "quantization_config")
1329-
and config.quantization_config.is_weight_quantize()
1330-
):
1331-
try:
1332-
from ..quantization.quantization_utils import (
1333-
convert_to_quantize_state_dict,
1334-
update_loaded_state_dict_keys,
1335-
)
1336-
except ImportError:
1337-
raise ImportError(
1338-
"Quantization features require `paddlepaddle >= 2.5.2`"
1339-
)
1340-
if state_dict is not None:
1341-
state_dict = convert_to_quantize_state_dict(
1342-
state_dict,
1343-
quantization_linear_list,
1344-
config.quantization_config,
1345-
dtype,
1346-
)
1347-
loaded_keys = [k for k in state_dict.keys()]
1348-
else:
1349-
loaded_keys = update_loaded_state_dict_keys(
1350-
loaded_keys, quantization_linear_list, config.quantization_config
1351-
)
1352-
if keep_in_fp32_modules is None:
1353-
keep_in_fp32_modules = (
1354-
["quant_scale"]
1355-
if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
1356-
else None
1357-
)
1358-
else:
1359-
keep_in_fp32_modules = (
1360-
keep_in_fp32_modules + ["quant_scale"]
1361-
if config.quantization_config.weight_quantize_algo in ["nf4", "fp4"]
1362-
else keep_in_fp32_modules
1363-
)
1364-
13651296
missing_keys = list(set(expected_keys) - set(loaded_keys))
13661297
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
13671298

@@ -1525,27 +1456,12 @@ def _fuse_or_split_keys(
15251456
ignore_mismatched_sizes,
15261457
)
15271458

1528-
if (
1529-
hasattr(config, "quantization_config")
1530-
and config.quantization_config.is_weight_quantize()
1531-
):
1532-
error_msgs = _load_state_dict_into_meta_model(
1533-
model_to_load,
1534-
state_dict,
1535-
loaded_keys,
1536-
start_prefix,
1537-
expected_keys,
1538-
dtype=dtype,
1539-
is_safetensors=is_safetensors,
1540-
keep_in_fp32_modules=keep_in_fp32_modules,
1541-
)
1542-
else:
1543-
error_msgs = _load_state_dict_into_model(
1544-
model_to_load,
1545-
state_dict,
1546-
start_prefix,
1547-
convert_from_hf=convert_from_hf,
1548-
)
1459+
error_msgs = _load_state_dict_into_model(
1460+
model_to_load,
1461+
state_dict,
1462+
start_prefix,
1463+
convert_from_hf=convert_from_hf,
1464+
)
15491465
else:
15501466
# Sharded checkpoint or whole but low_cpu_mem_usage==True
15511467

@@ -1600,8 +1516,6 @@ def _fuse_or_split_keys(
16001516
if k[-1] in tp_actions:
16011517
fuse_actions.pop(k[-1], None)
16021518

1603-
if config.quantization_config.is_weight_quantize():
1604-
filter_dict_keys = None
16051519
try:
16061520
transpose_weight_keys = model.get_transpose_weight_keys()
16071521
except NotImplementedError:
@@ -1630,14 +1544,6 @@ def _fuse_or_split_keys(
16301544
missing_keys = list(set(missing_keys) - set(new_keys))
16311545
unexpected_keys = list(set(unexpected_keys) - set(fused_keys))
16321546

1633-
if config.quantization_config.is_weight_quantize():
1634-
state_dict = convert_to_quantize_state_dict(
1635-
state_dict,
1636-
quantization_linear_list,
1637-
config.quantization_config,
1638-
dtype,
1639-
)
1640-
16411547
# Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
16421548
# matching the weights in the model.
16431549
mismatched_keys += _find_mismatched_keys(
@@ -1664,7 +1570,7 @@ def _fuse_or_split_keys(
16641570
)
16651571
logging.info("Converted state_dict to Tensor Parallel Format")
16661572

1667-
if low_cpu_mem_usage or config.quantization_config.is_weight_quantize():
1573+
if low_cpu_mem_usage:
16681574
new_error_msgs = _load_state_dict_into_meta_model(
16691575
model_to_load,
16701576
state_dict,

0 commit comments

Comments
 (0)