Skip to content

Commit 22b8b73

Browse files
committed
fix tests
Signed-off-by: Jennifer Chen <[email protected]>
1 parent d1fac44 commit 22b8b73

File tree

6 files changed

+42
-112
lines changed

6 files changed

+42
-112
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -80,22 +80,21 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
8080
if not distributed_sync:
8181
return
8282

83-
def sync_quantizer_amax_across_dp_cp(quantizer, parallel_state):
84-
"""Synchronize the amax across all ranks in the data parallel and context parallel groups."""
83+
def sync_quantizer_amax_across_dp(quantizer, parallel_state):
84+
"""Synchronize the amax across all ranks in the data parallel group."""
8585
if isinstance(quantizer, SequentialQuantizer):
8686
for _q in quantizer:
87-
sync_quantizer_amax_across_dp_cp(_q, parallel_state)
87+
sync_quantizer_amax_across_dp(_q, parallel_state)
8888
return
8989
if getattr(quantizer, "_amax", None) is not None:
9090
quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group)
91-
quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group)
9291
# TODO: create sync_bias_across_distributed_group
9392

9493
for name, module in model.named_modules():
9594
if isinstance(module, QuantModule):
9695
for child in module.children():
9796
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
98-
sync_quantizer_amax_across_dp_cp(child, module.parallel_state)
97+
sync_quantizer_amax_across_dp(child, module.parallel_state)
9998
# TP sync:
10099
# Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same
101100

@@ -600,17 +599,12 @@ def forward(self, input, *args, **kwargs):
600599
# This will also perform distributed amax sync for input_quantizers
601600
max_calibrate(model, lambda model: None)
602601

603-
def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group):
604-
# Sync across Data Parallel (DP)
602+
def sync_act_scale_across_dp(module, data_parallel_group):
603+
"""Sync activation scale across Data Parallel (DP)."""
605604
if data_parallel_group.is_initialized():
606605
dist.all_reduce(
607606
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group
608607
)
609-
# Sync across Context Parallel (CP)
610-
if context_parallel_group.is_initialized():
611-
dist.all_reduce(
612-
module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group
613-
)
614608

615609
for name, module in model.named_modules():
616610
if (
@@ -627,10 +621,9 @@ def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_gr
627621
):
628622
module.awq_lite.is_enabled = False
629623
else:
630-
sync_act_scale_across_dp_cp(
624+
sync_act_scale_across_dp(
631625
module,
632626
module.parallel_state.data_parallel_group,
633-
module.parallel_state.context_parallel_group,
634627
)
635628

636629
AWQLiteHelper.cache_mode = False

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
"""Support quantization for megatron linear layers."""
1717

18+
import logging
1819
import warnings
1920
from typing import Any
2021

@@ -39,6 +40,8 @@
3940
from ..qtensor import QTensorWrapper
4041
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
4142

43+
logger = logging.getLogger(__name__)
44+
4245
__all__ = []
4346

4447

@@ -222,11 +225,11 @@ def _setup(self):
222225
try:
223226
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
224227
except AssertionError:
228+
logger.warning("Context parallel group is not initialized, using data parallel group")
225229
data_parallel_group = get_data_parallel_group()
226230
self.parallel_state = ParallelState(
227231
data_parallel_group,
228232
mcore_parallel.get_tensor_model_parallel_group(),
229-
mcore_parallel.get_context_parallel_group(),
230233
)
231234
super()._setup()
232235

modelopt/torch/utils/distributed.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,18 +241,15 @@ def __init__(
241241
self,
242242
data_parallel_group: torch.distributed.ProcessGroup | int | None = None,
243243
tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
244-
context_parallel_group: torch.distributed.ProcessGroup | int | None = -1,
245244
):
246245
"""Initialize the parallel state."""
247246
self.data_parallel_group = DistributedProcessGroup(data_parallel_group)
248247
self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group)
249-
self.context_parallel_group = DistributedProcessGroup(context_parallel_group)
250248

