Skip to content

Commit 22bfe0e

Browse files
committed
code cleanup
Signed-off-by: Kinjal Patel <[email protected]>
1 parent 4dc16b0 commit 22bfe0e

File tree

3 files changed

+26
-92
lines changed

3 files changed

+26
-92
lines changed

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 7 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
)
3838
from modelopt.torch.utils.distributed import ParallelState
3939

40-
from ..nn import QuantModuleRegistry, SequentialQuantizer, TensorQuantizer
40+
from ..nn import QuantModuleRegistry, TensorQuantizer
4141
from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear
4242
from ..qtensor import QTensorWrapper
4343
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
@@ -501,7 +501,6 @@ def _setup(self):
501501
self.parallel_state = ParallelState(
502502
data_parallel_group,
503503
mcore_parallel.get_tensor_model_parallel_group(),
504-
mcore_parallel.get_context_parallel_group(),
505504
mcore_parallel.get_expert_model_parallel_group(),
506505
expert_tensor_parallel_group,
507506
)
@@ -544,70 +543,13 @@ def te_grouped_quantized_linear_fn(ctx, inp, m_splits, *args):
544543
]
545544

546545
def modelopt_post_restore(self, prefix: str = ""):
547-
"""Post restore to correctly configure the TensorQuantizer states for MCore/distributed frameworks.
548-
549-
ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their
550-
shape before saving. However this is not enough for MCore/distributed frameworks since the tensor parallelism
551-
could change between saving and restoring. If the tensor parallelism changes, the shape of the quantizer
552-
states also changes. So we need to re-calculate the quantizer states.
553-
"""
554-
from modelopt.torch.quantization.model_calib import max_calibrate
555-
556-
def _check_unsupported_states(quantizer: TensorQuantizer):
557-
for k in quantizer.state_dict():
558-
if k not in ["_amax", "_pre_quant_scale"]:
559-
warnings.warn(
560-
f"Restore of {k} for {prefix} is not supported. The restore of this layer might be "
561-
f"incorrect. Please implement a custom restore for {k}."
562-
)
563-
564-
def _has_state(quantizer, name):
565-
# Handling for SequentialQuantizer
566-
quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer
567-
return hasattr(quantizer, name)
568-
569-
# weights for TEGroupedLinear are stored in weight0, weight1, etc.
570-
if self.weight0 is None:
571-
return
572-
for quantizer in [self.weight_quantizer, self.input_quantizer, self.output_quantizer]:
573-
_check_unsupported_states(
574-
quantizer if isinstance(quantizer, TensorQuantizer) else quantizer[0]
575-
)
576-
if _has_state(self.weight_quantizer, "_amax"):
577-
self.weight_quantizer.reset_amax()
578-
for i in range(self.num_gemms):
579-
weight = getattr(self, f"weight{i}")
580-
assert weight is not None, "weight is None"
581-
582-
max_calibrate(self.weight_quantizer, lambda wq: wq(weight), distributed_sync=False)
583-
if _has_state(self.input_quantizer, "_pre_quant_scale"):
584-
if hasattr(self.input_quantizer, "_pre_quant_scale"):
585-
delattr(self.input_quantizer, "_pre_quant_scale")
586-
pqs = torch.zeros(
587-
(weight.shape[1]), device=weight.device, dtype=self.original_weight_dtype
588-
)
589-
self.input_quantizer.register_buffer("_pre_quant_scale", pqs)
590-
591-
if _has_state(self.input_quantizer, "_amax"):
592-
self.input_quantizer.reset_amax()
593-
dummy_input = torch.ones(
594-
(1, 1, self.weight0.shape[1]),
595-
device=self.weight0.device,
596-
dtype=self.original_weight_dtype,
597-
)
598-
max_calibrate(self.input_quantizer, lambda iq: iq(dummy_input), distributed_sync=False)
599-
if _has_state(self.output_quantizer, "_amax"):
600-
self.output_quantizer.reset_amax()
601-
dummy_input = torch.ones(
602-
(1, 1, self.weight0.shape[0]),
603-
device=self.weight0.device,
604-
dtype=self.original_weight_dtype,
605-
)
606-
max_calibrate(self.output_quantizer, lambda oq: oq(dummy_input), distributed_sync=False)
607-
# If there are any other states, lets move them to the correct device
608-
609-
self.weight = None
546+
# GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to
547+
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
548+
# self.weight0 to self.weight to run the quantizer states initialization.
549+
self.weight = self.weight0
610550
super().modelopt_post_restore(prefix=prefix)
551+
# Revert the weight to None after post_restore to avoid the weight being None during forward pass.
552+
self.weight = None
611553

