43
43
GenerationBlockInferenceModel ,
44
44
GenerationInferenceModel ,
45
45
)
46
- from paddlenlp .experimental .transformers .utils import infererence_model_from_pretrained
46
+ from paddlenlp .experimental .transformers .utils import (
47
+ EmptyActScale ,
48
+ EmptyCacheScale ,
49
+ EmptyWeightScale ,
50
+ infererence_model_from_pretrained ,
51
+ )
47
52
from paddlenlp .transformers import LlamaConfig , LlamaPretrainedModel
48
53
from paddlenlp .transformers .conversion_utils import split_param_func
49
54
from paddlenlp .transformers .llama .modeling import LlamaLMHead
@@ -346,7 +351,7 @@ def __init__(self, config: LlamaConfig):
346
351
self .num_layers = config .num_hidden_layers
347
352
self .epsilon = config .rms_norm_eps
348
353
self .max_position_embeddings = config .max_position_embeddings
349
- self .quant_type = config .quant_type
354
+ self .quant_type = config .get ( " quant_type" , "" )
350
355
351
356
self .rope_theta = config .rope_theta
352
357
self .use_neox = True
@@ -364,6 +369,8 @@ def __init__(self, config: LlamaConfig):
364
369
self .smooth = config .quantization_config .smooth
365
370
self .shift_smooth_all_linears = config .quantization_config .shift_smooth_all_linears
366
371
372
+ self .use_fake_parameter = config .get ("use_fake_parameter" , False )
373
+
367
374
if self .use_weight_only :
368
375
assert (
369
376
self .quant_type == "weight_only_int8" or self .quant_type == "weight_only_int4"
@@ -894,6 +901,30 @@ def set_state_dict(self, state_dict):
894
901
895
902
if "a8w8" in self .quant_type :
896
903
if self .shift_smooth_all_linears :
904
+ if self .use_fake_parameter :
905
+ if "llama.layers.{}.self_attn.o_proj.shift_bias" .format (idx ) not in state_dict :
906
+ state_dict ["llama.layers.{}.self_attn.o_proj.shift_bias" .format (idx )] = paddle .zeros (
907
+ shape = [
908
+ (self .num_attention_heads // self .config .tensor_parallel_degree )
909
+ * (self .hidden_size // self .num_attention_heads )
910
+ ],
911
+ dtype = paddle .get_default_dtype (),
912
+ )
913
+ state_dict ["llama.layers.{}.self_attn.o_proj.smooth_weight" .format (idx )] = paddle .ones (
914
+ shape = [
915
+ (self .num_attention_heads // self .config .tensor_parallel_degree )
916
+ * (self .hidden_size // self .num_attention_heads )
917
+ ],
918
+ dtype = paddle .get_default_dtype (),
919
+ )
920
+ state_dict ["llama.layers.{}.mlp.down_proj.shift_bias" .format (idx )] = paddle .zeros (
921
+ shape = [self .intermediate_size // self .config .tensor_parallel_degree ],
922
+ dtype = paddle .get_default_dtype (),
923
+ )
924
+ state_dict ["llama.layers.{}.mlp.down_proj.smooth_weight" .format (idx )] = paddle .ones (
925
+ shape = [self .intermediate_size // self .config .tensor_parallel_degree ],
926
+ dtype = paddle .get_default_dtype (),
927
+ )
897
928
self .transformer_block .linear_shifts [idx ].set_value (
898
929
paddle .to_tensor (state_dict ["llama.layers.{}.self_attn.o_proj.shift_bias" .format (idx )])
899
930
)
@@ -908,6 +939,33 @@ def set_state_dict(self, state_dict):
908
939
)
909
940
910
941
if self .shift :
942
+ if self .use_fake_parameter :
943
+ if "llama.layers.{}.input_layernorm.bias" .format (idx ) not in state_dict :
944
+ state_dict ["llama.layers.{}.input_layernorm.bias" .format (idx )] = paddle .zeros (
945
+ shape = [self .hidden_size ], dtype = paddle .get_default_dtype ()
946
+ )
947
+ state_dict ["llama.layers.{}.post_attention_layernorm.bias" .format (idx )] = paddle .zeros (
948
+ [self .hidden_size ], dtype = paddle .get_default_dtype ()
949
+ )
950
+ unfused_state_dict ["self_attn.q_proj.bias" ] = paddle .zeros (
951
+ shape = [self .num_attention_heads * (self .hidden_size // self .num_attention_heads )],
952
+ dtype = paddle .get_default_dtype (),
953
+ )
954
+ unfused_state_dict ["self_attn.k_proj.bias" ] = paddle .zeros (
955
+ shape = [self .num_key_value_heads * (self .hidden_size // self .num_attention_heads )],
956
+ dtype = paddle .get_default_dtype (),
957
+ )
958
+ unfused_state_dict ["self_attn.v_proj.bias" ] = paddle .zeros (
959
+ shape = [self .num_key_value_heads * (self .hidden_size // self .num_attention_heads )],
960
+ dtype = paddle .get_default_dtype (),
961
+ )
962
+ unfused_state_dict ["mlp.gate_proj.bias" ] = paddle .zeros (
963
+ shape = [self .intermediate_size ], dtype = paddle .get_default_dtype ()
964
+ )
965
+ unfused_state_dict ["mlp.up_proj.bias" ] = paddle .zeros (
966
+ shape = [self .intermediate_size ], dtype = paddle .get_default_dtype ()
967
+ )
968
+
911
969
self .transformer_block .ln_biases [idx ].set_value (
912
970
paddle .to_tensor (state_dict ["llama.layers.{}.input_layernorm.bias" .format (idx )])
913
971
)
@@ -948,6 +1006,14 @@ def set_state_dict(self, state_dict):
948
1006
self .transformer_block .ffn1_biases [idx ].set_value (paddle .to_tensor (concated_ffn1_bias ))
949
1007
950
1008
if self .shift_smooth_all_linears :
1009
+ if self .use_fake_parameter :
1010
+ if "llama.layers.{}.self_attn.o_proj.bias" .format (idx ) not in state_dict :
1011
+ state_dict ["llama.layers.{}.self_attn.o_proj.bias" .format (idx )] = paddle .zeros (
1012
+ [self .hidden_size ], dtype = paddle .get_default_dtype ()
1013
+ )
1014
+ state_dict ["llama.layers.{}.mlp.down_proj.layer.bias" .format (idx )] = paddle .zeros (
1015
+ [self .hidden_size ], dtype = paddle .get_default_dtype ()
1016
+ )
951
1017
self .transformer_block .linear_biases [idx ].set_value (
952
1018
paddle .to_tensor (state_dict ["llama.layers.{}.self_attn.o_proj.bias" .format (idx )])
953
1019
)
@@ -981,41 +1047,64 @@ def set_state_dict(self, state_dict):
981
1047
weight_scale_map_dict = scale_map_dict ["weight_scale" ]
982
1048
cache_scale_map_dict = scale_map_dict ["cachekv_scale" ]
983
1049
984
- act_scale_json_path = os .path .join (self .quant_model_path , "act_scales.json" )
985
- weight_scale_json_path = os .path .join (self .quant_model_path , "weight_scales.json" )
986
- if self .config .tensor_parallel_degree > 1 and not self .config .single_card_ptq :
987
- act_scale_json_path = os .path .join (
988
- self .quant_model_path , f"act_scales_{ self .config .tensor_parallel_rank } .json"
1050
+ if not self .use_fake_parameter :
1051
+ act_scale_json_path = os .path .join (self .quant_model_path , "act_scales.json" )
1052
+ weight_scale_json_path = os .path .join (self .quant_model_path , "weight_scales.json" )
1053
+ if self .config .tensor_parallel_degree > 1 and not self .config .single_card_ptq :
1054
+ act_scale_json_path = os .path .join (
1055
+ self .quant_model_path , f"act_scales_{ self .config .tensor_parallel_rank } .json"
1056
+ )
1057
+ weight_scale_json_path = os .path .join (
1058
+ self .quant_model_path , f"weight_scales_{ self .config .tensor_parallel_rank } .json"
1059
+ )
1060
+ act_scale_loader = ActScalesLoader (
1061
+ act_scale_json_path , act_scale_map_dict , num_of_layers = self .config .num_hidden_layers
989
1062
)
990
- weight_scale_json_path = os .path .join (
991
- self .quant_model_path , f"weight_scales_{ self .config .tensor_parallel_rank } .json"
1063
+ weight_scales_loader = WeightScalesLoader (
1064
+ weight_scale_json_path ,
1065
+ weight_scale_map_dict ,
1066
+ num_of_layers = self .config .num_hidden_layers ,
1067
+ concat_qkv = True ,
1068
+ concat_ffn1 = True ,
1069
+ )
1070
+ else :
1071
+ act_scale_loader = EmptyActScale (act_scale_map_dict , num_of_layers = self .config .num_hidden_layers )
1072
+ weight_scales_loader = EmptyWeightScale (
1073
+ weight_scale_map_dict ,
1074
+ num_of_layers = self .config .num_hidden_layers ,
1075
+ num_head = self .num_attention_heads ,
1076
+ dim_head = self .hidden_size // self .num_attention_heads ,
1077
+ ffn_hidden_size = self .intermediate_size ,
1078
+ num_key_value_heads = self .num_key_value_heads ,
1079
+ mp_size = self .config .tensor_parallel_degree ,
992
1080
)
993
- act_scale_loader = ActScalesLoader (
994
- act_scale_json_path , act_scale_map_dict , num_of_layers = self .config .num_hidden_layers
995
- )
996
1081
self .transformer_block .act_scales = act_scale_loader .scale
997
1082
998
- weight_scales_loader = WeightScalesLoader (
999
- weight_scale_json_path ,
1000
- weight_scale_map_dict ,
1001
- num_of_layers = self .config .num_hidden_layers ,
1002
- concat_qkv = True ,
1003
- concat_ffn1 = True ,
1004
- )
1005
-
1006
1083
if self .config .cachekv_int8_type == "static" :
1007
- cache_scale_json_path = os .path .join (self .quant_model_path , "cachekv_scales.json" )
1008
- if self .config .tensor_parallel_degree > 1 and not self .config .single_card_ptq :
1009
- cache_scale_json_path = os .path .join (
1010
- self .quant_model_path , f"cachekv_scales_{ self .config .tensor_parallel_rank } .json"
1084
+ if not self .use_fake_parameter :
1085
+ cache_scale_json_path = os .path .join (self .quant_model_path , "cachekv_scales.json" )
1086
+ if self .config .tensor_parallel_degree > 1 and not self .config .single_card_ptq :
1087
+ cache_scale_json_path = os .path .join (
1088
+ self .quant_model_path , f"cachekv_scales_{ self .config .tensor_parallel_rank } .json"
1089
+ )
1090
+ cache_scales_loader = CacheScaleLoader (
1091
+ cache_scale_json_path ,
1092
+ cache_scale_map_dict ,
1093
+ num_of_layers = self .config .num_hidden_layers ,
1094
+ num_heads = self .num_attention_heads // self .config .tensor_parallel_degree ,
1095
+ num_key_value_heads = self .num_key_value_heads // self .config .tensor_parallel_degree ,
1011
1096
)
1012
- cache_scales_loader = CacheScaleLoader (
1013
- cache_scale_json_path ,
1014
- cache_scale_map_dict ,
1015
- num_of_layers = self .config .num_hidden_layers ,
1016
- num_heads = self .num_attention_heads // self .config .tensor_parallel_degree ,
1017
- num_key_value_heads = self .num_key_value_heads // self .config .tensor_parallel_degree ,
1018
- )
1097
+ else :
1098
+ cache_scales_loader = EmptyCacheScale (
1099
+ cache_scale_map_dict ,
1100
+ num_of_layers = self .config .num_hidden_layers ,
1101
+ num_heads = self .num_attention_heads ,
1102
+ dim_heads = self .hidden_size // self .num_attention_heads ,
1103
+ is_channel_wise = False ,
1104
+ num_key_value_heads = self .num_key_value_heads ,
1105
+ mp_size = self .config .tensor_parallel_degree ,
1106
+ )
1107
+
1019
1108
for k , v in cache_scales_loader .scale .items ():
1020
1109
for i_layer , weight_scale in enumerate (v ):
1021
1110
weight_scale = weight_scale .astype ("float32" )
0 commit comments