Skip to content

Commit fa8f4c8

Browse files
committed
add print
Signed-off-by: Jennifer Chen <[email protected]>
1 parent 9f0691f commit fa8f4c8

File tree

2 files changed

+74
-15
lines changed

2 files changed

+74
-15
lines changed

tests/_test_utils/torch_quantization/quantize_common.py

Lines changed: 72 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N
119119
mto.restore_from_modelopt_state(model_ref, state_dict)
120120

121121

122-
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None):
122+
def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None):
123123
quantizer_attr = getattr(quantizer, attr).clone()
124124
print("quantizer.attr before reduce", getattr(quantizer, attr))
125125
dist.all_reduce(quantizer_attr, op=op, group=group)
@@ -225,9 +225,46 @@ def forward_loop(model):
225225

226226

227227
@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite)
228-
def data_tensor_context_parallel_test_helper(
229-
model, config, dp_group, tp_group, cp_group, mock_awq_lite
230-
):
228+
def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, mock_awq_lite):
229+
# Print rank information for debugging
230+
world_rank = dist.get_rank()
231+
world_size = dist.get_world_size()
232+
233+
print("\n=== RANK INFORMATION ===")
234+
print(f"World Rank: {world_rank}, World Size: {world_size}")
235+
236+
# Get group information with actual ranks
237+
def get_group_ranks(group):
238+
if group is None:
239+
return None
240+
ranks = []
241+
ranks = [
242+
i for i in range(world_size) if dist.get_rank(group=group) == dist.get_rank(group=group)
243+
]
244+
return ranks
245+
246+
if dp_group is not None:
247+
dp_rank = dist.get_rank(group=dp_group)
248+
dp_size = dist.get_world_size(group=dp_group)
249+
print(f"DP Group - Rank: {dp_rank}, Size: {dp_size}")
250+
251+
if tp_group is not None:
252+
tp_rank = dist.get_rank(group=tp_group)
253+
tp_size = dist.get_world_size(group=tp_group)
254+
print(f"TP Group - Rank: {tp_rank}, Size: {tp_size}")
255+
256+
print("=== END RANK INFO ===\n")
257+
258+
# Print a summary of all ranks
259+
print("=== ALL RANKS SUMMARY ===")
260+
print(f"Total GPUs: {world_size}")
261+
print(f"Current rank: {world_rank}")
262+
if dp_group is not None:
263+
print(f"DP groups: {dp_size} groups of {world_size // dp_size} ranks each")
264+
if tp_group is not None:
265+
print(f"TP groups: {tp_size} groups of {world_size // tp_size} ranks each")
266+
print("=== END SUMMARY ===\n")
267+
231268
calib_data = model.get_dummy_input().cuda()
232269
# data should be same across each TP rank
233270
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)
@@ -238,14 +275,38 @@ def forward_loop(model):
238275
model = mtq.quantize(model, config, forward_loop)
239276

240277
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX):
278+
world_rank = dist.get_rank()
279+
print(f"\n--- Rank {world_rank}: Reducing {attr} ---")
280+
from megatron.core.parallel_state import (
281+
_CONTEXT_PARALLEL_GLOBAL_RANKS,
282+
_DATA_PARALLEL_GLOBAL_RANKS,
283+
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP,
284+
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS,
285+
)
286+
287+
print(f"DATA_PARALLEL_GLOBAL_RANKS: {_DATA_PARALLEL_GLOBAL_RANKS}")
288+
print(f"CONTEXT_PARALLEL_GLOBAL_RANKS: {_CONTEXT_PARALLEL_GLOBAL_RANKS}")
289+
print(f"DATA_PARALLEL_GLOBAL_RANKS_WITH_CP: {_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP}")
290+
print(f"TENSOR_MODEL_PARALLEL_GLOBAL_RANKS: {_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS}")
241291
quantizer_attr = getattr(quantizer, attr).clone()
242-
print("quantizer_attr before reduce", quantizer_attr)
243-
print("quantizer.attr before reduce", getattr(quantizer, attr))
244-
dist.all_reduce(quantizer_attr, op=op, group=dp_group)
245-
dist.all_reduce(quantizer_attr, op=op, group=cp_group)
246-
dist.all_reduce(quantizer_attr, op=op, group=tp_group)
247-
print("quantizer_attr after reduce", quantizer_attr)
248-
print("quantizer.attr after reduce", getattr(quantizer, attr))
292+
print(f"Rank {world_rank} - quantizer_attr before reduce", quantizer_attr)
293+
print(f"Rank {world_rank} - quantizer.attr before reduce", getattr(quantizer, attr))
294+
295+
# Perform all-reduce operations
296+
if tp_group is not None:
297+
tp_rank = dist.get_rank(group=tp_group)
298+
print(f"Rank {world_rank} - TP reduce (TP rank {tp_rank})")
299+
dist.all_reduce(quantizer_attr, op=op, group=tp_group)
300+
301+
if dp_group is not None:
302+
dp_rank = dist.get_rank(group=dp_group)
303+
print(f"Rank {world_rank} - DP reduce (DP rank {dp_rank})")
304+
dist.all_reduce(quantizer_attr, op=op, group=dp_group)
305+
306+
print(f"Rank {world_rank} - quantizer_attr after reduce", quantizer_attr)
307+
print(f"Rank {world_rank} - quantizer.attr after reduce", getattr(quantizer, attr))
308+
print(f"--- End Rank {world_rank} ---\n")
309+
249310
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
250311

251312
# Input quantizer amax

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
import megatron.core
4343
from megatron.core.parallel_state import (
4444
destroy_model_parallel,
45-
get_context_parallel_group,
4645
get_data_parallel_group,
4746
get_tensor_model_parallel_group,
4847
)
@@ -152,7 +151,7 @@ def _test_context_parallel_helper(config, rank, size):
152151
) # modify seed so data is different across ranks
153152
model = MegatronModel(cp_size=size).cuda()
154153

155-
dp_cp_parallel_test_helper(model, config, get_context_parallel_group())
154+
dp_cp_parallel_test_helper(model, config, get_data_parallel_group(with_context_parallel=True))
156155

157156

158157
@pytest.mark.parametrize(
@@ -181,9 +180,8 @@ def _test_data_tensor_context_parallel_helper(config, rank, size):
181180
data_tensor_context_parallel_test_helper(
182181
model,
183182
config,
184-
get_data_parallel_group(),
183+
get_data_parallel_group(with_context_parallel=True),
185184
get_tensor_model_parallel_group(),
186-
get_context_parallel_group(),
187185
)
188186

189187

0 commit comments

Comments
 (0)