612554
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
613555
# _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,6 @@ def initialize_for_megatron(
429429
context_parallel_size=context_parallel_size,
430430
expert_tensor_parallel_size=expert_tensor_parallel_size,
431431
expert_model_parallel_size=expert_model_parallel_size,
432-
order="tp-ep-dp-pp",
433432
)
434433
model_parallel_cuda_manual_seed(seed)
435434

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ def test_moe_sharded_state_dict(need_8_gpus, tmp_path, config):
549549
)
550550

551551

552-
def _test_grouped_vs_non_grouped_amax_helper(tp_size, ep_size, etp_size, rank, size):
552+
def _test_grouped_vs_non_grouped_quantize_helper(tp_size, ep_size, etp_size, rank, size):
553553
"""Test that grouped and non-grouped MoE models produce similar amax values."""
554554
initialize_for_megatron(
555555
tensor_model_parallel_size=tp_size,
@@ -615,8 +615,8 @@ def forward_fn(model):
615615
assert output_comparison_after, "Outputs are not close after quantization"
616616

617617

618-
def test_grouped_vs_non_grouped_amax():
619-
"""Test that grouped and non-grouped MoE models produce similar amax values."""
618+
def test_grouped_vs_non_grouped_quantize():
619+
"""Test that grouped and non-grouped MoE models produce similar quantized models."""
620620
import time
621621

622622
size = torch.cuda.device_count()
@@ -627,14 +627,22 @@ def test_grouped_vs_non_grouped_amax():
627627
time.sleep(0.1)
628628

629629
spawn_multiprocess_job(
630-
size=size, job=partial(_test_grouped_vs_non_grouped_amax_helper, 1, 2, 2), backend="nccl"
630+
size=size,
631+
job=partial(_test_grouped_vs_non_grouped_quantize_helper, 1, 2, 2),
632+
backend="nccl",
631633
)
632634

633635

634-
def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm):
635-
"""
636-
Test that demonstrates the requirement for expert parallel sync in model_calib.py
637-
"""
636+
def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm, rank, size):
637+
"""Test expert parallel synchronization with different configurations."""
638+
initialize_for_megatron(
639+
tensor_model_parallel_size=1,
640+
pipeline_model_parallel_size=1,
641+
expert_model_parallel_size=ep_size,
642+
expert_tensor_parallel_size=etp_size,
643+
seed=SEED,
644+
)
645+
638646
# Create model with expert parallelism
639647
model = _gpt_model_provider(
640648
tp_size=1,
@@ -664,7 +672,7 @@ def forward_fn(model):
664672
)
665673

666674
# Create inconsistent amax values
667-
rank = torch.distributed.get_rank()
675+
cur_rank = torch.distributed.get_rank()
668676
for name, module in model.named_modules():
669677
if isinstance(module, mtq.nn.TensorQuantizer):
670678
# Check if this is an expert quantizer
@@ -675,7 +683,7 @@ def forward_fn(model):
675683

676684
if is_expert_quantizer and hasattr(module, "_amax"):
677685
# Create rank-specific amax values to simulate missing sync
678-
rank_offset = rank * 0.1
686+
rank_offset = cur_rank * 0.1
679687
module.amax = module.amax + rank_offset
680688

681689
# Determine expert parallel type
@@ -703,21 +711,6 @@ def forward_fn(model):
703711
)
704712

705713

706-
def _test_expert_parallel_sync_helper(ep_size, etp_size, moe_grouped_gemm, rank, size):
707-
"""Test expert parallel synchronization with different configurations."""
708-
initialize_for_megatron(
709-
tensor_model_parallel_size=1,
710-
pipeline_model_parallel_size=1,
711-
context_parallel_size=1,
712-
expert_model_parallel_size=ep_size,
713-
expert_tensor_parallel_size=etp_size,
714-
seed=42 + rank,
715-
)
716-
717-
# Run the actual test
718-
_test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm)
719-
720-
721714
@pytest.mark.parametrize(("ep_size", "etp_size"), [(1, 2), (2, 1), (2, 2)])
722715
@pytest.mark.parametrize("moe_grouped_gemm", [True, False])
723716
def test_expert_parallel_sync(need_4_gpus, ep_size, etp_size, moe_grouped_gemm):
@@ -734,6 +727,6 @@ def test_expert_parallel_sync(need_4_gpus, ep_size, etp_size, moe_grouped_gemm):
734727

735728
spawn_multiprocess_job(
736729
size=total_size,
737-
job=partial(_test_expert_parallel_sync_helper, ep_size, etp_size, moe_grouped_gemm),
730+
job=partial(_test_expert_model_parallel_amax_sync, ep_size, etp_size, moe_grouped_gemm),
738731
backend="nccl",
739732
)

0 commit comments

Comments
 (0)