@@ -119,7 +119,7 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
119119 mto .restore_from_modelopt_state (model_ref , state_dict )
120120
121121
122- def _reduce_quantizer_attr (quantizer , attr = str , op = dist .ReduceOp .MAX , group = None ):
122+ def _reduce_quantizer_attr (quantizer , attr : str , op = dist .ReduceOp .MAX , group = None ):
123123 quantizer_attr = getattr (quantizer , attr ).clone ()
124124 print ("quantizer.attr before reduce" , getattr (quantizer , attr ))
125125 dist .all_reduce (quantizer_attr , op = op , group = group )
@@ -225,9 +225,46 @@ def forward_loop(model):
225225
226226
227227@patch ("modelopt.torch.quantization.model_calib.awq_lite" , side_effect = _debug_awq_lite )
228- def data_tensor_context_parallel_test_helper (
229- model , config , dp_group , tp_group , cp_group , mock_awq_lite
230- ):
228+ def data_tensor_context_parallel_test_helper (model , config , dp_group , tp_group , mock_awq_lite ):
229+ # Print rank information for debugging
230+ world_rank = dist .get_rank ()
231+ world_size = dist .get_world_size ()
232+
233+ print ("\n === RANK INFORMATION ===" )
234+ print (f"World Rank: { world_rank } , World Size: { world_size } " )
235+
236+ # Get group information with actual ranks
237+ def get_group_ranks (group ):
238+ if group is None :
239+ return None
240+ ranks = []
241+ ranks = [
242+ i for i in range (world_size ) if dist .get_rank (group = group ) == dist .get_rank (group = group )
243+ ]
244+ return ranks
245+
246+ if dp_group is not None :
247+ dp_rank = dist .get_rank (group = dp_group )
248+ dp_size = dist .get_world_size (group = dp_group )
249+ print (f"DP Group - Rank: { dp_rank } , Size: { dp_size } " )
250+
251+ if tp_group is not None :
252+ tp_rank = dist .get_rank (group = tp_group )
253+ tp_size = dist .get_world_size (group = tp_group )
254+ print (f"TP Group - Rank: { tp_rank } , Size: { tp_size } " )
255+
256+ print ("=== END RANK INFO ===\n " )
257+
258+ # Print a summary of all ranks
259+ print ("=== ALL RANKS SUMMARY ===" )
260+ print (f"Total GPUs: { world_size } " )
261+ print (f"Current rank: { world_rank } " )
262+ if dp_group is not None :
263+ print (f"DP groups: { dp_size } groups of { world_size // dp_size } ranks each" )
264+ if tp_group is not None :
265+ print (f"TP groups: { tp_size } groups of { world_size // tp_size } ranks each" )
266+ print ("=== END SUMMARY ===\n " )
267+
231268 calib_data = model .get_dummy_input ().cuda ()
232269 # data should be same across each TP rank
233270 dist .all_reduce (calib_data , op = dist .ReduceOp .AVG , group = tp_group )
@@ -238,14 +275,38 @@ def forward_loop(model):
238275 model = mtq .quantize (model , config , forward_loop )
239276
240277 def _reduce_quantizer_attr (quantizer , attr = str , op = dist .ReduceOp .MAX ):
278+ world_rank = dist .get_rank ()
279+ print (f"\n --- Rank { world_rank } : Reducing { attr } ---" )
280+ from megatron .core .parallel_state import (
281+ _CONTEXT_PARALLEL_GLOBAL_RANKS ,
282+ _DATA_PARALLEL_GLOBAL_RANKS ,
283+ _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP ,
284+ _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS ,
285+ )
286+
287+ print (f"DATA_PARALLEL_GLOBAL_RANKS: { _DATA_PARALLEL_GLOBAL_RANKS } " )
288+ print (f"CONTEXT_PARALLEL_GLOBAL_RANKS: { _CONTEXT_PARALLEL_GLOBAL_RANKS } " )
289+ print (f"DATA_PARALLEL_GLOBAL_RANKS_WITH_CP: { _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP } " )
290+ print (f"TENSOR_MODEL_PARALLEL_GLOBAL_RANKS: { _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS } " )
241291 quantizer_attr = getattr (quantizer , attr ).clone ()
242- print ("quantizer_attr before reduce" , quantizer_attr )
243- print ("quantizer.attr before reduce" , getattr (quantizer , attr ))
244- dist .all_reduce (quantizer_attr , op = op , group = dp_group )
245- dist .all_reduce (quantizer_attr , op = op , group = cp_group )
246- dist .all_reduce (quantizer_attr , op = op , group = tp_group )
247- print ("quantizer_attr after reduce" , quantizer_attr )
248- print ("quantizer.attr after reduce" , getattr (quantizer , attr ))
292+ print (f"Rank { world_rank } - quantizer_attr before reduce" , quantizer_attr )
293+ print (f"Rank { world_rank } - quantizer.attr before reduce" , getattr (quantizer , attr ))
294+
295+ # Perform all-reduce operations
296+ if tp_group is not None :
297+ tp_rank = dist .get_rank (group = tp_group )
298+ print (f"Rank { world_rank } - TP reduce (TP rank { tp_rank } )" )
299+ dist .all_reduce (quantizer_attr , op = op , group = tp_group )
300+
301+ if dp_group is not None :
302+ dp_rank = dist .get_rank (group = dp_group )
303+ print (f"Rank { world_rank } - DP reduce (DP rank { dp_rank } )" )
304+ dist .all_reduce (quantizer_attr , op = op , group = dp_group )
305+
306+ print (f"Rank { world_rank } - quantizer_attr after reduce" , quantizer_attr )
307+ print (f"Rank { world_rank } - quantizer.attr after reduce" , getattr (quantizer , attr ))
308+ print (f"--- End Rank { world_rank } ---\n " )
309+
249310 assert torch .allclose (quantizer_attr , getattr (quantizer , attr ))
250311
251312 # Input quantizer amax
0 commit comments