251249
def __repr__(self) -> str:
252250
return (
253251
f"data_parallel_group: {self.data_parallel_group}, "
254252
f"tensor_parallel_group: {self.tensor_parallel_group}, "
255-
f"context_parallel_group: {self.context_parallel_group}"
256253
)
257254

258255

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,11 @@ def forward(self, x):
127127
x = x[0]
128128
return x
129129

130-
def get_dummy_input(self) -> torch.Tensor:
130+
def get_dummy_input(self, seed: int | None = None) -> torch.Tensor:
131+
if seed is not None:
132+
gen = torch.Generator()
133+
gen.manual_seed(seed)
134+
return torch.randn(1, 4, 32, generator=gen)
131135
return torch.randn(1, 4, 32)
132136

133137

tests/_test_utils/torch_quantization/quantize_common.py

Lines changed: 25 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,6 @@ def forward_loop(model):
172172
dist.ReduceOp.AVG,
173173
group=tp_group,
174174
)
175-
_reduce_quantizer_attr(
176-
model.fc2.awq_lite,
177-
"act_scale",
178-
dist.ReduceOp.AVG,
179-
group=tp_group,
180-
)
181175

182176
dist.destroy_process_group()
183177

@@ -191,6 +185,9 @@ def forward_loop(model):
191185

192186
model = mtq.quantize(model, config, forward_loop)
193187

188+
# Sanity check
189+
forward_loop(model)
190+
194191
# Input quantizer amax
195192
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
196193
_reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group)
@@ -226,105 +223,46 @@ def forward_loop(model):
226223

227224
@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite)
228225
def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, mock_awq_lite):
229-
# Print rank information for debugging
230-
world_rank = dist.get_rank()
231-
world_size = dist.get_world_size()
232-
233-
print("\n=== RANK INFORMATION ===")
234-
print(f"World Rank: {world_rank}, World Size: {world_size}")
235-
236-
# Get group information with actual ranks
237-
def get_group_ranks(group):
238-
if group is None:
239-
return None
240-
ranks = []
241-
ranks = [
242-
i for i in range(world_size) if dist.get_rank(group=group) == dist.get_rank(group=group)
243-
]
244-
return ranks
245-
246-
if dp_group is not None:
247-
dp_rank = dist.get_rank(group=dp_group)
248-
dp_size = dist.get_world_size(group=dp_group)
249-
print(f"DP Group - Rank: {dp_rank}, Size: {dp_size}")
250-
251-
if tp_group is not None:
252-
tp_rank = dist.get_rank(group=tp_group)
253-
tp_size = dist.get_world_size(group=tp_group)
254-
print(f"TP Group - Rank: {tp_rank}, Size: {tp_size}")
255-
256-
print("=== END RANK INFO ===\n")
257-
258-
# Print a summary of all ranks
259-
print("=== ALL RANKS SUMMARY ===")
260-
print(f"Total GPUs: {world_size}")
261-
print(f"Current rank: {world_rank}")
262-
if dp_group is not None:
263-
print(f"DP groups: {dp_size} groups of {world_size // dp_size} ranks each")
264-
if tp_group is not None:
265-
print(f"TP groups: {tp_size} groups of {world_size // tp_size} ranks each")
266-
print("=== END SUMMARY ===\n")
267-
268-
calib_data = model.get_dummy_input().cuda()
269-
# data should be same across each TP rank
270-
dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group)
226+
# Calib data should be same across each DP rank
227+
dp_rank = dist.get_rank(group=dp_group)
228+
calib_data = model.get_dummy_input(seed=dp_rank).cuda()
271229

272230
def forward_loop(model):
273231
model(calib_data)
274232

275233
model = mtq.quantize(model, config, forward_loop)
276234

