Skip to content

Commit 169677c

Browse files
committed
code cleanup and bug fixes
Signed-off-by: Kinjal Patel <[email protected]>
1 parent 1ea4ed1 commit 169677c

File tree

5 files changed

+105
-105
lines changed

5 files changed

+105
-105
lines changed

modelopt/torch/quantization/model_calib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
8181
return
8282

8383
def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
84-
"""Synchronize the amax across all ranks in the data parallel and context parallel groups."""
84+
"""Synchronize the amax across all ranks in the data parallel and expert parallel groups."""
8585
if isinstance(quantizer, SequentialQuantizer):
8686
for _q in quantizer:
8787
sync_quantizer_amax_across_dp_ep(_q, parallel_state)

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 82 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -52,37 +52,34 @@
5252

5353
def sync_amax_across_sequential_mlp(model: torch.nn.Module):
5454
"""Sync amax across experts in a SequentialMLP."""
55-
amax_dict = {
56-
"linear_fc1.input_quantizer": {},
57-
"linear_fc1.weight_quantizer": {},
58-
"linear_fc2.input_quantizer": {},
59-
"linear_fc2.weight_quantizer": {},
60-
}
61-
# gather amax values from SequentialMLP experts
62-
for name, module in model.named_modules():
55+
amax_dict = {}
56+
57+
def get_sequential_mlp_expert_names(name: str, module: torch.nn.Module):
6358
if (
64-
not isinstance(module, TensorQuantizer)
65-
or not hasattr(module, "_amax")
66-
or "local_experts" not in name
59+
isinstance(module, TensorQuantizer)
60+
and hasattr(module, "_amax")
61+
and ".local_experts." in name
6762
):
68-
continue
69-
expert_name, local_expert_name = name.split("local_experts")
70-
for key in amax_dict:
71-
if key in local_expert_name:
72-
amax_dict[key][expert_name] = max(amax_dict[key].get(expert_name, 0), module.amax)
63+
expert_name, local_expert_name = name.split(".local_experts.")
64+
# extract quantizer name by removing local_expert number from the name
65+
local_expert_name = ".".join(local_expert_name.split(".")[1:])
66+
return expert_name, local_expert_name
67+
return None, None
68+
69+
# gather amax values from SequentialMLP experts
70+
for name, module in model.named_modules():
71+
expert_name, local_expert_name = get_sequential_mlp_expert_names(name, module)
72+
if expert_name and local_expert_name:
73+
amax_dict[local_expert_name] = amax_dict.get(local_expert_name, {})
74+
amax_dict[local_expert_name][expert_name] = max(
75+
amax_dict[local_expert_name].get(expert_name, 0), module.amax
76+
)
7377

7478
# sync amax values across experts in SequentialMLP
7579
for name, module in model.named_modules():
76-
if (
77-
not isinstance(module, TensorQuantizer)
78-
or not hasattr(module, "_amax")
79-
or "local_experts" not in name
80-
):
81-
continue
82-
expert_name, local_expert_name = name.split("local_experts")
83-
for key in amax_dict:
84-
if key in local_expert_name:
85-
module.amax = amax_dict[key][expert_name]
80+
expert_name, local_expert_name = get_sequential_mlp_expert_names(name, module)
81+
if expert_name and local_expert_name:
82+
module.amax = amax_dict[local_expert_name][expert_name]
8683

8784

8885
CUSTOM_POST_CALIBRATION_PLUGINS.add(sync_amax_across_sequential_mlp)
@@ -523,6 +520,11 @@ def forward(self, input, *args, **kwargs):
523520
# Register the public te.pytorch.GroupedLinear class
524521
@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"})
525522
class _QuantMegatronTEGroupedLinear(_MegatronParallelLinear):
523+
_functionals_to_replace = [
524+
(te_grouped_linear._GroupedLinear, "forward"),
525+
(te_grouped_linear._GroupedLinear, "apply"),
526+
]
527+
526528
def _setup(self):
527529
# GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
528530
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
@@ -531,46 +533,17 @@ def _setup(self):
531533
# Memorize the original weight.dtype for modelopt_post_restore given that
532534
# the dtype can change later.
533535
super()._setup()
534-
# Revert the weight to None after setup.
535-
self.weight = None
536-
537-
@property
538-
def functionals_to_replace(self):
539-
original_forward = te_grouped_linear._GroupedLinear.forward
540-
541-
def te_grouped_quantized_linear_fn(ctx, inp, m_splits, *args):
542-
num_gemms = len(m_splits)
543-
weights_and_biases = args[-2 * num_gemms :]
544-
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
545-
quantized_inputs = self.input_quantizer(inp)
546-
quantized_weights = [self.weight_quantizer(weight) for weight in weights]
547-
548-
output = original_forward(
549-
ctx,
550-
quantized_inputs,
551-
m_splits,
552-
*args[: -2 * num_gemms],
553-
*quantized_weights,
554-
*biases,
555-
)
556-
return self.output_quantizer(output)
557-
558-
return [
559-
(
560-
te_grouped_linear._GroupedLinear,
561-
"forward",
562-
te_grouped_quantized_linear_fn,
563-
),
564-
]
536+
# Remove self.weight after setup.
537+
delattr(self, "weight")
565538

