@@ -172,12 +172,6 @@ def forward_loop(model):
172172 dist .ReduceOp .AVG ,
173173 group = tp_group ,
174174 )
175- _reduce_quantizer_attr (
176- model .fc2 .awq_lite ,
177- "act_scale" ,
178- dist .ReduceOp .AVG ,
179- group = tp_group ,
180- )
181175
182176 dist .destroy_process_group ()
183177
@@ -191,6 +185,9 @@ def forward_loop(model):
191185
192186 model = mtq .quantize (model , config , forward_loop )
193187
188+ # Sanity check
189+ forward_loop (model )
190+
194191 # Input quantizer amax
195192 if config not in [mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG , mtq .INT4_AWQ_CFG ]:
196193 _reduce_quantizer_attr (model .fc1 .input_quantizer , "amax" , dist .ReduceOp .MAX , group = group )
@@ -226,105 +223,46 @@ def forward_loop(model):
226223
227224@patch ("modelopt.torch.quantization.model_calib.awq_lite" , side_effect = _debug_awq_lite )
228225def data_tensor_context_parallel_test_helper (model , config , dp_group , tp_group , mock_awq_lite ):
229- # Print rank information for debugging
230- world_rank = dist .get_rank ()
231- world_size = dist .get_world_size ()
232-
233- print ("\n === RANK INFORMATION ===" )
234- print (f"World Rank: { world_rank } , World Size: { world_size } " )
235-
236- # Get group information with actual ranks
237- def get_group_ranks (group ):
238- if group is None :
239- return None
240- ranks = []
241- ranks = [
242- i for i in range (world_size ) if dist .get_rank (group = group ) == dist .get_rank (group = group )
243- ]
244- return ranks
245-
246- if dp_group is not None :
247- dp_rank = dist .get_rank (group = dp_group )
248- dp_size = dist .get_world_size (group = dp_group )
249- print (f"DP Group - Rank: { dp_rank } , Size: { dp_size } " )
250-
251- if tp_group is not None :
252- tp_rank = dist .get_rank (group = tp_group )
253- tp_size = dist .get_world_size (group = tp_group )
254- print (f"TP Group - Rank: { tp_rank } , Size: { tp_size } " )
255-
256- print ("=== END RANK INFO ===\n " )
257-
258- # Print a summary of all ranks
259- print ("=== ALL RANKS SUMMARY ===" )
260- print (f"Total GPUs: { world_size } " )
261- print (f"Current rank: { world_rank } " )
262- if dp_group is not None :
263- print (f"DP groups: { dp_size } groups of { world_size // dp_size } ranks each" )
264- if tp_group is not None :
265- print (f"TP groups: { tp_size } groups of { world_size // tp_size } ranks each" )
266- print ("=== END SUMMARY ===\n " )
267-
268- calib_data = model .get_dummy_input ().cuda ()
269- # data should be same across each TP rank
270- dist .all_reduce (calib_data , op = dist .ReduceOp .AVG , group = tp_group )
226+ # Calib data should be same across each DP rank
227+ dp_rank = dist .get_rank (group = dp_group )
228+ calib_data = model .get_dummy_input (seed = dp_rank ).cuda ()
271229
272230 def forward_loop (model ):
273231 model (calib_data )
274232
275233 model = mtq .quantize (model , config , forward_loop )
276234
277235 def _reduce_quantizer_attr (quantizer , attr = str , op = dist .ReduceOp .MAX ):
278- world_rank = dist .get_rank ()
279- print (f"\n --- Rank { world_rank } : Reducing { attr } ---" )
280- from megatron .core .parallel_state import (
281- _CONTEXT_PARALLEL_GLOBAL_RANKS ,
282- _DATA_PARALLEL_GLOBAL_RANKS ,
283- _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP ,
284- _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS ,
285- )
286-
287- print (f"DATA_PARALLEL_GLOBAL_RANKS: { _DATA_PARALLEL_GLOBAL_RANKS } " )
288- print (f"CONTEXT_PARALLEL_GLOBAL_RANKS: { _CONTEXT_PARALLEL_GLOBAL_RANKS } " )
289- print (f"DATA_PARALLEL_GLOBAL_RANKS_WITH_CP: { _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP } " )
290- print (f"TENSOR_MODEL_PARALLEL_GLOBAL_RANKS: { _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS } " )
291236 quantizer_attr = getattr (quantizer , attr ).clone ()
292- print (f"Rank { world_rank } - quantizer_attr before reduce" , quantizer_attr )
293- print (f"Rank { world_rank } - quantizer.attr before reduce" , getattr (quantizer , attr ))
294237
295238 # Perform all-reduce operations
296- if tp_group is not None :
297- tp_rank = dist .get_rank (group = tp_group )
298- print (f"Rank { world_rank } - TP reduce (TP rank { tp_rank } )" )
299- dist .all_reduce (quantizer_attr , op = op , group = tp_group )
239+ dist .all_reduce (quantizer_attr , op = op , group = tp_group )
300240
301- if dp_group is not None :
302- dp_rank = dist .get_rank (group = dp_group )
303- print (f"Rank { world_rank } - DP reduce (DP rank { dp_rank } )" )
304- dist .all_reduce (quantizer_attr , op = op , group = dp_group )
241+ dist .all_reduce (quantizer_attr , op = op , group = dp_group )
305242
306- print (f"Rank { world_rank } - quantizer_attr after reduce" , quantizer_attr )
307- print (f"Rank { world_rank } - quantizer.attr after reduce" , getattr (quantizer , attr ))
308- print (f"--- End Rank { world_rank } ---\n " )
309-
310- assert torch .allclose (quantizer_attr , getattr (quantizer , attr ))
243+ assert torch .allclose (quantizer_attr , getattr (quantizer , attr )), getattr (quantizer , attr )
311244
312245 # Input quantizer amax
313246 if config not in [mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG , mtq .INT4_AWQ_CFG ]:
314247 _reduce_quantizer_attr (model .fc1 .input_quantizer , "amax" , dist .ReduceOp .MAX )
315248 _reduce_quantizer_attr (model .fc2 .input_quantizer , "amax" , dist .ReduceOp .MAX )
316249
317- if isinstance (model .fc1 .weight_quantizer , SequentialQuantizer ):
318- for quantizer in model .fc1 .weight_quantizer :
319- _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX )
320- else :
321- _reduce_quantizer_attr (model .fc1 .weight_quantizer , "amax" , dist .ReduceOp .MAX )
322-
323- if isinstance (model .fc2 .weight_quantizer , SequentialQuantizer ):
324- for quantizer in model .fc2 .weight_quantizer :
325- _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX )
326- else :
327- _reduce_quantizer_attr (model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX )
250+ # Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks
251+ # Channel-wise (INT8) only expects same amax across row parallel ranks
252+ # Block-wise quantization does not expect same amax across row and column parallel ranks
253+ if config in [mtq .FP8_DEFAULT_CFG , mtq .NVFP4_DEFAULT_CFG ]:
254+ if isinstance (model .fc1 .weight_quantizer , SequentialQuantizer ):
255+ for quantizer in model .fc1 .weight_quantizer :
256+ _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX )
257+ else :
258+ _reduce_quantizer_attr (model .fc1 .weight_quantizer , "amax" , dist .ReduceOp .MAX )
259+
260+ if config in [mtq .FP8_DEFAULT_CFG , mtq .NVFP4_DEFAULT_CFG , mtq .INT8_DEFAULT_CFG ]:
261+ if isinstance (model .fc2 .weight_quantizer , SequentialQuantizer ):
262+ for quantizer in model .fc2 .weight_quantizer :
263+ _reduce_quantizer_attr (quantizer , "amax" , dist .ReduceOp .MAX )
264+ else :
265+ _reduce_quantizer_attr (model .fc2 .weight_quantizer , "amax" , dist .ReduceOp .MAX )
328266
329267 # Check act scale
330268 if config in [mtq .INT4_AWQ_CFG , mtq .W4A8_AWQ_BETA_CFG ]:
@@ -333,11 +271,6 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX):
333271 "act_scale" ,
334272 dist .ReduceOp .AVG ,
335273 )
336- _reduce_quantizer_attr (
337- model .fc2 .awq_lite ,
338- "act_scale" ,
339- dist .ReduceOp .AVG ,
340- )
341274
342275
343276def auto_quantize_helper (model ):
0 commit comments