277235
def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX):
278-
world_rank = dist.get_rank()
279-
print(f"\n--- Rank {world_rank}: Reducing {attr} ---")
280-
from megatron.core.parallel_state import (
281-
_CONTEXT_PARALLEL_GLOBAL_RANKS,
282-
_DATA_PARALLEL_GLOBAL_RANKS,
283-
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP,
284-
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS,
285-
)
286-
287-
print(f"DATA_PARALLEL_GLOBAL_RANKS: {_DATA_PARALLEL_GLOBAL_RANKS}")
288-
print(f"CONTEXT_PARALLEL_GLOBAL_RANKS: {_CONTEXT_PARALLEL_GLOBAL_RANKS}")
289-
print(f"DATA_PARALLEL_GLOBAL_RANKS_WITH_CP: {_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP}")
290-
print(f"TENSOR_MODEL_PARALLEL_GLOBAL_RANKS: {_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS}")
291236
quantizer_attr = getattr(quantizer, attr).clone()
292-
print(f"Rank {world_rank} - quantizer_attr before reduce", quantizer_attr)
293-
print(f"Rank {world_rank} - quantizer.attr before reduce", getattr(quantizer, attr))
294237

295238
# Perform all-reduce operations
296-
if tp_group is not None:
297-
tp_rank = dist.get_rank(group=tp_group)
298-
print(f"Rank {world_rank} - TP reduce (TP rank {tp_rank})")
299-
dist.all_reduce(quantizer_attr, op=op, group=tp_group)
239+
dist.all_reduce(quantizer_attr, op=op, group=tp_group)
300240

301-
if dp_group is not None:
302-
dp_rank = dist.get_rank(group=dp_group)
303-
print(f"Rank {world_rank} - DP reduce (DP rank {dp_rank})")
304-
dist.all_reduce(quantizer_attr, op=op, group=dp_group)
241+
dist.all_reduce(quantizer_attr, op=op, group=dp_group)
305242

306-
print(f"Rank {world_rank} - quantizer_attr after reduce", quantizer_attr)
307-
print(f"Rank {world_rank} - quantizer.attr after reduce", getattr(quantizer, attr))
308-
print(f"--- End Rank {world_rank} ---\n")
309-
310-
assert torch.allclose(quantizer_attr, getattr(quantizer, attr))
243+
assert torch.allclose(quantizer_attr, getattr(quantizer, attr)), getattr(quantizer, attr)
311244

312245
# Input quantizer amax
313246
if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]:
314247
_reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX)
315248
_reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX)
316249

317-
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
318-
for quantizer in model.fc1.weight_quantizer:
319-
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX)
320-
else:
321-
_reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX)
322-
323-
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
324-
for quantizer in model.fc2.weight_quantizer:
325-
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX)
326-
else:
327-
_reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX)
250+
# Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks
251+
# Channel-wise (INT8) only expects same amax across row parallel ranks
252+
# Block-wise quantization does not expect same amax across row and column parallel ranks
253+
if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]:
254+
if isinstance(model.fc1.weight_quantizer, SequentialQuantizer):
255+
for quantizer in model.fc1.weight_quantizer:
256+
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX)
257+
else:
258+
_reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX)
259+
260+
if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT8_DEFAULT_CFG]:
261+
if isinstance(model.fc2.weight_quantizer, SequentialQuantizer):
262+
for quantizer in model.fc2.weight_quantizer:
263+
_reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX)
264+
else:
265+
_reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX)
328266

329267
# Check act scale
330268
if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]:
@@ -333,11 +271,6 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX):
333271
"act_scale",
334272
dist.ReduceOp.AVG,
335273
)
336-
_reduce_quantizer_attr(
337-
model.fc2.awq_lite,
338-
"act_scale",
339-
dist.ReduceOp.AVG,
340-
)
341274

342275

343276
def auto_quantize_helper(model):

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def _test_data_tensor_context_parallel_helper(config, rank, size):
199199
)
200200
def test_data_tensor_context_parallel(need_8_gpus, config):
201201
spawn_multiprocess_job(
202-
size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl"
202+
size=4, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl"
203203
)
204204

205205

0 commit comments

Comments
 (0)