@@ -195,7 +195,8 @@ def quantize_weights(
195
195
quantize_config : BaseQuantizeConfig ,
196
196
ignored_layers : List [str ] = [],
197
197
):
198
- for name , linear in model .named_modules ():
198
+ named_modules = list (model .named_modules ())
199
+ for name , linear in tqdm .tqdm (named_modules , desc = "Quantizing weights" ):
199
200
if (
200
201
not isinstance (linear , torch .nn .Linear )
201
202
or name in quantize_config .ignored_layers
@@ -205,7 +206,7 @@ def quantize_weights(
205
206
quant_linear = FP8DynamicLinear (quant_weight , quant_scale , linear .bias )
206
207
replace_module (model , name , quant_linear )
207
208
del linear
208
- cleanup_memory ()
209
+ cleanup_memory ()
209
210
210
211
211
212
def quantize_activations (
@@ -214,6 +215,7 @@ def quantize_activations(
214
215
calibration_tokens ,
215
216
ignored_layers : List [str ] = [],
216
217
):
218
+ # Replace weight quantizer with a dynamic activation quantizer observer
217
219
for name , dynamic_quant_linear in model .named_modules ():
218
220
if (
219
221
not isinstance (dynamic_quant_linear , FP8DynamicLinear )
@@ -229,14 +231,14 @@ def quantize_activations(
229
231
del dynamic_quant_linear
230
232
cleanup_memory ()
231
233
232
- # Calibration.
233
- with tqdm .tqdm (total = calibration_tokens .shape [0 ], desc = "Calibrating" ) as pbar :
234
+ # Pass through calibration data to measure activation scales
235
+ with tqdm .tqdm (total = calibration_tokens .shape [0 ], desc = "Calibrating activation scales " ) as pbar :
234
236
for row_idx in range (calibration_tokens .shape [0 ]):
235
237
model (calibration_tokens [row_idx ].reshape (1 , - 1 ))
236
238
cleanup_memory ()
237
239
pbar .update (1 )
238
240
239
- # Replace dynamic quantizer with StaticLinear for export
241
+ # Replace dynamic quantizer observer with StaticLinear for export
240
242
for name , quantizer in model .named_modules ():
241
243
if (
242
244
not isinstance (quantizer , FP8StaticLinearQuantizer )
0 commit comments