@@ -172,12 +172,6 @@ def forward_loop(model):
172
172
dist .ReduceOp .AVG ,
173
173
group = tp_group ,
174
174
)
175
- _reduce_quantizer_attr (
176
- model .fc2 .awq_lite ,
177
- "act_scale" ,
178
- dist .ReduceOp .AVG ,
179
- group = tp_group ,
180
- )
181
175
182
176
dist .destroy_process_group ()
183
177
@@ -191,6 +185,9 @@ def forward_loop(model):
191
185
192
186
model = mtq .quantize (model , config , forward_loop )
193
187
188
+ # Sanity check
189
+ forward_loop (model )
190
+
194
191
# Input quantizer amax
195
192
if config not in [mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG , mtq .INT4_AWQ_CFG ]:
196
193
_reduce_quantizer_attr (model .fc1 .input_quantizer , "amax" , dist .ReduceOp .MAX , group = group )
@@ -226,105 +223,46 @@ def forward_loop(model):
226
223
227
224
@patch ("modelopt.torch.quantization.model_calib.awq_lite" , side_effect = _debug_awq_lite )
228
225
def data_tensor_context_parallel_test_helper (model , config , dp_group , tp_group , mock_awq_lite ):
229
- # Print rank information for debugging
230
- world_rank = dist .get_rank ()
231
- world_size = dist .get_world_size ()
232
-
233
- print ("\n === RANK INFORMATION ===" )
234
- print (f"World Rank: { world_rank } , World Size: { world_size } " )
235
-
236
- # Get group information with actual ranks
237
- def get_group_ranks (group ):
238
- if group is None :
239
- return None
240
- ranks = []
241
- ranks = [
242
- i for i in range (world_size ) if dist .get_rank (group = group ) == dist .get_rank (group = group )
243
- ]
244
- return ranks
245
-
246
- if dp_group is not None :
247
- dp_rank = dist .get_rank (group = dp_group )
248
- dp_size = dist .get_world_size (group = dp_group )
249
- print (f"DP Group - Rank: { dp_rank } , Size: { dp_size } " )
250
-
251
- if tp_group is not None :
252
- tp_rank = dist .get_rank (group = tp_group )
253
- tp_size = dist .get_world_size (group = tp_group )
254
- print (f"TP Group - Rank: { tp_rank } , Size: { tp_size } " )
255
-
256
- print ("=== END RANK INFO ===\n " )
257
-
258
- # Print a summary of all ranks
259
- print ("=== ALL RANKS SUMMARY ===" )
260
- print (f"Total GPUs: { world_size } " )
261
- print (f"Current rank: { world_rank } " )
262
- if dp_group is not None :
263
- print (f"DP groups: { dp_size } groups of { world_size // dp_size } ranks each" )
264
- if tp_group is not None :
265
- print (f"TP groups: { tp_size } groups of { world_size // tp_size } ranks each" )
266
- print ("=== END SUMMARY ===\n " )
267
-
268
- calib_data = model .get_dummy_input ().cuda ()
269
- # data should be same across each TP rank
270
- dist .all_reduce (calib_data , op = dist .ReduceOp .AVG , group = tp_group )
226
+ # Calib data should be same across each DP rank
227
+ dp_rank = dist .get_rank (group = dp_group )
228
+ calib_data = model .get_dummy_input (seed = dp_rank ).cuda ()
271
229
272
230
def forward_loop (model ):
273
231
model (calib_data )
274
232
275
233
model = mtq .quantize (model , config , forward_loop )
276
234
277
235
def _reduce_quantizer_attr (quantizer , attr = str , op = dist .ReduceOp .MAX ):
278
- world_rank = dist .get_rank ()
279
- print (f"\n --- Rank { world_rank } : Reducing { attr } ---" )
280
- from megatron .core .parallel_state import (
281
- _CONTEXT_PARALLEL_GLOBAL_RANKS ,
282
- _DATA_PARALLEL_GLOBAL_RANKS ,
283
- _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP ,
284
- _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS ,
285
- )
286
-
287
- print (f"DATA_PARALLEL_GLOBAL_RANKS: { _DATA_PARALLEL_GLOBAL_RANKS } " )
288
- print (f"CONTEXT_PARALLEL_GLOBAL_RANKS: { _CONTEXT_PARALLEL_GLOBAL_RANKS } " )
289
- print (f"DATA_PARALLEL_GLOBAL_RANKS_WITH_CP: { _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP } " )
290
- print (f"TENSOR_MODEL_PARALLEL_GLOBAL_RANKS: { _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS } " )
291
236
quantizer_attr = getattr (quantizer , attr ).clone ()
292
- print (f"Rank { world_rank } - quantizer_attr before reduce" , quantizer_attr )
293
- print (f"Rank { world_rank } - quantizer.attr before reduce" , getattr (quantizer , attr ))
294
237
295
238
# Perform all-reduce operations
296
- if tp_group is not None :
297
- tp_rank = dist .get_rank (group = tp_group )
298
- print (f"Rank { world_rank } - TP reduce (TP rank { tp_rank } )" )
299
- dist .all_reduce (quantizer_attr , op = op , group = tp_group )
239
+ dist .all_reduce (quantizer_attr , op = op , group = tp_group )
300
240
301
- if dp_group is not None :
302
- dp_rank = dist .get_rank (group = dp_group )
303
- print (f"Rank { world_rank } - DP reduce (DP rank { dp_rank } )" )
304
- dist .all_reduce (quantizer_attr , op = op , group = dp_group )
241
+ dist .all_reduce (quantizer_attr , op = op , group = dp_group )
305
242
306
- print (f"Rank { world_rank } - quantizer_attr after reduce" , quantizer_attr )
307
- print (f"Rank { world_rank } - quantizer.attr after reduce" , getattr (quantizer , attr ))
308
- print (f"--- End Rank { world_rank } ---\n " )
309
-
310
- assert torch .allclose (quantizer_attr , getattr (quantizer , attr ))
243
+ assert torch .allclose (quantizer_attr , getattr (quantizer , attr )), getattr (quantizer , attr )
311
244
312
245
# Input quantizer amax
313
246
if config not in [mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG , mtq .INT4_AWQ_CFG ]:
314
247
_reduce_quantizer_attr (model .fc1 .input_quantizer , "amax" , dist .ReduceOp .MAX )
315
248
_reduce_quantizer_attr (model .fc2 .input_quantizer , "amax" , dist .ReduceOp .MAX )
316
249
317
- if isinstance (model .fc1 .weight_quantizer , SequentialQuantizer ):
318
- for quantizer in model .fc1 .weight_quantizer :
319
- _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX )
320
- else :
321
- _reduce_quantizer_attr (model .fc1 .weight_quantizer , "amax" , dist .ReduceOp .MAX )
322
-
323
- if isinstance (model .fc2 .weight_quantizer , SequentialQuantizer ):
324
- for quantizer in model .fc2 .weight_quantizer :
325
- _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX )
326
- else :
327
- _reduce_quantizer_attr (model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX )
250
+ # Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks
251
+ # Channel-wise (INT8) only expects same amax across row parallel ranks
252
+ # Block-wise quantization does not expect same amax across row and column parallel ranks
253
+ if config in [mtq .FP8_DEFAULT_CFG , mtq .NVFP4_DEFAULT_CFG ]:
254
+ if isinstance (model .fc1 .weight_quantizer , SequentialQuantizer ):
255
+ for quantizer in model .fc1 .weight_quantizer :
256
+ _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX )
257
+ else :
258
+ _reduce_quantizer_attr (model .fc1 .weight_quantizer , "amax" , dist .ReduceOp .MAX )
259
+
260
+ if config in [mtq .FP8_DEFAULT_CFG , mtq .NVFP4_DEFAULT_CFG , mtq .INT8_DEFAULT_CFG ]:
261
+ if isinstance (model .fc2 .weight_quantizer , SequentialQuantizer ):
262
+ for quantizer in model .fc2 .weight_quantizer :
263
+ _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX )
264
+ else :
265
+ _reduce_quantizer_attr (model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX )
328
266
329
267
# Check act scale
330
268
if config in [mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
@@ -333,11 +271,6 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX):
333
271
"act_scale" ,
334
272
dist .ReduceOp .AVG ,
335
273
)
336
- _reduce_quantizer_attr (
337
- model .fc2 .awq_lite ,
338
- "act_scale" ,
339
- dist .ReduceOp .AVG ,
340
- )
341
274
342
275
343
276
def auto_quantize_helper (model ):
0 commit comments