566539
def modelopt_post_restore(self, prefix: str = ""):
567540
# GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
568541
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
569542
# self.weight0 to self.weight to run the quantizer states initialization.
570543
self.weight = self.weight0
571544
super().modelopt_post_restore(prefix=prefix)
572-
# Revert the weight to None after post_restore.
573-
self.weight = None
545+
# Remove self.weight after post_restore.
546+
delattr(self, "weight")
574547

575548
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
576549
# _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in
@@ -585,10 +558,34 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
585558
return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs)
586559

587560
def _process_quantizer_amax(self, k, v, quantizer_state_dict):
588-
if v.ndim == 4:
589-
quantizer_state_dict[k] = v.squeeze(1).squeeze(-1)
590-
else:
591-
quantizer_state_dict[k] = v.view(-1, 1) if v.numel() > 1 else v.view(-1)
561+
assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization"
562+
quantizer_state_dict[k] = v.view(-1)
563+
564+
@staticmethod
565+
def te_grouped_quantized_linear_fn(package, func_name, self, *args):
566+
idx = 1 if func_name == "_forward" else 0
567+
inp = args[idx]
568+
num_gemms = len(args[idx + 1])
569+
weights_and_biases = args[-2 * num_gemms :]
570+
weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:]
571+
quantized_inputs = self.input_quantizer(inp)
572+
quantized_weights = [self.weight_quantizer(weight) for weight in weights]
573+
574+
output = getattr(package, func_name)(
575+
*(
576+
args[0],
577+
quantized_inputs,
578+
)
579+
if func_name == "_forward"
580+
else (quantized_inputs,),
581+
*args[idx + 1 : -2 * num_gemms],
582+
*quantized_weights,
583+
*biases,
584+
)
585+
return self.output_quantizer(output)
586+
587+
# Override the quantized linear function
588+
_quantized_linear_fn = te_grouped_quantized_linear_fn
592589

593590

