1
1
import argparse
2
+ import gc
2
3
import re
3
4
from typing import Tuple
4
5
10
11
11
12
12
13
# HACK: override the dtype_byte_size function in transformers to support float8 types
14
+ # Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
13
15
def new_dtype_byte_size (dtype ):
14
16
if dtype == torch .bool :
15
17
return 1 / 8
@@ -23,6 +25,11 @@ def new_dtype_byte_size(dtype):
23
25
transformers .modeling_utils .dtype_byte_size = new_dtype_byte_size
24
26
25
27
28
+ def cleanup_memory ():
29
+ gc .collect ()
30
+ torch .cuda .empty_cache ()
31
+
32
+
26
33
def per_tensor_quantize (tensor : torch .Tensor ) -> Tuple [torch .Tensor , float ]:
27
34
"""Quantize a tensor using per-tensor static scaling factor.
28
35
@@ -33,7 +40,14 @@ def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]:
33
40
# Calculate the scale as dtype max divided by absmax.
34
41
# Since .abs() creates a new tensor, we use aminmax to get
35
42
# the min and max first and then calculate the absmax.
36
- min_val , max_val = tensor .aminmax ()
43
+ if tensor .numel () == 0 :
44
+ # Deal with empty tensors (triggered by empty MoE experts)
45
+ min_val , max_val = (
46
+ torch .tensor (0.0 , dtype = tensor .dtype ),
47
+ torch .tensor (1.0 , dtype = tensor .dtype ),
48
+ )
49
+ else :
50
+ min_val , max_val = tensor .aminmax ()
37
51
amax = min_val .abs ().max (max_val .abs ())
38
52
scale = finfo .max / amax .clamp (min = 1e-12 )
39
53
# scale and clamp the tensor to bring it to
@@ -145,68 +159,80 @@ def forward(self, x):
145
159
return output
146
160
147
161
162
+ def replace_module (model , name , new_module ):
163
+ if "." in name :
164
+ parent_name = name .rsplit ("." , 1 )[0 ]
165
+ child_name = name [len (parent_name ) + 1 :]
166
+ parent = model .model .get_submodule (parent_name )
167
+ else :
168
+ parent_name = ""
169
+ parent = model .model
170
+ child_name = name
171
+ setattr (parent , child_name , new_module )
172
+
173
+
148
174
def quantize_weights (model ):
149
175
for name , linear in model .model .named_modules ():
176
+ # if "gate" in name or not isinstance(linear, torch.nn.Linear):
150
177
if not isinstance (linear , torch .nn .Linear ):
151
178
continue
152
179
quant_weight , quant_scale = per_tensor_quantize (linear .weight )
153
180
quant_linear = FP8DynamicLinear (quant_weight , quant_scale )
154
- if "." in name :
155
- parent_name = name .rsplit ("." , 1 )[0 ]
156
- child_name = name [len (parent_name ) + 1 :]
157
- parent = model .model .get_submodule (parent_name )
158
- else :
159
- parent_name = ""
160
- parent = model .model
161
- child_name = name
162
- setattr (parent , child_name , quant_linear )
181
+ replace_module (model , name , quant_linear )
182
+ del linear
183
+ cleanup_memory ()
163
184
164
185
165
186
def quantize_activations (model , calibration_tokens ):
166
187
# Replace layers with quantizer.
167
188
for name , dynamic_quant_linear in model .model .named_modules ():
189
+ # if "gate" in name or not isinstance(dynamic_quant_linear, FP8DynamicLinear):
168
190
if not isinstance (dynamic_quant_linear , FP8DynamicLinear ):
169
191
continue
170
192
quantizer = FP8StaticLinearQuantizer (
171
193
dynamic_quant_linear .weight , dynamic_quant_linear .weight_scale
172
194
)
173
- if "." in name :
174
- parent_name = name .rsplit ("." , 1 )[0 ]
175
- child_name = name [len (parent_name ) + 1 :]
176
- parent = model .model .get_submodule (parent_name )
177
- else :
178
- parent_name = ""
179
- parent = model .model
180
- child_name = name
181
- setattr (parent , child_name , quantizer )
195
+ replace_module (model , name , quantizer )
196
+ del dynamic_quant_linear
197
+ cleanup_memory ()
182
198
183
199
# Calibration.
184
200
for row_idx in range (calibration_tokens .shape [0 ]):
185
201
_ = model (calibration_tokens [row_idx ].reshape (1 , - 1 ))
186
202
187
203
# Replace quantizer with StaticLayer.
188
204
for name , quantizer in model .model .named_modules ():
205
+ # if "gate" in name or not isinstance(quantizer, FP8StaticLinearQuantizer):
189
206
if not isinstance (quantizer , FP8StaticLinearQuantizer ):
190
207
continue
191
208
static_proj = FP8StaticLinear (
192
209
quantizer .weight , quantizer .weight_scale , quantizer .act_scale
193
210
)
194
- if "." in name :
195
- parent_name = name .rsplit ("." , 1 )[0 ]
196
- child_name = name [len (parent_name ) + 1 :]
197
- parent = model .model .get_submodule (parent_name )
198
- else :
199
- parent_name = ""
200
- parent = model .model
201
- child_name = name
202
- setattr (parent , child_name , static_proj )
211
+ replace_module (model , name , static_proj )
212
+ del quantizer
213
+ cleanup_memory ()
214
+
215
+
216
+ def save_quantized_model (model , activation_scheme , save_dir ):
217
+ print (f"Saving the model to { save_dir } " )
218
+ static_q_dict = {
219
+ "quantization_config" : {
220
+ "quant_method" : "fp8" ,
221
+ "activation_scheme" : activation_scheme ,
222
+ }
223
+ }
224
+ model .config .update (static_q_dict )
225
+ model .save_pretrained (save_dir )
226
+ tokenizer .save_pretrained (save_dir )
203
227
204
228
205
229
if __name__ == "__main__" :
206
230
parser = argparse .ArgumentParser ()
207
231
parser .add_argument ("--model-id" , type = str )
208
232
parser .add_argument ("--save-dir" , type = str )
209
- # parser.add_argument("--static-act", action="store_true")
233
+ parser .add_argument (
234
+ "--activation-scheme" , type = str , default = "static" , choices = ["static" , "dynamic" ]
235
+ )
210
236
parser .add_argument ("--num-samples" , type = int , default = 512 )
211
237
parser .add_argument ("--max-seq-len" , type = int , default = 512 )
212
238
args = parser .parse_args ()
@@ -240,22 +266,26 @@ def quantize_activations(model, calibration_tokens):
240
266
model = AutoModelForCausalLM .from_pretrained (
241
267
args .model_id , torch_dtype = "auto" , device_map = "auto"
242
268
)
269
+ print ("Original model graph:\n " , model )
243
270
output = model .generate (input_ids = sample_input_tokens , max_new_tokens = 20 )
244
- print ("ORIGINAL:\n " , tokenizer .decode (output [0 ]), "\n \n " )
271
+ print ("ORIGINAL OUTPUT :\n " , tokenizer .decode (output [0 ]), "\n \n " )
245
272
246
273
# Quantize weights.
247
274
quantize_weights (model )
275
+ print ("Weight-quantized model graph:\n " , model )
248
276
output = model .generate (input_ids = sample_input_tokens , max_new_tokens = 20 )
249
- print ("WEIGHT QUANT:\n " , tokenizer .decode (output [0 ]), "\n \n " )
277
+ print ("WEIGHT QUANT OUTPUT :\n " , tokenizer .decode (output [0 ]), "\n \n " )
250
278
251
- # Quantize activations.
252
- quantize_activations (model , calibration_tokens = calibration_tokens )
253
- output = model .generate (input_ids = sample_input_tokens , max_new_tokens = 20 )
254
- print ("ACT QUANT:\n " , tokenizer .decode (output [0 ]), "\n \n " )
255
-
256
- # Save the model fully quantized
257
- print (f"Saving the model to { args .save_dir } " )
258
- static_q_dict = {"quantization_config" : {"quant_method" : "fp8" , "scheme" : "static" }}
259
- model .config .update (static_q_dict )
260
- model .save_pretrained (args .save_dir )
261
- tokenizer .save_pretrained (args .save_dir )
279
+ if args .activation_scheme in "dynamic" :
280
+ print ("Exporting model with static weights and dynamic activations" )
281
+ save_quantized_model (model , args .activation_scheme , args .save_dir )
282
+ else :
283
+ assert args .activation_scheme in "static"
284
+ # Quantize activations.
285
+ quantize_activations (model , calibration_tokens = calibration_tokens )
286
+ print ("Weight and activation quantized model graph:\n " , model )
287
+ output = model .generate (input_ids = sample_input_tokens , max_new_tokens = 20 )
288
+ print ("ACT QUANT OUTPUT:\n " , tokenizer .decode (output [0 ]), "\n \n " )
289
+
290
+ print ("Exporting model with static weights and static activations" )
291
+ save_quantized_model (model , args .activation_scheme , args .save_dir )
0 commit comments