@@ -258,8 +258,6 @@ def load_state_dict(
258258 tensor_parallel_split_mapping ,
259259 fliter_dict_keys ,
260260 "expected" ,
261- quantization_linear_list = None ,
262- quantization_config = None ,
263261 dtype = None ,
264262 return_numpy = False ,
265263 convert_from_hf = convert_from_hf ,
@@ -631,34 +629,6 @@ def set_inference_config(cls, config, predictor_args, **kwargs):
631629 config .weightonly_group_size = predictor_args .weightonly_group_size
632630 config .weight_block_size = predictor_args .weight_block_size
633631 config .moe_quant_type = predictor_args .moe_quant_type
634- if config .quantization_config .quant_method is not None :
635- predictor_args .weight_block_size = (
636- config .quantization_config .weight_block_size
637- )
638- config .weight_block_size = predictor_args .weight_block_size
639-
640- if config .quantization_config .quant_type is not None :
641- if predictor_args .mode == "dynamic" :
642- predictor_args .quant_type = config .quantization_config .quant_type
643- config .quant_type = config .quantization_config .quant_type
644- if "c8" in config .quant_type :
645- predictor_args .cachekv_int8_type = "static"
646- if predictor_args .mode == "dynamic" :
647- config .cachekv_int8_type = "static"
648-
649- if predictor_args .mode == "dynamic" :
650- ptq_multicards_num = 0
651- if os .path .exists (config .model_name_or_path ):
652- prefix = "act_scales_"
653- for filename in os .listdir (config .model_name_or_path ):
654- if filename .startswith (prefix ):
655- ptq_multicards_num += 1
656-
657- logging .info (
658- f"PTQ from { ptq_multicards_num } cards, so we will not split"
659- )
660- if ptq_multicards_num > 1 :
661- config .single_card_ptq = False
662632
663633 if predictor_args .block_attn :
664634 config .block_size = predictor_args .block_size
@@ -1323,45 +1293,6 @@ def _load_pretrained_model(
13231293 "." .join ([prefix , s ]) for s in quantization_linear_list
13241294 ]
13251295
1326- # Weight quantization if not yet quantized & update loaded_keys
1327- if (
1328- hasattr (config , "quantization_config" )
1329- and config .quantization_config .is_weight_quantize ()
1330- ):
1331- try :
1332- from ..quantization .quantization_utils import (
1333- convert_to_quantize_state_dict ,
1334- update_loaded_state_dict_keys ,
1335- )
1336- except ImportError :
1337- raise ImportError (
1338- "Quantization features require `paddlepaddle >= 2.5.2`"
1339- )
1340- if state_dict is not None :
1341- state_dict = convert_to_quantize_state_dict (
1342- state_dict ,
1343- quantization_linear_list ,
1344- config .quantization_config ,
1345- dtype ,
1346- )
1347- loaded_keys = [k for k in state_dict .keys ()]
1348- else :
1349- loaded_keys = update_loaded_state_dict_keys (
1350- loaded_keys , quantization_linear_list , config .quantization_config
1351- )
1352- if keep_in_fp32_modules is None :
1353- keep_in_fp32_modules = (
1354- ["quant_scale" ]
1355- if config .quantization_config .weight_quantize_algo in ["nf4" , "fp4" ]
1356- else None
1357- )
1358- else :
1359- keep_in_fp32_modules = (
1360- keep_in_fp32_modules + ["quant_scale" ]
1361- if config .quantization_config .weight_quantize_algo in ["nf4" , "fp4" ]
1362- else keep_in_fp32_modules
1363- )
1364-
13651296 missing_keys = list (set (expected_keys ) - set (loaded_keys ))
13661297 unexpected_keys = list (set (loaded_keys ) - set (expected_keys ))
13671298
@@ -1525,27 +1456,12 @@ def _fuse_or_split_keys(
15251456 ignore_mismatched_sizes ,
15261457 )
15271458
1528- if (
1529- hasattr (config , "quantization_config" )
1530- and config .quantization_config .is_weight_quantize ()
1531- ):
1532- error_msgs = _load_state_dict_into_meta_model (
1533- model_to_load ,
1534- state_dict ,
1535- loaded_keys ,
1536- start_prefix ,
1537- expected_keys ,
1538- dtype = dtype ,
1539- is_safetensors = is_safetensors ,
1540- keep_in_fp32_modules = keep_in_fp32_modules ,
1541- )
1542- else :
1543- error_msgs = _load_state_dict_into_model (
1544- model_to_load ,
1545- state_dict ,
1546- start_prefix ,
1547- convert_from_hf = convert_from_hf ,
1548- )
1459+ error_msgs = _load_state_dict_into_model (
1460+ model_to_load ,
1461+ state_dict ,
1462+ start_prefix ,
1463+ convert_from_hf = convert_from_hf ,
1464+ )
15491465 else :
15501466 # Sharded checkpoint or whole but low_cpu_mem_usage==True
15511467
@@ -1600,8 +1516,6 @@ def _fuse_or_split_keys(
16001516 if k [- 1 ] in tp_actions :
16011517 fuse_actions .pop (k [- 1 ], None )
16021518
1603- if config .quantization_config .is_weight_quantize ():
1604- filter_dict_keys = None
16051519 try :
16061520 transpose_weight_keys = model .get_transpose_weight_keys ()
16071521 except NotImplementedError :
@@ -1630,14 +1544,6 @@ def _fuse_or_split_keys(
16301544 missing_keys = list (set (missing_keys ) - set (new_keys ))
16311545 unexpected_keys = list (set (unexpected_keys ) - set (fused_keys ))
16321546
1633- if config .quantization_config .is_weight_quantize ():
1634- state_dict = convert_to_quantize_state_dict (
1635- state_dict ,
1636- quantization_linear_list ,
1637- config .quantization_config ,
1638- dtype ,
1639- )
1640-
16411547 # Mismatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
16421548 # matching the weights in the model.
16431549 mismatched_keys += _find_mismatched_keys (
@@ -1664,7 +1570,7 @@ def _fuse_or_split_keys(
16641570 )
16651571 logging .info ("Converted state_dict to Tensor Parallel Format" )
16661572
1667- if low_cpu_mem_usage or config . quantization_config . is_weight_quantize () :
1573+ if low_cpu_mem_usage :
16681574 new_error_msgs = _load_state_dict_into_meta_model (
16691575 model_to_load ,
16701576 state_dict ,
0 commit comments