594591
@QuantModuleRegistry.register(
@@ -614,42 +611,36 @@ class _MegatronTEGroupedRowParallelLinear(
614611
class _MegatronTEGroupedMLP(_MegatronMLP):
615612
def _setup(self):
616613
if not hasattr(self, "parallel_state") or self.parallel_state is None:
617-
data_parallel_group = None
618-
try:
619-
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
620-
except AssertionError:
621-
logger.warning(
622-
"Context parallel group is not initialized, using data parallel group"
623-
)
624-
data_parallel_group = get_data_parallel_group()
625-
626-
try:
627-
expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group()
628-
except AssertionError:
629-
expert_tensor_parallel_group = None
630614
self.parallel_state = ParallelState(
631-
data_parallel_group,
632-
tensor_parallel_group=expert_tensor_parallel_group,
633-
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
615+
mcore_parallel.get_expert_data_parallel_group(check_initialized=False),
616+
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(
617+
check_initialized=False
618+
),
619+
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(
620+
check_initialized=False
621+
),
634622
)
623+
# initialize parallel state for submodules linear_fc1 and linear_fc2
624+
self.linear_fc1.parallel_state = self.parallel_state
625+
self.linear_fc2.parallel_state = self.parallel_state
635626

636627

637628
# Register the public megatron_moe.SequentialMLP class
638629
@QuantModuleRegistry.register({megatron_moe.SequentialMLP: "megatron_moe_SequentialMLP"})
639630
class _MegatronSequentialMLP(_MegatronMLP):
640631
def _setup(self):
641632
if not hasattr(self, "parallel_state") or self.parallel_state is None:
642-
try:
643-
data_parallel_group = mcore_parallel.get_expert_data_parallel_group()
644-
except AssertionError:
645-
data_parallel_group = None
646-
647-
try:
648-
expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group()
649-
except AssertionError:
650-
expert_tensor_parallel_group = None
651633
self.parallel_state = ParallelState(
652-
data_parallel_group,
653-
tensor_parallel_group=expert_tensor_parallel_group,
654-
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
634+
mcore_parallel.get_expert_data_parallel_group(check_initialized=False),
635+
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(
636+
check_initialized=False
637+
),
638+
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(
639+
check_initialized=False
640+
),
655641
)
642+
643+
# Initialize parallel state for submodules local_experts.*.linear_fc1 and local_experts.*.linear_fc2
644+
for expert in self.local_experts:
645+
expert.linear_fc1.parallel_state = self.parallel_state
646+
expert.linear_fc2.parallel_state = self.parallel_state

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,15 @@ def compare_amax_sync_across_expert_parallel(model):
588588

589589
if quantizer_type not in expert_quantizers:
590590
expert_quantizers[quantizer_type] = {}
591+
if (
592+
quantizer_type in expert_quantizers
593+
and rank_idx in expert_quantizers[quantizer_type]
594+
):
595+
# compare expert value across expert for sequential MoE
596+
assert expert_quantizers[quantizer_type][rank_idx] == amax_val, (
597+
f"{rank_idx}, {quantizer_type}, expert_quantizers[quantizer_type][rank_idx]: "
598+
f"{expert_quantizers[quantizer_type][rank_idx]}, amax_val: {amax_val}"
599+
)
591600
expert_quantizers[quantizer_type][rank_idx] = amax_val
592601

593602
# Check synchronization - fail fast on first inconsistency

tests/gpu/torch/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ def need_8_gpus():
4040
pytest.skip("Need at least 8 GPUs to run this test")
4141

4242

43+
@pytest.fixture
44+
def need_4_gpus():
45+
if torch.cuda.device_count() < 4:
46+
pytest.skip("Need at least 4 GPUs to run this test")
47+
48+
4349
@pytest.fixture(scope="module")
4450
def set_torch_dtype(request):
4551
orig_dtype = torch.get_default_dtype()

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
auto_quantize_helper,
3636
data_tensor_context_parallel_test_helper,
3737
dp_cp_parallel_test_helper,
38-
tensor_parallel_test_helper,
3938
)
4039

4140
skip_if_no_megatron()
@@ -621,12 +620,10 @@ def test_fp8_real_quantize():
621620
mtq.NVFP4_DEFAULT_CFG,
622621
],
623622
)
624-
@pytest.mark.parametrize("moe_grouped_gemm", [False, True])
625-
def test_moe_sharded_state_dict(tmp_path, config, moe_grouped_gemm):
623+
@pytest.mark.parametrize("moe_grouped_gemm", [True, False])
624+
def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm):
626625
size = torch.cuda.device_count()
627626
# TODO: Add support for compress=True for TEGroupedMLP
628-
if size < 4:
629-
pytest.skip("Requires at least 4 GPUs for expert parallel test")
630627
moe_config = {
631628
"tp_size": 2,
632629
"ep_size": 2,
@@ -720,13 +717,9 @@ def forward_fn(model):
720717
)
721718

722719

723-
def test_te_grouped_vs_sequential_quantize():
720+
def test_te_grouped_vs_sequential_quantize(need_4_gpus):
724721
"""Test that TEGrouped and sequential MoE models produce similar quantized models."""
725-
726722
size = torch.cuda.device_count()
727-
if size < 4:
728-
pytest.skip("Requires at least 4 GPUs for expert parallel test")
729-
730723
spawn_multiprocess_job(
731724
size=size,
732725
job=partial(_test_te_grouped_vs_sequential_quantize_helper, 1, 2, 2),
@@ -763,7 +756,6 @@ def forward_fn(model):
763756

764757
# quantize the model
765758
model = mtq.quantize(model, config, forward_fn)
766-
767759
# Check initial sync status
768760
initial_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model)
769761
assert initial_sync, (
@@ -790,8 +782,10 @@ def forward_fn(model):
790782
"Consistent amax across expert parallel ranks, "
791783
"Amax should not be synchronized across expert parallel ranks since expert parallel is disabled"
792784
)
793-
# Re-enable parallel groups and test synchronization
794-
mtq.model_calib.max_calibrate(model, forward_fn)
785+
# Re-calibrate the model and test synchronization
786+
mtq.mode.wrapped_calib_func(
787+
model, mtq.config.MaxCalibConfig(), forward_fn, mtq.model_calib.max_calibrate
788+
)
795789

796790
final_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model)
797791
assert final_sync, f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}"

0 commit comments

Comments
 (0)