2626
2727from angelslim .compressor .quant .core .quant_func import weight_dequant
2828
29+ SUFFIX_TO_QUANT = [
30+ ".gate_and_up_proj.weight" ,
31+ ".gate_proj.weight" ,
32+ ".up_proj.weight" ,
33+ ".down_proj.weight" ,
34+ ".q_a_proj.weight" ,
35+ ".q_b_proj.weight" ,
36+ ".kv_a_proj_with_mqa.weight" ,
37+ ".kv_b_proj.weight" ,
38+ ".qkv_proj.weight" ,
39+ ".q_proj.weight" ,
40+ ".k_proj.weight" ,
41+ ".v_proj.weight" ,
42+ ".o_proj.weight" ,
43+ ".indexer.wq_b.weight" ,
44+ ".indexer.wk.weight" ,
45+ ]
46+
2947
3048def process_worker (
31- worker_id , safetensor_files , fp8_path , int8_path , weight_map , return_dict
49+ worker_id ,
50+ safetensor_files ,
51+ input_path ,
52+ int8_path ,
53+ weight_map ,
54+ return_dict ,
55+ input_type = "bf16" ,
3256):
3357 """
3458 Process worker.
@@ -51,18 +75,19 @@ def process_worker(
5175 keys = set (f .keys ())
5276 for weight_name in keys :
5377 weight = f .get_tensor (weight_name )
54- scale_inv_name = f"{ weight_name } _scale_inv"
55- if scale_inv_name in weight_map :
78+ if any (weight_name .endswith (suffix ) for suffix in SUFFIX_TO_QUANT ):
5679 quant_count += 1
57- # 1. fp8 dequant to bf16
58- scale_inv = get_tensor_from_file (
59- rank , scale_inv_name , weight_map , fp8_path
60- )
61- weight_bf16 = weight_dequant (weight , scale_inv )
62- # 2. bf16 quant to int8
80+ if input_type == "fp8" :
81+ scale_inv_name = f"{ weight_name } _scale_inv"
82+ scale_inv = get_tensor_from_file (
83+ rank , scale_inv_name , weight_map , input_path
84+ )
85+ weight_bf16 = weight_dequant (weight , scale_inv )
86+ else :
87+ weight_bf16 = weight
6388 int8_weight , scale_inv = weight_quant (weight_bf16 )
6489 new_state_dict [weight_name ] = int8_weight
65- new_scale_name = scale_inv_name . replace ( "_scale_inv" , " _scale")
90+ new_scale_name = f" { weight_name } _scale"
6691 new_state_dict [new_scale_name ] = scale_inv
6792 new_weight_map [weight_name ] = file_name
6893 new_weight_map [new_scale_name ] = file_name
@@ -78,7 +103,7 @@ def process_worker(
78103
79104
80105# Helper function to get tensor from the correct file
81- def get_tensor_from_file (rank , tensor_name , weight_map , fp8_path ):
106+ def get_tensor_from_file (rank , tensor_name , weight_map , input_path ):
82107 """
83108 Retrieves a tensor from mmap safe_tensors
84109
@@ -93,7 +118,7 @@ def get_tensor_from_file(rank, tensor_name, weight_map, fp8_path):
93118 """
94119 torch .cuda .set_device (rank )
95120 file_name = weight_map [tensor_name ]
96- file_path = os .path .join (fp8_path , file_name )
121+ file_path = os .path .join (input_path , file_name )
97122
98123 with safe_open (file_path , framework = "pt" , device = f"cuda:{ rank } " ) as f :
99124 return f .get_tensor (tensor_name )
@@ -119,7 +144,7 @@ def weight_quant(tensor: torch.Tensor):
119144 return quantized .to (torch .int8 ), scale .to (torch .float32 )
120145
121146
122- def main (fp8_path , int8_path , num_workers ):
147+ def main (input_path , int8_path , num_workers ):
123148 """
124149 Run the FP8-to-INT8 per-channel quantization pipeline.
125150
@@ -130,7 +155,7 @@ def main(fp8_path, int8_path, num_workers):
130155 4. Saves quantized safetensors and updates model index.
131156
132157 Args:
133- fp8_path (str): Path to directory containing FP8 safetensors.
158+ input_path (str): Path to directory containing FP8 safetensors.
134159 int8_path (str): Output directory to save INT8 safetensors.
135160 num_workers (int): Number of processing workers
136161 """
@@ -139,10 +164,10 @@ def main(fp8_path, int8_path, num_workers):
139164 model_index_file = os .path .join (int8_path , "model.safetensors.index.json" )
140165 config_file = os .path .join (int8_path , "config.json" )
141166
142- for fname in os .listdir (fp8_path ):
167+ for fname in os .listdir (input_path ):
143168 if fname .endswith (".safetensors" ):
144169 continue
145- src = os .path .join (fp8_path , fname )
170+ src = os .path .join (input_path , fname )
146171 dst = os .path .join (int8_path , fname )
147172 if os .path .isdir (src ):
148173 print (f"cp -r { src } { dst } " )
@@ -154,7 +179,11 @@ def main(fp8_path, int8_path, num_workers):
154179 # modify config.json and save it
155180 config = json .load (open (config_file ))
156181 # delete quantization_config
157- config .pop ("quantization_config" , None )
182+ quant_config = config .pop ("quantization_config" , None )
183+ input_type = "bf16"
184+ if quant_config is not None :
185+ input_type = quant_config .get ("quant_method" , input_type )
186+ print ("input_type" , input_type )
158187 config ["quantization_config" ] = {
159188 "config_groups" : {
160189 "group_0" : {
@@ -200,9 +229,8 @@ def main(fp8_path, int8_path, num_workers):
200229 with open (model_index_file , "r" ) as f :
201230 model_index = json .load (f )
202231 weight_map = model_index ["weight_map" ]
203- scale_count = len ([key for key in weight_map .keys () if key .endswith ("_scale_inv" )])
204232
205- safetensor_files = list (glob (os .path .join (fp8_path , "*.safetensors" )))
233+ safetensor_files = list (glob (os .path .join (input_path , "*.safetensors" )))
206234 safetensor_files .sort ()
207235 quant_count = 0
208236 new_weight_map = {}
@@ -216,7 +244,15 @@ def main(fp8_path, int8_path, num_workers):
216244 for i in range (num_workers ):
217245 p = mp .Process (
218246 target = process_worker ,
219- args = (i , file_subsets [i ], fp8_path , int8_path , weight_map , return_dict ),
247+ args = (
248+ i ,
249+ file_subsets [i ],
250+ input_path ,
251+ int8_path ,
252+ weight_map ,
253+ return_dict ,
254+ input_type ,
255+ ),
220256 )
221257 p .start ()
222258 processes .append (p )
@@ -227,7 +263,6 @@ def main(fp8_path, int8_path, num_workers):
227263 qc , wm = return_dict [i ]
228264 quant_count += qc
229265 new_weight_map .update (wm )
230- assert quant_count == scale_count
231266 print (f"{ quant_count } weights are quantized." )
232267
233268 # modify model.safetensors.index.json
@@ -241,10 +276,10 @@ def main(fp8_path, int8_path, num_workers):
241276
242277if __name__ == "__main__" :
243278 parser = ArgumentParser ()
244- parser .add_argument ("--input-fp8- path" , type = str , required = True )
279+ parser .add_argument ("--input-path" , type = str , required = True )
245280 parser .add_argument ("--output-int8-path" , type = str , required = True )
246281 parser .add_argument ("--num-workers" , type = int , default = 32 )
247282
248283 args = parser .parse_args ()
249- main (args .input_fp8_path , args .output_int8_path , args .num_workers )
284+ main (args .input_path , args .output_int8_path , args .num_workers )
250285 print ("done" )
0 commit comments