@@ -119,7 +119,7 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
119
119
mto .restore_from_modelopt_state (model_ref , state_dict )
120
120
121
121
122
- def _reduce_quantizer_attr (quantizer , attr = str , op = dist .ReduceOp .MAX , group = None ):
122
+ def _reduce_quantizer_attr (quantizer , attr : str , op = dist .ReduceOp .MAX , group = None ):
123
123
quantizer_attr = getattr (quantizer , attr ).clone ()
124
124
print ("quantizer.attr before reduce" , getattr (quantizer , attr ))
125
125
dist .all_reduce (quantizer_attr , op = op , group = group )
@@ -225,9 +225,46 @@ def forward_loop(model):
225
225
226
226
227
227
@patch ("modelopt.torch.quantization.model_calib.awq_lite" , side_effect = _debug_awq_lite )
228
- def data_tensor_context_parallel_test_helper (
229
- model , config , dp_group , tp_group , cp_group , mock_awq_lite
230
- ):
228
+ def data_tensor_context_parallel_test_helper (model , config , dp_group , tp_group , mock_awq_lite ):
229
+ # Print rank information for debugging
230
+ world_rank = dist .get_rank ()
231
+ world_size = dist .get_world_size ()
232
+
233
+ print ("\n === RANK INFORMATION ===" )
234
+ print (f"World Rank: { world_rank } , World Size: { world_size } " )
235
+
236
+ # Get group information with actual ranks
237
+ def get_group_ranks (group ):
238
+ if group is None :
239
+ return None
240
+ ranks = []
241
+ ranks = [
242
+ i for i in range (world_size ) if dist .get_rank (group = group ) == dist .get_rank (group = group )
243
+ ]
244
+ return ranks
245
+
246
+ if dp_group is not None :
247
+ dp_rank = dist .get_rank (group = dp_group )
248
+ dp_size = dist .get_world_size (group = dp_group )
249
+ print (f"DP Group - Rank: { dp_rank } , Size: { dp_size } " )
250
+
251
+ if tp_group is not None :
252
+ tp_rank = dist .get_rank (group = tp_group )
253
+ tp_size = dist .get_world_size (group = tp_group )
254
+ print (f"TP Group - Rank: { tp_rank } , Size: { tp_size } " )
255
+
256
+ print ("=== END RANK INFO ===\n " )
257
+
258
+ # Print a summary of all ranks
259
+ print ("=== ALL RANKS SUMMARY ===" )
260
+ print (f"Total GPUs: { world_size } " )
261
+ print (f"Current rank: { world_rank } " )
262
+ if dp_group is not None :
263
+ print (f"DP groups: { dp_size } groups of { world_size // dp_size } ranks each" )
264
+ if tp_group is not None :
265
+ print (f"TP groups: { tp_size } groups of { world_size // tp_size } ranks each" )
266
+ print ("=== END SUMMARY ===\n " )
267
+
231
268
calib_data = model .get_dummy_input ().cuda ()
232
269
# data should be same across each TP rank
233
270
dist .all_reduce (calib_data , op = dist .ReduceOp .AVG , group = tp_group )
@@ -238,14 +275,38 @@ def forward_loop(model):
238
275
model = mtq .quantize (model , config , forward_loop )
239
276
240
277
def _reduce_quantizer_attr (quantizer , attr = str , op = dist .ReduceOp .MAX ):
278
+ world_rank = dist .get_rank ()
279
+ print (f"\n --- Rank { world_rank } : Reducing { attr } ---" )
280
+ from megatron .core .parallel_state import (
281
+ _CONTEXT_PARALLEL_GLOBAL_RANKS ,
282
+ _DATA_PARALLEL_GLOBAL_RANKS ,
283
+ _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP ,
284
+ _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS ,
285
+ )
286
+
287
+ print (f"DATA_PARALLEL_GLOBAL_RANKS: { _DATA_PARALLEL_GLOBAL_RANKS } " )
288
+ print (f"CONTEXT_PARALLEL_GLOBAL_RANKS: { _CONTEXT_PARALLEL_GLOBAL_RANKS } " )
289
+ print (f"DATA_PARALLEL_GLOBAL_RANKS_WITH_CP: { _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP } " )
290
+ print (f"TENSOR_MODEL_PARALLEL_GLOBAL_RANKS: { _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS } " )
241
291
quantizer_attr = getattr (quantizer , attr ).clone ()
242
- print ("quantizer_attr before reduce" , quantizer_attr )
243
- print ("quantizer.attr before reduce" , getattr (quantizer , attr ))
244
- dist .all_reduce (quantizer_attr , op = op , group = dp_group )
245
- dist .all_reduce (quantizer_attr , op = op , group = cp_group )
246
- dist .all_reduce (quantizer_attr , op = op , group = tp_group )
247
- print ("quantizer_attr after reduce" , quantizer_attr )
248
- print ("quantizer.attr after reduce" , getattr (quantizer , attr ))
292
+ print (f"Rank { world_rank } - quantizer_attr before reduce" , quantizer_attr )
293
+ print (f"Rank { world_rank } - quantizer.attr before reduce" , getattr (quantizer , attr ))
294
+
295
+ # Perform all-reduce operations
296
+ if tp_group is not None :
297
+ tp_rank = dist .get_rank (group = tp_group )
298
+ print (f"Rank { world_rank } - TP reduce (TP rank { tp_rank } )" )
299
+ dist .all_reduce (quantizer_attr , op = op , group = tp_group )
300
+
301
+ if dp_group is not None :
302
+ dp_rank = dist .get_rank (group = dp_group )
303
+ print (f"Rank { world_rank } - DP reduce (DP rank { dp_rank } )" )
304
+ dist .all_reduce (quantizer_attr , op = op , group = dp_group )
305
+
306
+ print (f"Rank { world_rank } - quantizer_attr after reduce" , quantizer_attr )
307
+ print (f"Rank { world_rank } - quantizer.attr after reduce" , getattr (quantizer , attr ))
308
+ print (f"--- End Rank { world_rank } ---\n " )
309
+
249
310
assert torch .allclose (quantizer_attr , getattr (quantizer , attr ))
250
311
251
312
# Input quantizer amax
0 commit comments