From 9cac53c0c6d387de6f8bad7a98b455feae4e79b0 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 24 Sep 2025 00:27:11 +0000 Subject: [PATCH 01/28] sync amax in context parallel and awq act scale Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 3 +- .../torch/quantization/plugins/megatron.py | 2 + modelopt/torch/utils/distributed.py | 2 + .../torch_quantization/quantize_common.py | 53 ++++++++++++ tests/gpu/torch/conftest.py | 6 ++ .../quantization/plugins/test_megatron.py | 80 +++++++++++++++++++ 6 files changed, 145 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index f987efcd6..bb49fdf4a 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -84,10 +84,11 @@ def sync_quantizer_amax_across_dp(quantizer, parallel_state): """Synchronize the amax across all ranks in the data parallel group.""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - sync_quantizer_amax_across_dp(_q, parallel_state) + sync_quantizer_amax_across_dp_cp(_q, parallel_state) return if getattr(quantizer, "_amax", None) is not None: quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) + quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group) # TODO: create sync_bias_across_distributed_group for name, module in model.named_modules(): diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 85784d2fe..2876c66b8 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -25,6 +25,7 @@ import torch from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region +from megatron.core.parallel_state import get_data_parallel_group from megatron.core.transformer import MegatronModule from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint from megatron.core.utils import get_tensor_model_parallel_group_if_none @@ -230,6 +231,7 @@ def _setup(self): self.parallel_state = ParallelState( data_parallel_group, mcore_parallel.get_tensor_model_parallel_group(), + mcore_parallel.get_context_parallel_group(), ) super()._setup() diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index f11a736db..1807f4f34 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -241,10 +241,12 @@ def __init__( self, data_parallel_group: torch.distributed.ProcessGroup | int | None = None, tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1, + context_parallel_group: torch.distributed.ProcessGroup | int | None = -1, ): """Initialize the parallel state.""" self.data_parallel_group = DistributedProcessGroup(data_parallel_group) self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group) + self.context_parallel_group = DistributedProcessGroup(context_parallel_group) def __repr__(self) -> str: return ( diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 8647aaa00..5a8ba55b4 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -209,6 +209,59 @@ def forward_loop(model): model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, groups=[dp_group, tp_group] ) +def data_parallel_test_helper(model, config, dp_group): + calib_data = model.get_dummy_input().cuda() + + def forward_loop(model): + model(calib_data) + + model = mtq.quantize(model, config, forward_loop) + + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + +def context_parallel_test_helper(model, config, cp_group): + calib_data = model.get_dummy_input().cuda() + + def forward_loop(model): + model(calib_data) + + model = mtq.quantize(model, config, forward_loop) + + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + +def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): + calib_data = model.get_dummy_input().cuda() + # data should be same across each TP rank + dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) + + def forward_loop(model): + model(calib_data) + + model = mtq.quantize(model, config, forward_loop) + + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) def auto_quantize_helper(model): model, search_state = mtq.auto_quantize( diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index f32065bce..b6a1aa287 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -33,6 +33,12 @@ def need_2_gpus(): if torch.cuda.device_count() < 2: pytest.skip("Need at least 2 GPUs to run this test") +@pytest.fixture +def need_8_gpus(): + if torch.cuda.device_count() < 8: + pytest.skip("Need at least 8 GPUs to run this test") + + @pytest.fixture def need_8_gpus(): diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index b63462ef3..5484cde50 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -39,6 +39,7 @@ from megatron.core.parallel_state import ( destroy_model_parallel, get_data_parallel_group, + get_context_parallel_group, get_tensor_model_parallel_group, ) from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear @@ -226,6 +227,85 @@ def test_data_tensor_context_parallel(need_8_gpus, config): backend="nccl", ) +# 2. Data Parallel Test +def _test_data_parallel_helper(config, rank, size): + # TODO does this model automatically get copied to both DP ranks? + initialize_for_megatron(seed=SEED) + model = MegatronModel().cuda() + + data_parallel_test_helper( + model, config, get_data_parallel_group() + ) + + +@pytest.mark.parametrize( + "config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.W4A8_AWQ_BETA_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + mtq.INT4_AWQ_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_data_parallel(need_2_gpus, config): + spawn_multiprocess_job( + size=2, job=partial(_test_data_parallel_helper, config), backend="nccl" + ) + +# 3. Context Parallel Test +def _test_context_parallel_helper(config, rank, size): + initialize_for_megatron(context_parallel_size=size, seed=SEED) + model = MegatronModel(cp_size=size).cuda() + + context_parallel_test_helper( + model, config, get_context_parallel_group() + ) + +@pytest.mark.parametrize( + "config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.W4A8_AWQ_BETA_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + mtq.INT4_AWQ_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_context_parallel(need_2_gpus, config): + spawn_multiprocess_job( + size=2, job=partial(_test_context_parallel_helper, config), backend="nccl" + ) + +# 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs) +def _test_data_tensor_context_parallel_helper(config, rank, size): + initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED) + model = MegatronModel(tp_size=2, cp_size=2).cuda() + + data_tensor_context_parallel_test_helper( + model, config, get_data_parallel_group(), get_tensor_model_parallel_group(), get_context_parallel_group() + ) + +@pytest.mark.parametrize( + "config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.W4A8_AWQ_BETA_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + mtq.INT4_AWQ_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_data_tensor_context_parallel(need_8_gpus, config): + spawn_multiprocess_job( + size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" + ) def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): """Build the model.""" From be5e8388eafce51b8045ce8b733f1863deac9ba4 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Thu, 25 Sep 2025 18:54:28 +0000 Subject: [PATCH 02/28] lint Signed-off-by: Jennifer Chen --- .../torch/quantization/plugins/megatron.py | 1 + .../torch_quantization/quantize_common.py | 4 +++ tests/gpu/torch/conftest.py | 1 + .../quantization/plugins/test_megatron.py | 29 +++++++++++-------- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 2876c66b8..a48840832 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -26,6 +26,7 @@ from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.parallel_state import get_data_parallel_group +from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint from megatron.core.utils import get_tensor_model_parallel_group_if_none diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 5a8ba55b4..145e48e3b 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -209,6 +209,7 @@ def forward_loop(model): model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, groups=[dp_group, tp_group] ) + def data_parallel_test_helper(model, config, dp_group): calib_data = model.get_dummy_input().cuda() @@ -225,6 +226,7 @@ def forward_loop(model): dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + def context_parallel_test_helper(model, config, cp_group): calib_data = model.get_dummy_input().cuda() @@ -241,6 +243,7 @@ def forward_loop(model): dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): calib_data = model.get_dummy_input().cuda() # data should be same across each TP rank @@ -263,6 +266,7 @@ def forward_loop(model): dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + def auto_quantize_helper(model): model, search_state = mtq.auto_quantize( model, diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index b6a1aa287..7a8223d9c 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -33,6 +33,7 @@ def need_2_gpus(): if torch.cuda.device_count() < 2: pytest.skip("Need at least 2 GPUs to run this test") + @pytest.fixture def need_8_gpus(): if torch.cuda.device_count() < 8: diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 5484cde50..61044f8d6 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -32,14 +32,15 @@ from _test_utils.torch_quantization.quantize_common import ( auto_quantize_helper, data_tensor_context_parallel_test_helper, + tensor_parallel_test_helper, ) skip_if_no_megatron() from megatron.core.parallel_state import ( destroy_model_parallel, - get_data_parallel_group, get_context_parallel_group, + get_data_parallel_group, get_tensor_model_parallel_group, ) from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear @@ -227,15 +228,14 @@ def test_data_tensor_context_parallel(need_8_gpus, config): backend="nccl", ) + # 2. Data Parallel Test def _test_data_parallel_helper(config, rank, size): # TODO does this model automatically get copied to both DP ranks? initialize_for_megatron(seed=SEED) model = MegatronModel().cuda() - data_parallel_test_helper( - model, config, get_data_parallel_group() - ) + data_parallel_test_helper(model, config, get_data_parallel_group()) @pytest.mark.parametrize( @@ -251,18 +251,16 @@ def _test_data_parallel_helper(config, rank, size): ], ) def test_data_parallel(need_2_gpus, config): - spawn_multiprocess_job( - size=2, job=partial(_test_data_parallel_helper, config), backend="nccl" - ) + spawn_multiprocess_job(size=2, job=partial(_test_data_parallel_helper, config), backend="nccl") + # 3. Context Parallel Test def _test_context_parallel_helper(config, rank, size): initialize_for_megatron(context_parallel_size=size, seed=SEED) model = MegatronModel(cp_size=size).cuda() - context_parallel_test_helper( - model, config, get_context_parallel_group() - ) + context_parallel_test_helper(model, config, get_context_parallel_group()) + @pytest.mark.parametrize( "config", @@ -281,15 +279,21 @@ def test_context_parallel(need_2_gpus, config): size=2, job=partial(_test_context_parallel_helper, config), backend="nccl" ) + # 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs) def _test_data_tensor_context_parallel_helper(config, rank, size): initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED) model = MegatronModel(tp_size=2, cp_size=2).cuda() data_tensor_context_parallel_test_helper( - model, config, get_data_parallel_group(), get_tensor_model_parallel_group(), get_context_parallel_group() + model, + config, + get_data_parallel_group(), + get_tensor_model_parallel_group(), + get_context_parallel_group(), ) + @pytest.mark.parametrize( "config", [ @@ -304,9 +308,10 @@ def _test_data_tensor_context_parallel_helper(config, rank, size): ) def test_data_tensor_context_parallel(need_8_gpus, config): spawn_multiprocess_job( - size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" + size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" ) + def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): """Build the model.""" From 4a2a8d7ec5ec8b117f6eb10b62093d39132b300d Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Thu, 25 Sep 2025 21:46:30 +0000 Subject: [PATCH 03/28] test weight quantizer too Signed-off-by: Jennifer Chen --- .../torch_quantization/quantize_common.py | 64 ++++++++++++++----- 1 file changed, 47 insertions(+), 17 deletions(-) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 145e48e3b..2a6a7ea8e 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -218,13 +218,22 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - fc1_amax = model.fc1.input_quantizer.amax.clone() + # Input quantizer amax + if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + + # Weight quantizer amax + fc1_amax = model.fc1.weight_quantizer.amax.clone() dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - - fc2_amax = model.fc2.input_quantizer.amax.clone() + assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax) + fc2_amax = model.fc2.weight_quantizer.amax.clone() dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax) def context_parallel_test_helper(model, config, cp_group): @@ -235,13 +244,22 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - fc1_amax = model.fc1.input_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - - fc2_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + # Input quantizer amax + if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + + # Weight quantizer amax + fc1_weight_amax = model.fc1.weight_quantizer.amax.clone() + dist.all_reduce(fc1_weight_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc1_weight_amax, model.fc1.weight_quantizer.amax) + fc2_weight_amax = model.fc2.weight_quantizer.amax.clone() + dist.all_reduce(fc2_weight_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc2_weight_amax, model.fc2.weight_quantizer.amax) def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): @@ -254,17 +272,29 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - fc1_amax = model.fc1.input_quantizer.amax.clone() + # Input quantizer amax + if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + + fc1_amax = model.fc1.weight_quantizer.amax.clone() dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group) dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - - fc2_amax = model.fc2.input_quantizer.amax.clone() + assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax) + fc2_amax = model.fc2.weight_quantizer.amax.clone() dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group) dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax) def auto_quantize_helper(model): From cacee61301dfef159eb155f1ede2acc792056edd Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Fri, 26 Sep 2025 01:07:13 +0000 Subject: [PATCH 04/28] fix test Signed-off-by: Jennifer Chen --- .../torch_quantization/quantize_common.py | 97 ++++++++----------- .../quantization/plugins/test_megatron.py | 5 +- 2 files changed, 41 insertions(+), 61 deletions(-) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 2a6a7ea8e..1c9d7aae0 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -210,7 +210,7 @@ def forward_loop(model): ) -def data_parallel_test_helper(model, config, dp_group): +def dp_cp_parallel_test_helper(model, config, group): calib_data = model.get_dummy_input().cuda() def forward_loop(model): @@ -218,48 +218,27 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - # Input quantizer amax - if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - fc1_amax = model.fc1.input_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - fc2_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) - - # Weight quantizer amax - fc1_amax = model.fc1.weight_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax) - fc2_amax = model.fc2.weight_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax) - - -def context_parallel_test_helper(model, config, cp_group): - calib_data = model.get_dummy_input().cuda() - - def forward_loop(model): - model(calib_data) - - model = mtq.quantize(model, config, forward_loop) + def reduce_amax(quantizer): + amax = quantizer.amax.clone() + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group) + assert torch.allclose(amax, quantizer.amax) # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - fc1_amax = model.fc1.input_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - fc2_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + reduce_amax(model.fc1.input_quantizer) + reduce_amax(model.fc2.input_quantizer) # Weight quantizer amax - fc1_weight_amax = model.fc1.weight_quantizer.amax.clone() - dist.all_reduce(fc1_weight_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc1_weight_amax, model.fc1.weight_quantizer.amax) - fc2_weight_amax = model.fc2.weight_quantizer.amax.clone() - dist.all_reduce(fc2_weight_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc2_weight_amax, model.fc2.weight_quantizer.amax) + if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc1.weight_quantizer: + reduce_amax(quantizer) + else: + reduce_amax(model.fc1.weight_quantizer) + if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc2.weight_quantizer: + reduce_amax(quantizer) + else: + reduce_amax(model.fc2.weight_quantizer) def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): @@ -272,29 +251,29 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) + def reduce_amax(quantizer): + amax = quantizer.amax.clone() + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(amax, quantizer.amax) + # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - fc1_amax = model.fc1.input_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group) - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - fc2_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group) - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) - - fc1_amax = model.fc1.weight_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group) - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax) - fc2_amax = model.fc2.weight_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group) - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax) + reduce_amax(model.fc1.input_quantizer) + reduce_amax(model.fc2.input_quantizer) + + if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc1.weight_quantizer: + reduce_amax(quantizer) + else: + reduce_amax(model.fc1.weight_quantizer) + + if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc2.weight_quantizer: + reduce_amax(quantizer) + else: + reduce_amax(model.fc2.weight_quantizer) def auto_quantize_helper(model): diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 61044f8d6..ada356d9c 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -32,6 +32,7 @@ from _test_utils.torch_quantization.quantize_common import ( auto_quantize_helper, data_tensor_context_parallel_test_helper, + dp_cp_parallel_test_helper, tensor_parallel_test_helper, ) @@ -235,7 +236,7 @@ def _test_data_parallel_helper(config, rank, size): initialize_for_megatron(seed=SEED) model = MegatronModel().cuda() - data_parallel_test_helper(model, config, get_data_parallel_group()) + dp_cp_parallel_test_helper(model, config, get_data_parallel_group()) @pytest.mark.parametrize( @@ -259,7 +260,7 @@ def _test_context_parallel_helper(config, rank, size): initialize_for_megatron(context_parallel_size=size, seed=SEED) model = MegatronModel(cp_size=size).cuda() - context_parallel_test_helper(model, config, get_context_parallel_group()) + dp_cp_parallel_test_helper(model, config, get_context_parallel_group()) @pytest.mark.parametrize( From 41cc9bd7330e1e2efe918c1fa4f76866f74747f9 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Mon, 29 Sep 2025 21:38:48 +0000 Subject: [PATCH 05/28] awq test Signed-off-by: Jennifer Chen --- .../torch_quantization/quantize_common.py | 8 +++-- .../quantization/plugins/test_megatron.py | 7 ++-- .../torch/quantization/test_model_calib.py | 33 +++++++++++++++++++ 3 files changed, 43 insertions(+), 5 deletions(-) create mode 100644 tests/gpu/torch/quantization/test_model_calib.py diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 1c9d7aae0..478ce9d5f 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -253,9 +253,13 @@ def forward_loop(model): def reduce_amax(quantizer): amax = quantizer.amax.clone() - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group) + print("amax before reduce", amax) + print("quantizer.amax before reduce", quantizer.amax) dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group) + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group) + print("amax after reduce", amax) + print("quantizer.amax after reduce", quantizer.amax) assert torch.allclose(amax, quantizer.amax) # Input quantizer amax diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index ada356d9c..8880ed69d 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -232,8 +232,7 @@ def test_data_tensor_context_parallel(need_8_gpus, config): # 2. Data Parallel Test def _test_data_parallel_helper(config, rank, size): - # TODO does this model automatically get copied to both DP ranks? - initialize_for_megatron(seed=SEED) + initialize_for_megatron(seed=SEED + rank) # modify seed so data is different across ranks model = MegatronModel().cuda() dp_cp_parallel_test_helper(model, config, get_data_parallel_group()) @@ -257,7 +256,9 @@ def test_data_parallel(need_2_gpus, config): # 3. Context Parallel Test def _test_context_parallel_helper(config, rank, size): - initialize_for_megatron(context_parallel_size=size, seed=SEED) + initialize_for_megatron( + context_parallel_size=size, seed=SEED + rank + ) # modify seed so data is different across ranks model = MegatronModel(cp_size=size).cuda() dp_cp_parallel_test_helper(model, config, get_context_parallel_group()) diff --git a/tests/gpu/torch/quantization/test_model_calib.py b/tests/gpu/torch/quantization/test_model_calib.py new file mode 100644 index 000000000..2179b8532 --- /dev/null +++ b/tests/gpu/torch/quantization/test_model_calib.py @@ -0,0 +1,33 @@ +import torch +import torch.distributed as dist +from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job +from _test_utils.torch_dist.plugins.megatron_common import MegatronModel, initialize_for_megatron +from megatron.core.parallel_state import get_data_parallel_group + +from modelopt.torch.quantization.model_calib import awq_lite + + +def _test_awq_lite_act_scale_sync_helper(rank, size): + initialize_for_megatron(seed=1234 + rank) + model = MegatronModel().cuda() + + calib_data = model.get_dummy_input().cuda() + + def forward_loop(model): + model(calib_data) + + model = awq_lite(model, forward_loop) + # Sanity check + forward_loop(model) + + act_scale = model.fc1.weight_quantizer.awq_lite.act_scale.clone() + dist.all_reduce(act_scale, op=dist.ReduceOp.AVG, group=get_data_parallel_group()) + assert torch.allclose(act_scale, model.fc1.weight_quantizer.awq_lite.act_scale) + + act_scale = model.fc2.weight_quantizer.awq_lite.act_scale.clone() + dist.all_reduce(act_scale, op=dist.ReduceOp.AVG, group=get_data_parallel_group()) + assert torch.allclose(act_scale, model.fc2.weight_quantizer.awq_lite.act_scale) + + +def test_awq_lite_act_scale_sync(need_2_gpus): + spawn_multiprocess_job(size=2, job=_test_awq_lite_act_scale_sync_helper, backend="nccl") From 4a706eff8a4c411333a768f44debcf6974ce50db Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Mon, 29 Sep 2025 23:17:00 +0000 Subject: [PATCH 06/28] move awq test inside megatron tests Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 1 + .../torch_quantization/quantize_common.py | 83 +++++++++++++------ .../quantization/plugins/test_megatron.py | 2 +- .../torch/quantization/test_model_calib.py | 33 -------- 4 files changed, 58 insertions(+), 61 deletions(-) delete mode 100644 tests/gpu/torch/quantization/test_model_calib.py diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index bb49fdf4a..d2072e748 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -582,6 +582,7 @@ def forward(self, input, *args, **kwargs): return out_actual for name, module in model.named_modules(): + print(name, module, module.weight_quantizer.is_enabled) if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model): module.awq_lite = AWQLiteHelper(module, name) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 478ce9d5f..0d14ce91a 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -218,27 +218,37 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - def reduce_amax(quantizer): - amax = quantizer.amax.clone() - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group) - assert torch.allclose(amax, quantizer.amax) - # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - reduce_amax(model.fc1.input_quantizer) - reduce_amax(model.fc2.input_quantizer) + _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) # Weight quantizer amax if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): for quantizer in model.fc1.weight_quantizer: - reduce_amax(quantizer) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) else: - reduce_amax(model.fc1.weight_quantizer) + _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): for quantizer in model.fc2.weight_quantizer: - reduce_amax(quantizer) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) else: - reduce_amax(model.fc2.weight_quantizer) + _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) + + if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: + # Check act scale + _reduce_quantizer_attr( + model.fc1.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=group, + ) + _reduce_quantizer_attr( + model.fc2.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=group, + ) def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): @@ -251,33 +261,52 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - def reduce_amax(quantizer): - amax = quantizer.amax.clone() - print("amax before reduce", amax) - print("quantizer.amax before reduce", quantizer.amax) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group) - print("amax after reduce", amax) - print("quantizer.amax after reduce", quantizer.amax) - assert torch.allclose(amax, quantizer.amax) + def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): + quantizer_attr = getattr(quantizer, attr).clone() + print("quantizer_attr before reduce", quantizer_attr) + print("quantizer.attr before reduce", getattr(quantizer, attr)) + dist.all_reduce(quantizer_attr, op=op, group=dp_group) + dist.all_reduce(quantizer_attr, op=op, group=cp_group) + dist.all_reduce(quantizer_attr, op=op, group=tp_group) + print("quantizer_attr after reduce", quantizer_attr) + print("quantizer.attr after reduce", getattr(quantizer, attr)) + assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - reduce_amax(model.fc1.input_quantizer) - reduce_amax(model.fc2.input_quantizer) + _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): for quantizer in model.fc1.weight_quantizer: - reduce_amax(quantizer) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) else: - reduce_amax(model.fc1.weight_quantizer) + _reduce_quantizer_attr( + model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group + ) if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): for quantizer in model.fc2.weight_quantizer: - reduce_amax(quantizer) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) else: - reduce_amax(model.fc2.weight_quantizer) + _reduce_quantizer_attr( + model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group + ) + + # Check act scale + if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: + _reduce_quantizer_attr( + model.fc1.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=tp_group, + ) + _reduce_quantizer_attr( + model.fc2.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=tp_group, + ) def auto_quantize_helper(model): diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 8880ed69d..6c039c025 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -284,7 +284,7 @@ def test_context_parallel(need_2_gpus, config): # 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs) def _test_data_tensor_context_parallel_helper(config, rank, size): - initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED) + initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED + rank) model = MegatronModel(tp_size=2, cp_size=2).cuda() data_tensor_context_parallel_test_helper( diff --git a/tests/gpu/torch/quantization/test_model_calib.py b/tests/gpu/torch/quantization/test_model_calib.py deleted file mode 100644 index 2179b8532..000000000 --- a/tests/gpu/torch/quantization/test_model_calib.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch -import torch.distributed as dist -from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job -from _test_utils.torch_dist.plugins.megatron_common import MegatronModel, initialize_for_megatron -from megatron.core.parallel_state import get_data_parallel_group - -from modelopt.torch.quantization.model_calib import awq_lite - - -def _test_awq_lite_act_scale_sync_helper(rank, size): - initialize_for_megatron(seed=1234 + rank) - model = MegatronModel().cuda() - - calib_data = model.get_dummy_input().cuda() - - def forward_loop(model): - model(calib_data) - - model = awq_lite(model, forward_loop) - # Sanity check - forward_loop(model) - - act_scale = model.fc1.weight_quantizer.awq_lite.act_scale.clone() - dist.all_reduce(act_scale, op=dist.ReduceOp.AVG, group=get_data_parallel_group()) - assert torch.allclose(act_scale, model.fc1.weight_quantizer.awq_lite.act_scale) - - act_scale = model.fc2.weight_quantizer.awq_lite.act_scale.clone() - dist.all_reduce(act_scale, op=dist.ReduceOp.AVG, group=get_data_parallel_group()) - assert torch.allclose(act_scale, model.fc2.weight_quantizer.awq_lite.act_scale) - - -def test_awq_lite_act_scale_sync(need_2_gpus): - spawn_multiprocess_job(size=2, job=_test_awq_lite_act_scale_sync_helper, backend="nccl") From 7b2c9698d484b96ab63c625abb7cbc120aa4c5db Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 30 Sep 2025 00:20:17 +0000 Subject: [PATCH 07/28] fix amax tests Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 1 - .../torch_quantization/quantize_common.py | 16 ++++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index d2072e748..bb49fdf4a 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -582,7 +582,6 @@ def forward(self, input, *args, **kwargs): return out_actual for name, module in model.named_modules(): - print(name, module, module.weight_quantizer.is_enabled) if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model): module.awq_lite = AWQLiteHelper(module, name) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 0d14ce91a..6427bf923 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -210,7 +210,8 @@ def forward_loop(model): ) -def dp_cp_parallel_test_helper(model, config, group): +@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) +def dp_cp_parallel_test_helper(model, config, group, mock_awq_lite): calib_data = model.get_dummy_input().cuda() def forward_loop(model): @@ -238,20 +239,23 @@ def forward_loop(model): if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: # Check act scale _reduce_quantizer_attr( - model.fc1.weight_quantizer.awq_lite.act_scale, + model.fc1.weight_quantizer.awq_lite, "act_scale", dist.ReduceOp.AVG, group=group, ) _reduce_quantizer_attr( - model.fc2.weight_quantizer.awq_lite.act_scale, + model.fc2.weight_quantizer.awq_lite, "act_scale", dist.ReduceOp.AVG, group=group, ) -def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): +@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) +def data_tensor_context_parallel_test_helper( + model, config, dp_group, tp_group, cp_group, mock_awq_lite +): calib_data = model.get_dummy_input().cuda() # data should be same across each TP rank dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) @@ -296,13 +300,13 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): # Check act scale if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: _reduce_quantizer_attr( - model.fc1.weight_quantizer.awq_lite.act_scale, + model.fc1.weight_quantizer.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, ) _reduce_quantizer_attr( - model.fc2.weight_quantizer.awq_lite.act_scale, + model.fc2.weight_quantizer.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, From e6dc5e58a7d39af021d34596a7cc326a1dcb9aff Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 30 Sep 2025 00:21:38 +0000 Subject: [PATCH 08/28] fix awq lite param Signed-off-by: Jennifer Chen --- tests/_test_utils/torch_quantization/quantize_common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 6427bf923..70133832e 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -239,13 +239,13 @@ def forward_loop(model): if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: # Check act scale _reduce_quantizer_attr( - model.fc1.weight_quantizer.awq_lite, + model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, group=group, ) _reduce_quantizer_attr( - model.fc2.weight_quantizer.awq_lite, + model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, group=group, @@ -300,13 +300,13 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): # Check act scale if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: _reduce_quantizer_attr( - model.fc1.weight_quantizer.awq_lite, + model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, ) _reduce_quantizer_attr( - model.fc2.weight_quantizer.awq_lite, + model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, From f17320ea43b8b8655e050f8a97d6a3aa2c9707ed Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 30 Sep 2025 19:04:19 +0000 Subject: [PATCH 09/28] fix test Signed-off-by: Jennifer Chen --- .../torch_quantization/quantize_common.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 70133832e..1536683db 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -278,24 +278,20 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) - _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) + _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX) if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): for quantizer in model.fc1.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) else: - _reduce_quantizer_attr( - model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group - ) + _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): for quantizer in model.fc2.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) else: - _reduce_quantizer_attr( - model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group - ) + _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) # Check act scale if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: @@ -303,13 +299,11 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, - group=tp_group, ) _reduce_quantizer_attr( model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, - group=tp_group, ) From a1fdf180fbf89d699d714b06724bbdabb6afd80a Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 1 Oct 2025 01:07:14 +0000 Subject: [PATCH 10/28] add print Signed-off-by: Jennifer Chen --- .../torch_quantization/quantize_common.py | 81 ++++++++++++++++--- .../quantization/plugins/test_megatron.py | 6 +- 2 files changed, 73 insertions(+), 14 deletions(-) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 1536683db..b48db355e 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -253,9 +253,46 @@ def forward_loop(model): @patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) -def data_tensor_context_parallel_test_helper( - model, config, dp_group, tp_group, cp_group, mock_awq_lite -): +def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, mock_awq_lite): + # Print rank information for debugging + world_rank = dist.get_rank() + world_size = dist.get_world_size() + + print("\n=== RANK INFORMATION ===") + print(f"World Rank: {world_rank}, World Size: {world_size}") + + # Get group information with actual ranks + def get_group_ranks(group): + if group is None: + return None + ranks = [] + ranks = [ + i for i in range(world_size) if dist.get_rank(group=group) == dist.get_rank(group=group) + ] + return ranks + + if dp_group is not None: + dp_rank = dist.get_rank(group=dp_group) + dp_size = dist.get_world_size(group=dp_group) + print(f"DP Group - Rank: {dp_rank}, Size: {dp_size}") + + if tp_group is not None: + tp_rank = dist.get_rank(group=tp_group) + tp_size = dist.get_world_size(group=tp_group) + print(f"TP Group - Rank: {tp_rank}, Size: {tp_size}") + + print("=== END RANK INFO ===\n") + + # Print a summary of all ranks + print("=== ALL RANKS SUMMARY ===") + print(f"Total GPUs: {world_size}") + print(f"Current rank: {world_rank}") + if dp_group is not None: + print(f"DP groups: {dp_size} groups of {world_size // dp_size} ranks each") + if tp_group is not None: + print(f"TP groups: {tp_size} groups of {world_size // tp_size} ranks each") + print("=== END SUMMARY ===\n") + calib_data = model.get_dummy_input().cuda() # data should be same across each TP rank dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) @@ -266,14 +303,38 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): + world_rank = dist.get_rank() + print(f"\n--- Rank {world_rank}: Reducing {attr} ---") + from megatron.core.parallel_state import ( + _CONTEXT_PARALLEL_GLOBAL_RANKS, + _DATA_PARALLEL_GLOBAL_RANKS, + _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP, + _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS, + ) + + print(f"DATA_PARALLEL_GLOBAL_RANKS: {_DATA_PARALLEL_GLOBAL_RANKS}") + print(f"CONTEXT_PARALLEL_GLOBAL_RANKS: {_CONTEXT_PARALLEL_GLOBAL_RANKS}") + print(f"DATA_PARALLEL_GLOBAL_RANKS_WITH_CP: {_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP}") + print(f"TENSOR_MODEL_PARALLEL_GLOBAL_RANKS: {_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS}") quantizer_attr = getattr(quantizer, attr).clone() - print("quantizer_attr before reduce", quantizer_attr) - print("quantizer.attr before reduce", getattr(quantizer, attr)) - dist.all_reduce(quantizer_attr, op=op, group=dp_group) - dist.all_reduce(quantizer_attr, op=op, group=cp_group) - dist.all_reduce(quantizer_attr, op=op, group=tp_group) - print("quantizer_attr after reduce", quantizer_attr) - print("quantizer.attr after reduce", getattr(quantizer, attr)) + print(f"Rank {world_rank} - quantizer_attr before reduce", quantizer_attr) + print(f"Rank {world_rank} - quantizer.attr before reduce", getattr(quantizer, attr)) + + # Perform all-reduce operations + if tp_group is not None: + tp_rank = dist.get_rank(group=tp_group) + print(f"Rank {world_rank} - TP reduce (TP rank {tp_rank})") + dist.all_reduce(quantizer_attr, op=op, group=tp_group) + + if dp_group is not None: + dp_rank = dist.get_rank(group=dp_group) + print(f"Rank {world_rank} - DP reduce (DP rank {dp_rank})") + dist.all_reduce(quantizer_attr, op=op, group=dp_group) + + print(f"Rank {world_rank} - quantizer_attr after reduce", quantizer_attr) + print(f"Rank {world_rank} - quantizer.attr after reduce", getattr(quantizer, attr)) + print(f"--- End Rank {world_rank} ---\n") + assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) # Input quantizer amax diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 6c039c025..033e2f41a 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -40,7 +40,6 @@ from megatron.core.parallel_state import ( destroy_model_parallel, - get_context_parallel_group, get_data_parallel_group, get_tensor_model_parallel_group, ) @@ -261,7 +260,7 @@ def _test_context_parallel_helper(config, rank, size): ) # modify seed so data is different across ranks model = MegatronModel(cp_size=size).cuda() - dp_cp_parallel_test_helper(model, config, get_context_parallel_group()) + dp_cp_parallel_test_helper(model, config, get_data_parallel_group(with_context_parallel=True)) @pytest.mark.parametrize( @@ -290,9 +289,8 @@ def _test_data_tensor_context_parallel_helper(config, rank, size): data_tensor_context_parallel_test_helper( model, config, - get_data_parallel_group(), + get_data_parallel_group(with_context_parallel=True), get_tensor_model_parallel_group(), - get_context_parallel_group(), ) From cd3115968485da0487a7b9d0bfa9e7957600db3a Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 1 Oct 2025 20:58:09 +0000 Subject: [PATCH 11/28] docstring Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index bb49fdf4a..c8192a69d 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -118,6 +118,7 @@ def sync_quantizer_amax_across_tp( # Syncing amax across TP for sequential quantizer if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: + "Syncing amax across TP for sequential quantizer" sync_quantizer_amax_across_tp( _q, linear_name, quantizer_type, axes_for_sync, parallel_state ) From 5a67acfb4447974fd9d384b3516746b46ecc755e Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Thu, 2 Oct 2025 00:09:31 +0000 Subject: [PATCH 12/28] fix tests Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 3 +- .../torch/quantization/plugins/megatron.py | 1 - modelopt/torch/utils/distributed.py | 2 - .../torch_quantization/quantize_common.py | 109 ++++-------------- .../quantization/plugins/test_megatron.py | 2 +- 5 files changed, 26 insertions(+), 91 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index c8192a69d..5d2a24c1f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -84,11 +84,10 @@ def sync_quantizer_amax_across_dp(quantizer, parallel_state): """Synchronize the amax across all ranks in the data parallel group.""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - sync_quantizer_amax_across_dp_cp(_q, parallel_state) + sync_quantizer_amax_across_dp(_q, parallel_state) return if getattr(quantizer, "_amax", None) is not None: quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) - quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group) # TODO: create sync_bias_across_distributed_group for name, module in model.named_modules(): diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index a48840832..77709baba 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -232,7 +232,6 @@ def _setup(self): self.parallel_state = ParallelState( data_parallel_group, mcore_parallel.get_tensor_model_parallel_group(), - mcore_parallel.get_context_parallel_group(), ) super()._setup() diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 1807f4f34..f11a736db 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -241,12 +241,10 @@ def __init__( self, data_parallel_group: torch.distributed.ProcessGroup | int | None = None, tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1, - context_parallel_group: torch.distributed.ProcessGroup | int | None = -1, ): """Initialize the parallel state.""" self.data_parallel_group = DistributedProcessGroup(data_parallel_group) self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group) - self.context_parallel_group = DistributedProcessGroup(context_parallel_group) def __repr__(self) -> str: return ( diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index b48db355e..94e0ccc3c 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -219,6 +219,9 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) + # Sanity check + forward_loop(model) + # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) @@ -254,48 +257,9 @@ def forward_loop(model): @patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, mock_awq_lite): - # Print rank information for debugging - world_rank = dist.get_rank() - world_size = dist.get_world_size() - - print("\n=== RANK INFORMATION ===") - print(f"World Rank: {world_rank}, World Size: {world_size}") - - # Get group information with actual ranks - def get_group_ranks(group): - if group is None: - return None - ranks = [] - ranks = [ - i for i in range(world_size) if dist.get_rank(group=group) == dist.get_rank(group=group) - ] - return ranks - - if dp_group is not None: - dp_rank = dist.get_rank(group=dp_group) - dp_size = dist.get_world_size(group=dp_group) - print(f"DP Group - Rank: {dp_rank}, Size: {dp_size}") - - if tp_group is not None: - tp_rank = dist.get_rank(group=tp_group) - tp_size = dist.get_world_size(group=tp_group) - print(f"TP Group - Rank: {tp_rank}, Size: {tp_size}") - - print("=== END RANK INFO ===\n") - - # Print a summary of all ranks - print("=== ALL RANKS SUMMARY ===") - print(f"Total GPUs: {world_size}") - print(f"Current rank: {world_rank}") - if dp_group is not None: - print(f"DP groups: {dp_size} groups of {world_size // dp_size} ranks each") - if tp_group is not None: - print(f"TP groups: {tp_size} groups of {world_size // tp_size} ranks each") - print("=== END SUMMARY ===\n") - - calib_data = model.get_dummy_input().cuda() - # data should be same across each TP rank - dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) + # Calib data should be same across each DP rank + dp_rank = dist.get_rank(group=dp_group) + calib_data = model.get_dummy_input(seed=dp_rank).cuda() def forward_loop(model): model(calib_data) @@ -303,56 +267,36 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): - world_rank = dist.get_rank() - print(f"\n--- Rank {world_rank}: Reducing {attr} ---") - from megatron.core.parallel_state import ( - _CONTEXT_PARALLEL_GLOBAL_RANKS, - _DATA_PARALLEL_GLOBAL_RANKS, - _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP, - _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS, - ) - - print(f"DATA_PARALLEL_GLOBAL_RANKS: {_DATA_PARALLEL_GLOBAL_RANKS}") - print(f"CONTEXT_PARALLEL_GLOBAL_RANKS: {_CONTEXT_PARALLEL_GLOBAL_RANKS}") - print(f"DATA_PARALLEL_GLOBAL_RANKS_WITH_CP: {_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP}") - print(f"TENSOR_MODEL_PARALLEL_GLOBAL_RANKS: {_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS}") quantizer_attr = getattr(quantizer, attr).clone() - print(f"Rank {world_rank} - quantizer_attr before reduce", quantizer_attr) - print(f"Rank {world_rank} - quantizer.attr before reduce", getattr(quantizer, attr)) # Perform all-reduce operations - if tp_group is not None: - tp_rank = dist.get_rank(group=tp_group) - print(f"Rank {world_rank} - TP reduce (TP rank {tp_rank})") - dist.all_reduce(quantizer_attr, op=op, group=tp_group) + dist.all_reduce(quantizer_attr, op=op, group=tp_group) - if dp_group is not None: - dp_rank = dist.get_rank(group=dp_group) - print(f"Rank {world_rank} - DP reduce (DP rank {dp_rank})") - dist.all_reduce(quantizer_attr, op=op, group=dp_group) + dist.all_reduce(quantizer_attr, op=op, group=dp_group) - print(f"Rank {world_rank} - quantizer_attr after reduce", quantizer_attr) - print(f"Rank {world_rank} - quantizer.attr after reduce", getattr(quantizer, attr)) - print(f"--- End Rank {world_rank} ---\n") - - assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) + assert torch.allclose(quantizer_attr, getattr(quantizer, attr)), getattr(quantizer, attr) # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX) - if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): - for quantizer in model.fc1.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) - else: - _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) + # Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks + # Channel-wise (INT8) only expects same amax across row parallel ranks + # Block-wise quantization does not expect same amax across row and column parallel ranks + if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]: + if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc1.weight_quantizer: + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) + else: + _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) - if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): - for quantizer in model.fc2.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) - else: - _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) + if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT8_DEFAULT_CFG]: + if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc2.weight_quantizer: + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) + else: + _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) # Check act scale if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: @@ -361,11 +305,6 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): "act_scale", dist.ReduceOp.AVG, ) - _reduce_quantizer_attr( - model.fc2.awq_lite, - "act_scale", - dist.ReduceOp.AVG, - ) def auto_quantize_helper(model): diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 033e2f41a..5462d3d71 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -308,7 +308,7 @@ def _test_data_tensor_context_parallel_helper(config, rank, size): ) def test_data_tensor_context_parallel(need_8_gpus, config): spawn_multiprocess_job( - size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" + size=4, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" ) From 9d7dff127f29705fc125ea0426039a14a894e03b Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Thu, 2 Oct 2025 21:17:08 +0000 Subject: [PATCH 13/28] fix multiprocess size Signed-off-by: Jennifer Chen --- tests/gpu/torch/quantization/plugins/test_megatron.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 5462d3d71..033e2f41a 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -308,7 +308,7 @@ def _test_data_tensor_context_parallel_helper(config, rank, size): ) def test_data_tensor_context_parallel(need_8_gpus, config): spawn_multiprocess_job( - size=4, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" + size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" ) From 3bf16e6bd06aa90ded10498819749beb30269635 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Tue, 7 Oct 2025 02:24:14 +0000 Subject: [PATCH 14/28] Added quantization support for TEGroupedMoE for megatron-lm Signed-off-by: Kinjal Patel --- modelopt/torch/quantization/model_calib.py | 13 +- .../torch/quantization/plugins/megatron.py | 172 +++++++++- modelopt/torch/utils/distributed.py | 14 +- .../torch_dist/plugins/megatron_common.py | 226 ++++++++++++- tests/gpu/torch/conftest.py | 1 - .../quantization/plugins/test_megatron.py | 302 +++++++++++++++++- 6 files changed, 710 insertions(+), 18 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 5d2a24c1f..f50e7ae29 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -80,21 +80,26 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis if not distributed_sync: return - def sync_quantizer_amax_across_dp(quantizer, parallel_state): - """Synchronize the amax across all ranks in the data parallel group.""" + def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): + """Synchronize the amax across all ranks in the data parallel and context parallel groups.""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - sync_quantizer_amax_across_dp(_q, parallel_state) + sync_quantizer_amax_across_dp_ep(_q, parallel_state) return if getattr(quantizer, "_amax", None) is not None: quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) + quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) + if parallel_state.expert_tensor_parallel_group is not None: + quantizer.sync_amax_across_distributed_group( + parallel_state.expert_tensor_parallel_group + ) # TODO: create sync_bias_across_distributed_group for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): - sync_quantizer_amax_across_dp(child, module.parallel_state) + sync_quantizer_amax_across_dp_ep(child, module.parallel_state) # TP sync: # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 77709baba..eb948a48d 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -23,6 +23,8 @@ import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp import torch +import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear +from megatron.core.extensions import transformer_engine as megatron_te from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.parallel_state import get_data_parallel_group @@ -37,8 +39,8 @@ ) from modelopt.torch.utils.distributed import ParallelState -from ..nn import QuantModuleRegistry, TensorQuantizer -from ..nn.modules.quant_linear import RealQuantLinear +from ..nn import QuantModuleRegistry, SequentialQuantizer, TensorQuantizer +from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear @@ -229,9 +231,17 @@ def _setup(self): except AssertionError: logger.warning("Context parallel group is not initialized, using data parallel group") data_parallel_group = get_data_parallel_group() + + try: + expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() + except AssertionError: + expert_tensor_parallel_group = None + self.parallel_state = ParallelState( data_parallel_group, mcore_parallel.get_tensor_model_parallel_group(), + mcore_parallel.get_expert_model_parallel_group(), + expert_tensor_parallel_group, ) super()._setup() @@ -474,3 +484,161 @@ class _RealQuantMegatronRowParallelLinear( def forward(self, input, *args, **kwargs): return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs) + + +# Register the public te.pytorch.GroupedLinear class +@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear_public"}) +class _QuantTEGroupedLinear(_MegatronParallelLinear): + def _setup(self): + data_parallel_group = None + try: + data_parallel_group = get_data_parallel_group(with_context_parallel=True) + except AssertionError: + data_parallel_group = get_data_parallel_group() + + try: + expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() + except AssertionError: + expert_tensor_parallel_group = None + self.parallel_state = ParallelState( + data_parallel_group, + mcore_parallel.get_tensor_model_parallel_group(), + mcore_parallel.get_context_parallel_group(), + mcore_parallel.get_expert_model_parallel_group(), + expert_tensor_parallel_group, + ) + self.input_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_input) + self.weight_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_weight) + self.output_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_output) + self.output_quantizer.disable() + + # Memorize the original weight.dtype for modelopt_post_restore given that + # the dtype can change later. + self.original_weight_dtype = None if self.weight0 is None else self.weight0.dtype + + @property + def functionals_to_replace(self): + original_forward = te_grouped_linear._GroupedLinear.forward + + def te_grouped_quantized_linear_fn(ctx, inp, m_splits, *args): + num_gemms = len(m_splits) + weights_and_biases = args[-2 * num_gemms :] + weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] + quantized_inputs = self.input_quantizer(inp) + quantized_weights = [self.weight_quantizer(weight) for weight in weights] + + output = original_forward( + ctx, + quantized_inputs, + m_splits, + *args[: -2 * num_gemms], + *quantized_weights, + *biases, + ) + return self.output_quantizer(output) + + return [ + ( + te_grouped_linear._GroupedLinear, + "forward", + te_grouped_quantized_linear_fn, + ), + ] + + def modelopt_post_restore(self, prefix: str = ""): + """Post restore to correctly configure the TensorQuantizer states for MCore/distributed frameworks. + + ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their + shape before saving. However this is not enough for MCore/distributed frameworks since the tensor parallelism + could change between saving and restoring. If the tensor parallelism changes, the shape of the quantizer + states also changes. So we need to re-calculate the quantizer states. + """ + from modelopt.torch.quantization.model_calib import max_calibrate + + def _check_unsupported_states(quantizer: TensorQuantizer): + for k in quantizer.state_dict(): + if k not in ["_amax", "_pre_quant_scale"]: + warnings.warn( + f"Restore of {k} for {prefix} is not supported. The restore of this layer might be " + f"incorrect. Please implement a custom restore for {k}." + ) + + def _has_state(quantizer, name): + # Handling for SequentialQuantizer + quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer + return hasattr(quantizer, name) + + # weights for TEGroupedLinear are stored in weight0, weight1, etc. + if self.weight0 is None: + return + for quantizer in [self.weight_quantizer, self.input_quantizer, self.output_quantizer]: + _check_unsupported_states( + quantizer if isinstance(quantizer, TensorQuantizer) else quantizer[0] + ) + if _has_state(self.weight_quantizer, "_amax"): + self.weight_quantizer.reset_amax() + for i in range(self.num_gemms): + weight = getattr(self, f"weight{i}") + assert weight is not None, "weight is None" + + max_calibrate(self.weight_quantizer, lambda wq: wq(weight), distributed_sync=False) + if _has_state(self.input_quantizer, "_pre_quant_scale"): + if hasattr(self.input_quantizer, "_pre_quant_scale"): + delattr(self.input_quantizer, "_pre_quant_scale") + pqs = torch.zeros( + (weight.shape[1]), device=weight.device, dtype=self.original_weight_dtype + ) + self.input_quantizer.register_buffer("_pre_quant_scale", pqs) + + if _has_state(self.input_quantizer, "_amax"): + self.input_quantizer.reset_amax() + dummy_input = torch.ones( + (1, 1, self.weight0.shape[1]), + device=self.weight0.device, + dtype=self.original_weight_dtype, + ) + max_calibrate(self.input_quantizer, lambda iq: iq(dummy_input), distributed_sync=False) + if _has_state(self.output_quantizer, "_amax"): + self.output_quantizer.reset_amax() + dummy_input = torch.ones( + (1, 1, self.weight0.shape[0]), + device=self.weight0.device, + dtype=self.original_weight_dtype, + ) + max_calibrate(self.output_quantizer, lambda oq: oq(dummy_input), distributed_sync=False) + # If there are any other states, lets move them to the correct device + + self.weight = None + super().modelopt_post_restore(prefix=prefix) + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in + # sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for + # TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] + # for modelopt checkpoint restore + filtered_state_dict = { + k: v + for k, v in state_dict.items() + if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms)) + } + return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs) + + def _process_quantizer_amax(self, k, v, quantizer_state_dict): + if v.ndim == 4: + quantizer_state_dict[k] = v.squeeze(1).squeeze(-1) + else: + quantizer_state_dict[k] = v.view(-1, 1) if v.numel() > 1 else v.view(-1) + + +@QuantModuleRegistry.register( + {megatron_te.TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"} +) +class _QuantTEGroupedColumnParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear): + _is_column_parallel = True + + +@QuantModuleRegistry.register( + {megatron_te.TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"} +) +class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear): + _is_row_parallel = True diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index f11a736db..3d80d1464 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -241,16 +241,28 @@ def __init__( self, data_parallel_group: torch.distributed.ProcessGroup | int | None = None, tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1, + expert_model_parallel_group: torch.distributed.ProcessGroup | int | None = -1, + expert_tensor_parallel_group: torch.distributed.ProcessGroup | int | None = None, ): """Initialize the parallel state.""" self.data_parallel_group = DistributedProcessGroup(data_parallel_group) self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group) + self.expert_model_parallel_group = DistributedProcessGroup(expert_model_parallel_group) + self.expert_tensor_parallel_group = None + if expert_tensor_parallel_group is not None: + self.expert_tensor_parallel_group = DistributedProcessGroup( + expert_tensor_parallel_group + ) def __repr__(self) -> str: - return ( + parallel_groups = ( f"data_parallel_group: {self.data_parallel_group}, " f"tensor_parallel_group: {self.tensor_parallel_group}, " + f"expert_model_parallel_group: {self.expert_model_parallel_group}" ) + if self.expert_tensor_parallel_group: + parallel_groups += f"expert_tensor_parallel_group: {self.expert_tensor_parallel_group}" + return parallel_groups def get_group(ranks: list[int]): diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index ca6b9bff7..d408f258e 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -38,6 +38,8 @@ ) from megatron.core.models.mamba import MambaModel from megatron.core.parallel_state import ( + get_expert_model_parallel_group, + get_expert_tensor_parallel_group, initialize_model_parallel, is_pipeline_first_stage, is_pipeline_last_stage, @@ -49,12 +51,14 @@ from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig +import modelopt.torch.quantization as mtq from modelopt.torch.export.unified_export_megatron import import_mcore_gpt_from_hf from modelopt.torch.opt.plugins.mcore_dist_checkpointing import ( restore_sharded_modelopt_state, save_sharded_modelopt_state, ) from modelopt.torch.utils import to_empty_if_meta_device +from modelopt.torch.utils.distributed import DistributedProcessGroup try: from megatron.core.extensions.transformer_engine import TENorm @@ -143,6 +147,8 @@ def get_dummy_input(self, seed: int | None = None) -> torch.Tensor: def get_mcore_gpt_model( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + expert_model_parallel_size: int = 1, + expert_tensor_parallel_size: int = 1, initialize_megatron: bool = False, *, num_layers: int = 2, @@ -158,7 +164,10 @@ def get_mcore_gpt_model( normalization: str = "LayerNorm", transformer_impl: str = "modelopt" if HAS_TE else "local", use_cpu_initialization: bool = False, + num_moe_experts: int | None = None, + moe_grouped_gemm: bool = False, bf16: bool = True, + use_te: bool = False, ) -> GPTModel: assert activation_func in ["swiglu", "squared_relu"] assert normalization in ["LayerNorm", "RMSNorm"] @@ -166,7 +175,12 @@ def get_mcore_gpt_model( print(f"Using `{transformer_impl=}` model spec for building GPT Model.") if initialize_megatron: - initialize_for_megatron(tensor_model_parallel_size, pipeline_model_parallel_size) + initialize_for_megatron( + tensor_model_parallel_size, + pipeline_model_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + expert_tensor_parallel_size=expert_tensor_parallel_size, + ) def squared_relu(x): return torch.pow(F.relu(x), 2) @@ -174,7 +188,10 @@ def squared_relu(x): config = TransformerConfig( tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + expert_tensor_parallel_size=expert_tensor_parallel_size, sequence_parallel=False, + moe_grouped_gemm=moe_grouped_gemm, num_layers=num_layers, num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, @@ -182,6 +199,7 @@ def squared_relu(x): num_attention_heads=num_attention_heads, num_query_groups=num_query_groups, ffn_hidden_size=ffn_hidden_size, + num_moe_experts=num_moe_experts, activation_func=squared_relu if activation_func == "squared_relu" else F.silu, normalization=normalization, gated_linear_unit=(activation_func == "swiglu"), @@ -193,7 +211,12 @@ def squared_relu(x): if transformer_impl == "local": assert HAS_APEX, "Apex not installed" - transformer_layer_spec = get_gpt_layer_local_spec(normalization=normalization) + transformer_layer_spec = get_gpt_layer_local_spec( + num_experts=num_moe_experts, + normalization=normalization, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, + ) else: assert HAS_TE, "Transformer Engine not installed" transformer_layer_spec = ( @@ -212,6 +235,7 @@ def squared_relu(x): share_embeddings_and_output_weights=False, position_embedding_type="rope", ) + if bf16: model = model.to(torch.bfloat16) @@ -403,6 +427,8 @@ def initialize_for_megatron( pipeline_model_parallel_size=1, seed=1234, context_parallel_size=1, + expert_model_parallel_size=1, + expert_tensor_parallel_size=None, ): """Initialize Megatron model parallelism. @@ -412,6 +438,9 @@ def initialize_for_megatron( tensor_model_parallel_size, pipeline_model_parallel_size, context_parallel_size=context_parallel_size, + expert_tensor_parallel_size=expert_tensor_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + order="tp-ep-dp-pp", ) model_parallel_cuda_manual_seed(seed) @@ -478,3 +507,196 @@ def convert_maybe_fp8(v): assert torch.allclose(logits_ref, logits_test), ( f"diff: {logits_diff.max()} ref: {logits_ref}, test: {logits_test}" ) + + +def compare_model_outputs(grouped_model, non_grouped_model, forward_fn, tolerance=1e-6): + """Compare outputs of grouped and non-grouped models.""" + # Set both models to eval mode + grouped_model.eval() + non_grouped_model.eval() + + with torch.no_grad(): + # Get outputs from both models + grouped_output = forward_fn(grouped_model) + non_grouped_output = forward_fn(non_grouped_model) + + # Compare outputs + if isinstance(grouped_output, tuple): + grouped_output = grouped_output[0] + if isinstance(non_grouped_output, tuple): + non_grouped_output = non_grouped_output[0] + + output_close = torch.allclose( + grouped_output, non_grouped_output, atol=tolerance, rtol=tolerance + ) + return output_close + + +def sync_amax(model): + amax_dict = { + "linear_fc1.input_quantizer": {}, + "linear_fc1.weight_quantizer": {}, + "linear_fc2.input_quantizer": {}, + "linear_fc2.weight_quantizer": {}, + } + for name, module in model.named_modules(): + if not isinstance(module, mtq.nn.TensorQuantizer): + continue + if not hasattr(module, "_amax"): + continue + if "local_experts" not in name: + continue + expert_name, local_expert_name = name.split("local_experts") + for key in amax_dict: + if key in local_expert_name: + amax_dict[key][expert_name] = max(amax_dict[key].get(expert_name, 0), module.amax) + + for name, module in model.named_modules(): + if not isinstance(module, mtq.nn.TensorQuantizer): + continue + if not hasattr(module, "_amax"): + continue + if "local_experts" not in name: + continue + expert_name, local_expert_name = name.split("local_experts") + for key in amax_dict: + if key in local_expert_name: + module.amax = amax_dict[key][expert_name] + + +def copy_weights_from_grouped_to_non_grouped(grouped_model, non_grouped_model): + """Copy weights from grouped MoE model to non-grouped MoE model.""" + grouped_state = grouped_model.state_dict() + non_grouped_state = non_grouped_model.state_dict() + + # Map grouped weights to non-grouped weights + weight_mapping = {} + non_grouped_key_template = "decoder.layers.{}.mlp.experts.local_experts.{}.linear_fc{}.weight" + for key, value in grouped_state.items(): + if "experts.linear_fc" in key and "weight" in key: + # Extract expert index from grouped weight name + # Format: decoder.layers.X.mlp.experts.linear_fcY.weightZ + parts = key.split(".") + layer_idx = parts[2] # X + fc_idx = parts[5] # Y (linear_fc1 or linear_fc2) + weight_idx = parts[6] # Z (weight0, weight1, etc.) + + # Map to non-grouped format: decoder.layers.X.mlp.experts.local_experts.Y.linear_fcZ.weight + expert_idx = weight_idx.replace("weight", "") + non_grouped_key = non_grouped_key_template.format(layer_idx, expert_idx, fc_idx[-1]) + weight_mapping[non_grouped_key] = value + elif isinstance(value, torch.Tensor): + weight_mapping[key] = value + + # Copy weights to non-grouped model + for non_grouped_key in non_grouped_state: + if non_grouped_key in weight_mapping: + non_grouped_state[non_grouped_key] = weight_mapping[non_grouped_key].clone() + + non_grouped_model.load_state_dict(non_grouped_state) + + +def compare_amax_sync_across_expert_parallel(model): + """ + Test if amax values are synchronized across expert parallel groups. + + Returns True if synchronized, False otherwise. + """ + + ep_group = get_expert_model_parallel_group(check_initialized=False) + etp_group = get_expert_tensor_parallel_group(check_initialized=False) + + # Check if we have either expert model parallel or expert tensor parallel + has_expert_parallel = (ep_group is not None and ep_group.size() > 1) or ( + etp_group is not None and etp_group.size() > 1 + ) + + assert has_expert_parallel, "No expert parallelism detected" + # Collect amax values from expert quantizers only + expert_amax_values = {} + for name, module in model.named_modules(): + if isinstance(module, mtq.nn.TensorQuantizer) and hasattr(module, "_amax"): + # Check for both grouped and non-grouped MoE patterns + if "local_experts" in name or ("experts" in name and "linear_fc" in name): + expert_amax_values[name] = ( + module.amax.item() if hasattr(module.amax, "item") else module.amax + ) + + # Early return if no expert quantizers found + assert expert_amax_values, "No expert quantizers found" + + # Gather amax values from all ranks + world_size = torch.distributed.get_world_size() + all_amax_values = [None] * world_size + torch.distributed.all_gather_object(all_amax_values, expert_amax_values) + + # Group quantizers by type (ignoring specific expert indices) and check sync + expert_quantizers = {} + for rank_idx, rank_amax in enumerate(all_amax_values): + for name, amax_val in rank_amax.items(): + # Create quantizer type key by normalizing the name + if "local_experts" in name: + # Non-grouped MoE: replace expert index with wildcard + import re + + quantizer_type = re.sub(r"local_experts\.\d+", "local_experts.*", name) + else: + # Grouped MoE: use the name as-is since experts are grouped + quantizer_type = name + + if quantizer_type not in expert_quantizers: + expert_quantizers[quantizer_type] = {} + expert_quantizers[quantizer_type][rank_idx] = amax_val + + # Check synchronization - fail fast on first inconsistency + for quantizer_type, rank_values in expert_quantizers.items(): + if len(rank_values) > 1: # Only check if we have multiple ranks + values = list(rank_values.values()) + max_diff = max(values) - min(values) + + if max_diff > 1e-6: # Allow for small floating point differences + return False + + return True + + +def disable_distributed_parallel_sync(model, expert_parallel_type: str = "tensor"): + """Disable distributed parallel synchronization groups.""" + module_parallel_groups = {} + + for name, module in model.named_modules(): + if isinstance(module, mtq.nn.QuantModule): + # Store original groups + module_parallel_groups[name] = { + "data_parallel_group": module.parallel_state.data_parallel_group, + "expert_tensor_parallel_group": module.parallel_state.expert_tensor_parallel_group, + "expert_model_parallel_group": module.parallel_state.expert_model_parallel_group, + } + + # Disable groups + module.parallel_state.data_parallel_group = DistributedProcessGroup(-1) + + if expert_parallel_type in ["tensor", "both"]: + module.parallel_state.expert_tensor_parallel_group = DistributedProcessGroup(-1) + if expert_parallel_type in ["model", "both"]: + module.parallel_state.expert_model_parallel_group = DistributedProcessGroup(-1) + + return module_parallel_groups + + +def enable_distributed_parallel_sync( + model, module_parallel_groups, expert_parallel_type: str = "tensor" +): + """Re-enable distributed parallel synchronization groups.""" + for name, module in model.named_modules(): + if isinstance(module, mtq.nn.QuantModule) and name in module_parallel_groups: + groups = module_parallel_groups[name] + + if expert_parallel_type in ["tensor", "both"]: + module.parallel_state.expert_tensor_parallel_group = groups[ + "expert_tensor_parallel_group" + ] + if expert_parallel_type in ["model", "both"]: + module.parallel_state.expert_model_parallel_group = groups[ + "expert_model_parallel_group" + ] diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index 7a8223d9c..8297d3f4e 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -40,7 +40,6 @@ def need_8_gpus(): pytest.skip("Need at least 8 GPUs to run this test") - @pytest.fixture def need_8_gpus(): if torch.cuda.device_count() < 8: diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 033e2f41a..0184c211d 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -21,10 +21,16 @@ from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job from _test_utils.torch_dist.plugins.megatron_common import ( MegatronModel, + compare_amax_sync_across_expert_parallel, + compare_model_outputs, + copy_weights_from_grouped_to_non_grouped, + disable_distributed_parallel_sync, + enable_distributed_parallel_sync, get_mcore_gpt_model, initialize_for_megatron, run_mcore_inference, sharded_state_dict_test_helper, + sync_amax, ) from _test_utils.torch_misc import set_seed from _test_utils.torch_quantization.models import RegularQuantModelForTP @@ -44,6 +50,7 @@ get_tensor_model_parallel_group, ) from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP import modelopt import modelopt.torch.opt as mto @@ -312,13 +319,25 @@ def test_data_tensor_context_parallel(need_8_gpus, config): ) -def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): +def _gpt_model_provider( + tp_size: int, + hidden_size=256, + vocab_size=64, + num_moe_experts=None, + moe_grouped_gemm=False, + meta_device=False, + ep_size=1, + etp_size=None, + use_te=False, +): """Build the model.""" if meta_device: with torch.device("meta"): gpt_model = get_mcore_gpt_model( tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, num_layers=4, ffn_hidden_size=None, num_attention_heads=8, @@ -327,10 +346,15 @@ def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_devic hidden_size=hidden_size, vocab_size=vocab_size, use_cpu_initialization=meta_device, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, ) else: gpt_model = get_mcore_gpt_model( tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, num_layers=4, ffn_hidden_size=None, num_attention_heads=8, @@ -338,12 +362,15 @@ def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_devic transformer_impl="local", hidden_size=hidden_size, vocab_size=vocab_size, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, ).cuda() return gpt_model.eval() def _test_sharded_state_dict( - tmp_path, config, hidden_size, modelopt_version, compress, meta_device, rank, size + tmp_path, config, hidden_size, modelopt_version, compress, meta_device, moe_config, rank, size ): # Must disable output_layer quantization since output_layer amax cannot be restore via # sharded_state_dict. All output_layer quantizers state are removed. @@ -353,10 +380,42 @@ def _test_sharded_state_dict( mto.conversion.__version__ = modelopt_version mtq.plugins.megatron.__version__ = modelopt_version - initialize_for_megatron(tensor_model_parallel_size=size, seed=SEED) + tp_size = moe_config.get("tp_size", size) + ep_size = moe_config.get("ep_size", 1) + etp_size = moe_config.get("etp_size", None) + num_moe_experts = moe_config.get("num_moe_experts", None) + moe_grouped_gemm = moe_config.get("moe_grouped_gemm", False) + use_te = moe_config.get("use_te", False) + + initialize_for_megatron( + tensor_model_parallel_size=tp_size, + seed=SEED, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, + ) - model_ref = _gpt_model_provider(size, hidden_size, vocab_size=256) - model_test = _gpt_model_provider(size, hidden_size, vocab_size=256, meta_device=meta_device) + model_ref = _gpt_model_provider( + tp_size, + hidden_size, + vocab_size=256, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, + meta_device=meta_device, + ep_size=ep_size, + etp_size=etp_size, + ) + model_test = _gpt_model_provider( + tp_size, + hidden_size, + vocab_size=256, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, + meta_device=meta_device, + ep_size=ep_size, + etp_size=etp_size, + ) prompt_tokens = torch.randint( 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) @@ -437,7 +496,9 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device) spawn_multiprocess_job( size=size, - job=partial(_test_sharded_state_dict, tmp_path, config, 256, None, compress, meta_device), + job=partial( + _test_sharded_state_dict, tmp_path, config, 256, None, compress, meta_device, {} + ), backend="nccl", ) @@ -452,7 +513,7 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device) def test_heterogenous_sharded_state_dict(need_2_gpus, tmp_path, config): spawn_multiprocess_job( size=2, - job=partial(_test_sharded_state_dict, tmp_path, config, 256, None, False, False), + job=partial(_test_sharded_state_dict, tmp_path, config, 256, None, False, False, {}), backend="nccl", ) @@ -473,7 +534,7 @@ def test_sharded_state_dict_old_checkpoints(need_2_gpus, tmp_path, config, model spawn_multiprocess_job( size=2, job=partial( - _test_sharded_state_dict, tmp_path, config, 256, modelopt_version, False, False + _test_sharded_state_dict, tmp_path, config, 256, modelopt_version, False, False, {} ), backend="nccl", ) @@ -556,3 +617,228 @@ def forward_fn(model): def test_fp8_real_quantize(): size = torch.cuda.device_count() spawn_multiprocess_job(size=size, job=_test_fp8_real_quantize_helper, backend="nccl") + + +@pytest.mark.parametrize( + "config", + [ + mtq.FP8_DEFAULT_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_moe_sharded_state_dict(need_8_gpus, tmp_path, config): + size = torch.cuda.device_count() + # TODO: Meta device doesn't work with TE + # TODO: Add support for compress=True for TEGroupedMLP + moe_config = { + "tp_size": 2, + "ep_size": 2, + "etp_size": 2, + "num_moe_experts": 4, + "moe_grouped_gemm": True, + "use_te": True, + } + spawn_multiprocess_job( + size=size, + job=partial( + _test_sharded_state_dict, + tmp_path, + config, + 256, + None, + False, + False, + moe_config, + ), + backend="nccl", + ) + + +def _test_grouped_vs_non_grouped_amax_helper(tp_size, ep_size, etp_size, rank, size): + """Test that grouped and non-grouped MoE models produce similar amax values.""" + initialize_for_megatron( + tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, + seed=SEED, + ) + + # Create input + prompt_tokens = torch.randint(0, 64, (2, 16)).cuda() + + def forward_fn(model): + return megatron_prefill(model, prompt_tokens) + + # Create grouped MoE model + grouped_model = _gpt_model_provider( + tp_size=tp_size, + ep_size=ep_size, + etp_size=etp_size, + hidden_size=32, + moe_grouped_gemm=True, + use_te=True, + num_moe_experts=4, + ) + num_grouped_mlp = sum(isinstance(module, TEGroupedMLP) for module in grouped_model.modules()) + assert num_grouped_mlp == 4, ( + f"TEGrupedMoEModel has {num_grouped_mlp} TEGroupedMLP modules, it should have 4" + ) + + # Create non-grouped MoE model + non_grouped_model = _gpt_model_provider( + tp_size=tp_size, + ep_size=ep_size, + etp_size=etp_size, + hidden_size=32, + moe_grouped_gemm=False, + num_moe_experts=4, + ) + num_sequential_mlp = sum( + isinstance(module, SequentialMLP) for module in non_grouped_model.modules() + ) + assert num_sequential_mlp == 4, ( + f"SequentialMoEModel has {num_sequential_mlp} SequentialMLP modules, it should have 4" + ) + # Copy weights from grouped to non-grouped model + copy_weights_from_grouped_to_non_grouped(grouped_model, non_grouped_model) + + output_comparison_before = compare_model_outputs(grouped_model, non_grouped_model, forward_fn) + assert output_comparison_before, "Outputs are not close before quantization" + + # Quantize grouped model + mtq.quantize(grouped_model, mtq.FP8_DEFAULT_CFG, forward_fn) + + # Quantize non-grouped model + mtq.quantize(non_grouped_model, mtq.FP8_DEFAULT_CFG, forward_fn) + + # sync amax across expert parallel + # TODO: Remove once amax sync is enabled by default for SequentialGroupedMLP + sync_amax(non_grouped_model) + + # Compare model outputs after quantization + output_comparison_after = compare_model_outputs(grouped_model, non_grouped_model, forward_fn) + assert output_comparison_after, "Outputs are not close after quantization" + + +def test_grouped_vs_non_grouped_amax(): + """Test that grouped and non-grouped MoE models produce similar amax values.""" + import time + + size = torch.cuda.device_count() + if size < 4: + pytest.skip("Requires at least 4 GPUs for expert parallel test") + + # Add small delay to avoid port conflicts + time.sleep(0.1) + + spawn_multiprocess_job( + size=size, job=partial(_test_grouped_vs_non_grouped_amax_helper, 1, 2, 2), backend="nccl" + ) + + +def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm): + """ + Test that demonstrates the requirement for expert parallel sync in model_calib.py + """ + # Create model with expert parallelism + model = _gpt_model_provider( + tp_size=1, + ep_size=ep_size, + etp_size=etp_size, + hidden_size=256, + moe_grouped_gemm=moe_grouped_gemm, + use_te=moe_grouped_gemm, + num_moe_experts=4, + ) + + # Create input and forward function + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + + def forward_fn(model): + return megatron_prefill(model, prompt_tokens) + + # Run forward pass and quantize + forward_fn(model) + config = mtq.FP8_DEFAULT_CFG + model = mtq.quantize(model, config, forward_fn) + + # Check initial sync status + initial_sync = compare_amax_sync_across_expert_parallel(model) + assert initial_sync, ( + "Inconsistent amax across expert parallel ranks, Amax should be synchronized across expert parallel ranks" + ) + + # Create inconsistent amax values + rank = torch.distributed.get_rank() + for name, module in model.named_modules(): + if isinstance(module, mtq.nn.TensorQuantizer): + # Check if this is an expert quantizer + is_expert_quantizer = ( + "local_experts" in name # Non-grouped MoE + or ("experts" in name and "linear_fc" in name) # Grouped MoE + ) + + if is_expert_quantizer and hasattr(module, "_amax"): + # Create rank-specific amax values to simulate missing sync + rank_offset = rank * 0.1 + module.amax = module.amax + rank_offset + + # Determine expert parallel type + expert_parallel_type = ( + "both" if ep_size > 1 and etp_size > 1 else ("model" if ep_size > 1 else "tensor") + ) + + # Disable parallel groups and test inconsistency + module_parallel_groups = disable_distributed_parallel_sync(model, expert_parallel_type) + mtq.model_calib.max_calibrate(model, forward_fn) + + inconsistent_sync = compare_amax_sync_across_expert_parallel(model) + assert not inconsistent_sync, ( + "Consistent amax across expert parallel ranks, " + "Amax should not be synchronized across expert parallel ranks since expert parallel is disabled" + ) + + # Re-enable parallel groups and test synchronization + enable_distributed_parallel_sync(model, module_parallel_groups, expert_parallel_type) + mtq.model_calib.max_calibrate(model, forward_fn) + + final_sync = compare_amax_sync_across_expert_parallel(model) + assert final_sync, ( + "Inconsistent amax across expert parallel ranks, Amax should be synchronized across expert parallel ranks" + ) + + +def _test_expert_parallel_sync_helper(ep_size, etp_size, moe_grouped_gemm, rank, size): + """Test expert parallel synchronization with different configurations.""" + initialize_for_megatron( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, + seed=42 + rank, + ) + + # Run the actual test + _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm) + + +@pytest.mark.parametrize(("ep_size", "etp_size"), [(1, 2), (2, 1), (2, 2)]) +@pytest.mark.parametrize("moe_grouped_gemm", [True, False]) +def test_expert_parallel_sync(need_4_gpus, ep_size, etp_size, moe_grouped_gemm): + """Test expert model parallel synchronization.""" + import time + + size = torch.cuda.device_count() + total_size = ep_size * etp_size + if size < total_size: + pytest.skip(f"Requires at least {total_size} GPUs for expert model parallel test") + + # Add small delay to avoid port conflicts + time.sleep(0.1) + + spawn_multiprocess_job( + size=total_size, + job=partial(_test_expert_parallel_sync_helper, ep_size, etp_size, moe_grouped_gemm), + backend="nccl", + ) From 70776c3c35ab4e14d52ce22d69b8678ef022ab46 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Tue, 7 Oct 2025 22:49:53 +0000 Subject: [PATCH 15/28] code cleanup Signed-off-by: Kinjal Patel --- .../torch/quantization/plugins/megatron.py | 72 ++----------------- .../torch_dist/plugins/megatron_common.py | 1 - .../quantization/plugins/test_megatron.py | 45 +++++------- 3 files changed, 26 insertions(+), 92 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index eb948a48d..26133ee59 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -39,7 +39,7 @@ ) from modelopt.torch.utils.distributed import ParallelState -from ..nn import QuantModuleRegistry, SequentialQuantizer, TensorQuantizer +from ..nn import QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear @@ -503,7 +503,6 @@ def _setup(self): self.parallel_state = ParallelState( data_parallel_group, mcore_parallel.get_tensor_model_parallel_group(), - mcore_parallel.get_context_parallel_group(), mcore_parallel.get_expert_model_parallel_group(), expert_tensor_parallel_group, ) @@ -546,70 +545,13 @@ def te_grouped_quantized_linear_fn(ctx, inp, m_splits, *args): ] def modelopt_post_restore(self, prefix: str = ""): - """Post restore to correctly configure the TensorQuantizer states for MCore/distributed frameworks. - - ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their - shape before saving. However this is not enough for MCore/distributed frameworks since the tensor parallelism - could change between saving and restoring. If the tensor parallelism changes, the shape of the quantizer - states also changes. So we need to re-calculate the quantizer states. - """ - from modelopt.torch.quantization.model_calib import max_calibrate - - def _check_unsupported_states(quantizer: TensorQuantizer): - for k in quantizer.state_dict(): - if k not in ["_amax", "_pre_quant_scale"]: - warnings.warn( - f"Restore of {k} for {prefix} is not supported. The restore of this layer might be " - f"incorrect. Please implement a custom restore for {k}." - ) - - def _has_state(quantizer, name): - # Handling for SequentialQuantizer - quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer - return hasattr(quantizer, name) - - # weights for TEGroupedLinear are stored in weight0, weight1, etc. - if self.weight0 is None: - return - for quantizer in [self.weight_quantizer, self.input_quantizer, self.output_quantizer]: - _check_unsupported_states( - quantizer if isinstance(quantizer, TensorQuantizer) else quantizer[0] - ) - if _has_state(self.weight_quantizer, "_amax"): - self.weight_quantizer.reset_amax() - for i in range(self.num_gemms): - weight = getattr(self, f"weight{i}") - assert weight is not None, "weight is None" - - max_calibrate(self.weight_quantizer, lambda wq: wq(weight), distributed_sync=False) - if _has_state(self.input_quantizer, "_pre_quant_scale"): - if hasattr(self.input_quantizer, "_pre_quant_scale"): - delattr(self.input_quantizer, "_pre_quant_scale") - pqs = torch.zeros( - (weight.shape[1]), device=weight.device, dtype=self.original_weight_dtype - ) - self.input_quantizer.register_buffer("_pre_quant_scale", pqs) - - if _has_state(self.input_quantizer, "_amax"): - self.input_quantizer.reset_amax() - dummy_input = torch.ones( - (1, 1, self.weight0.shape[1]), - device=self.weight0.device, - dtype=self.original_weight_dtype, - ) - max_calibrate(self.input_quantizer, lambda iq: iq(dummy_input), distributed_sync=False) - if _has_state(self.output_quantizer, "_amax"): - self.output_quantizer.reset_amax() - dummy_input = torch.ones( - (1, 1, self.weight0.shape[0]), - device=self.weight0.device, - dtype=self.original_weight_dtype, - ) - max_calibrate(self.output_quantizer, lambda oq: oq(dummy_input), distributed_sync=False) - # If there are any other states, lets move them to the correct device - - self.weight = None + # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to + # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning + # self.weight0 to self.weight to run the quantizer states initialization. + self.weight = self.weight0 super().modelopt_post_restore(prefix=prefix) + # Revert the weight to None after post_restore to avoid the weight being None during forward pass. + self.weight = None def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index d408f258e..636e8f0fa 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -440,7 +440,6 @@ def initialize_for_megatron( context_parallel_size=context_parallel_size, expert_tensor_parallel_size=expert_tensor_parallel_size, expert_model_parallel_size=expert_model_parallel_size, - order="tp-ep-dp-pp", ) model_parallel_cuda_manual_seed(seed) diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 0184c211d..8dbdcfd8a 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -654,7 +654,7 @@ def test_moe_sharded_state_dict(need_8_gpus, tmp_path, config): ) -def _test_grouped_vs_non_grouped_amax_helper(tp_size, ep_size, etp_size, rank, size): +def _test_grouped_vs_non_grouped_quantize_helper(tp_size, ep_size, etp_size, rank, size): """Test that grouped and non-grouped MoE models produce similar amax values.""" initialize_for_megatron( tensor_model_parallel_size=tp_size, @@ -720,8 +720,8 @@ def forward_fn(model): assert output_comparison_after, "Outputs are not close after quantization" -def test_grouped_vs_non_grouped_amax(): - """Test that grouped and non-grouped MoE models produce similar amax values.""" +def test_grouped_vs_non_grouped_quantize(): + """Test that grouped and non-grouped MoE models produce similar quantized models.""" import time size = torch.cuda.device_count() @@ -732,14 +732,22 @@ def test_grouped_vs_non_grouped_amax(): time.sleep(0.1) spawn_multiprocess_job( - size=size, job=partial(_test_grouped_vs_non_grouped_amax_helper, 1, 2, 2), backend="nccl" + size=size, + job=partial(_test_grouped_vs_non_grouped_quantize_helper, 1, 2, 2), + backend="nccl", ) -def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm): - """ - Test that demonstrates the requirement for expert parallel sync in model_calib.py - """ +def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm, rank, size): + """Test expert parallel synchronization with different configurations.""" + initialize_for_megatron( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, + seed=SEED, + ) + # Create model with expert parallelism model = _gpt_model_provider( tp_size=1, @@ -769,7 +777,7 @@ def forward_fn(model): ) # Create inconsistent amax values - rank = torch.distributed.get_rank() + cur_rank = torch.distributed.get_rank() for name, module in model.named_modules(): if isinstance(module, mtq.nn.TensorQuantizer): # Check if this is an expert quantizer @@ -780,7 +788,7 @@ def forward_fn(model): if is_expert_quantizer and hasattr(module, "_amax"): # Create rank-specific amax values to simulate missing sync - rank_offset = rank * 0.1 + rank_offset = cur_rank * 0.1 module.amax = module.amax + rank_offset # Determine expert parallel type @@ -808,21 +816,6 @@ def forward_fn(model): ) -def _test_expert_parallel_sync_helper(ep_size, etp_size, moe_grouped_gemm, rank, size): - """Test expert parallel synchronization with different configurations.""" - initialize_for_megatron( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - context_parallel_size=1, - expert_model_parallel_size=ep_size, - expert_tensor_parallel_size=etp_size, - seed=42 + rank, - ) - - # Run the actual test - _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm) - - @pytest.mark.parametrize(("ep_size", "etp_size"), [(1, 2), (2, 1), (2, 2)]) @pytest.mark.parametrize("moe_grouped_gemm", [True, False]) def test_expert_parallel_sync(need_4_gpus, ep_size, etp_size, moe_grouped_gemm): @@ -839,6 +832,6 @@ def test_expert_parallel_sync(need_4_gpus, ep_size, etp_size, moe_grouped_gemm): spawn_multiprocess_job( size=total_size, - job=partial(_test_expert_parallel_sync_helper, ep_size, etp_size, moe_grouped_gemm), + job=partial(_test_expert_model_parallel_amax_sync, ep_size, etp_size, moe_grouped_gemm), backend="nccl", ) From bab9ca29075306c70de60ef14a714922dbaf3f87 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Wed, 8 Oct 2025 23:53:58 +0000 Subject: [PATCH 16/28] code and test cleanup Signed-off-by: Kinjal Patel --- modelopt/torch/quantization/mode.py | 8 +- modelopt/torch/quantization/model_calib.py | 4 - modelopt/torch/quantization/plugins/custom.py | 7 ++ .../torch/quantization/plugins/megatron.py | 109 +++++++++++------- modelopt/torch/utils/distributed.py | 8 -- .../torch_dist/plugins/megatron_common.py | 108 +---------------- tests/gpu/torch/conftest.py | 6 - .../quantization/plugins/test_megatron.py | 99 ++++++---------- 8 files changed, 125 insertions(+), 224 deletions(-) diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 4e6e9fd49..1e72a291a 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -208,6 +208,8 @@ def wrapped_calib_func( forward_loop and the relevant kwargs and are independent of the ModelOpt framework. So lets wrap them to be compatible with the ModelOpt convert entrypoint. """ + from .plugins.custom import register_custom_post_calibration_plugins + kwargs = config.model_dump() method = kwargs.pop("method") if method is not None and "awq" in method: @@ -218,6 +220,7 @@ def wrapped_calib_func( # Call the function with forward_loop as a separate argument func(model, forward_loop=forward_loop, **kwargs) + register_custom_post_calibration_plugins(model) # Lets get the latest metadata for the quantizer states metadata = {} update_quantize_metadata(model, config, metadata) @@ -290,7 +293,10 @@ def convert(self) -> ConvertEntrypoint: def wrapped_func(model, config, forward_loop=None): # Access _calib_func as a class attribute to avoid binding # Check if _calib_func is defined as a class attribute - return wrapped_calib_func(model, config, forward_loop, func=self.__class__._calib_func) + calib_results = wrapped_calib_func( + model, config, forward_loop, func=self.__class__._calib_func + ) + return calib_results return wrapped_func diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index f50e7ae29..414e13e2d 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -89,10 +89,6 @@ def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): if getattr(quantizer, "_amax", None) is not None: quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) - if parallel_state.expert_tensor_parallel_group is not None: - quantizer.sync_amax_across_distributed_group( - parallel_state.expert_tensor_parallel_group - ) # TODO: create sync_bias_across_distributed_group for name, module in model.named_modules(): diff --git a/modelopt/torch/quantization/plugins/custom.py b/modelopt/torch/quantization/plugins/custom.py index 4227f3c49..38317bd52 100644 --- a/modelopt/torch/quantization/plugins/custom.py +++ b/modelopt/torch/quantization/plugins/custom.py @@ -30,6 +30,7 @@ CUSTOM_MODEL_PLUGINS = set() CUSTOM_POST_CONVERSION_PLUGINS = set() +CUSTOM_POST_CALIBRATION_PLUGINS = set() # TODO: This is a temporary solution @@ -46,6 +47,12 @@ def register_custom_post_conversion_plugins(model): callback(model) +def register_custom_post_calibration_plugins(model): + """Registers custom modules as QUANT_MODULE after calibration.""" + for callback in CUSTOM_POST_CALIBRATION_PLUGINS: + callback(model) + + class _QuantFunctionalMixin(QuantModule): """Mixin class for quantized functionals. diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 26133ee59..3de8fe9b7 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -42,13 +42,51 @@ from ..nn import QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear from ..qtensor import QTensorWrapper -from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear +from .custom import CUSTOM_MODEL_PLUGINS, CUSTOM_POST_CALIBRATION_PLUGINS, _ParallelLinear logger = logging.getLogger(__name__) __all__ = [] +def sync_amax_across_sequential_mlp(model: torch.nn.Module): + """Sync amax across experts in a SequentialMLP.""" + amax_dict = { + "linear_fc1.input_quantizer": {}, + "linear_fc1.weight_quantizer": {}, + "linear_fc2.input_quantizer": {}, + "linear_fc2.weight_quantizer": {}, + } + # gather amax values from SequentialMLP experts + for name, module in model.named_modules(): + if ( + not isinstance(module, TensorQuantizer) + or not hasattr(module, "_amax") + or "local_experts" not in name + ): + continue + expert_name, local_expert_name = name.split("local_experts") + for key in amax_dict: + if key in local_expert_name: + amax_dict[key][expert_name] = max(amax_dict[key].get(expert_name, 0), module.amax) + + # sync amax values across experts in SequentialMLP + for name, module in model.named_modules(): + if ( + not isinstance(module, TensorQuantizer) + or not hasattr(module, "_amax") + or "local_experts" not in name + ): + continue + expert_name, local_expert_name = name.split("local_experts") + for key in amax_dict: + if key in local_expert_name: + module.amax = amax_dict[key][expert_name] + + +CUSTOM_POST_CALIBRATION_PLUGINS.add(sync_amax_across_sequential_mlp) + + def real_quant_module_get_extra_state(self) -> dict: """Populating real_quantizer_state and q_tensor_state.""" extra_state = {} @@ -225,24 +263,19 @@ class _MegatronParallelLinear(_ParallelLinear): ] def _setup(self): - data_parallel_group = None - try: - data_parallel_group = get_data_parallel_group(with_context_parallel=True) - except AssertionError: - logger.warning("Context parallel group is not initialized, using data parallel group") - data_parallel_group = get_data_parallel_group() - - try: - expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() - except AssertionError: - expert_tensor_parallel_group = None - - self.parallel_state = ParallelState( - data_parallel_group, - mcore_parallel.get_tensor_model_parallel_group(), - mcore_parallel.get_expert_model_parallel_group(), - expert_tensor_parallel_group, - ) + if not hasattr(self, "parallel_state") or self.parallel_state is None: + data_parallel_group = None + try: + data_parallel_group = get_data_parallel_group(with_context_parallel=True) + except AssertionError: + logger.warning( + "Context parallel group is not initialized, using data parallel group" + ) + data_parallel_group = get_data_parallel_group() + self.parallel_state = ParallelState( + data_parallel_group, + mcore_parallel.get_tensor_model_parallel_group(), + ) super()._setup() def _process_quantizer_amax(self, k, v, quantizer_state_dict): @@ -490,26 +523,22 @@ def forward(self, input, *args, **kwargs): @QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear_public"}) class _QuantTEGroupedLinear(_MegatronParallelLinear): def _setup(self): - data_parallel_group = None - try: - data_parallel_group = get_data_parallel_group(with_context_parallel=True) - except AssertionError: - data_parallel_group = get_data_parallel_group() - - try: - expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() - except AssertionError: - expert_tensor_parallel_group = None - self.parallel_state = ParallelState( - data_parallel_group, - mcore_parallel.get_tensor_model_parallel_group(), - mcore_parallel.get_expert_model_parallel_group(), - expert_tensor_parallel_group, - ) - self.input_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_input) - self.weight_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_weight) - self.output_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_output) - self.output_quantizer.disable() + if not hasattr(self, "parallel_state") or self.parallel_state is None: + data_parallel_group = None + try: + data_parallel_group = get_data_parallel_group(with_context_parallel=True) + except AssertionError: + data_parallel_group = get_data_parallel_group() + + self.parallel_state = ParallelState( + data_parallel_group, + tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), + expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), + ) + self.input_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_input) + self.weight_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_weight) + self.output_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_output) + self.output_quantizer.disable() # Memorize the original weight.dtype for modelopt_post_restore given that # the dtype can change later. @@ -582,5 +611,5 @@ class _QuantTEGroupedColumnParallelLinear(_QuantTEGroupedLinear, _MegatronColumn @QuantModuleRegistry.register( {megatron_te.TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"} ) -class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear): +class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronRowParallelLinear): _is_row_parallel = True diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 3d80d1464..bcebd0492 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -242,17 +242,11 @@ def __init__( data_parallel_group: torch.distributed.ProcessGroup | int | None = None, tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1, expert_model_parallel_group: torch.distributed.ProcessGroup | int | None = -1, - expert_tensor_parallel_group: torch.distributed.ProcessGroup | int | None = None, ): """Initialize the parallel state.""" self.data_parallel_group = DistributedProcessGroup(data_parallel_group) self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group) self.expert_model_parallel_group = DistributedProcessGroup(expert_model_parallel_group) - self.expert_tensor_parallel_group = None - if expert_tensor_parallel_group is not None: - self.expert_tensor_parallel_group = DistributedProcessGroup( - expert_tensor_parallel_group - ) def __repr__(self) -> str: parallel_groups = ( @@ -260,8 +254,6 @@ def __repr__(self) -> str: f"tensor_parallel_group: {self.tensor_parallel_group}, " f"expert_model_parallel_group: {self.expert_model_parallel_group}" ) - if self.expert_tensor_parallel_group: - parallel_groups += f"expert_tensor_parallel_group: {self.expert_tensor_parallel_group}" return parallel_groups diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 636e8f0fa..691002eef 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import re from warnings import warn import torch @@ -58,7 +59,6 @@ save_sharded_modelopt_state, ) from modelopt.torch.utils import to_empty_if_meta_device -from modelopt.torch.utils.distributed import DistributedProcessGroup try: from megatron.core.extensions.transformer_engine import TENorm @@ -148,7 +148,7 @@ def get_mcore_gpt_model( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, expert_model_parallel_size: int = 1, - expert_tensor_parallel_size: int = 1, + expert_tensor_parallel_size: int | None = None, initialize_megatron: bool = False, *, num_layers: int = 2, @@ -508,61 +508,6 @@ def convert_maybe_fp8(v): ) -def compare_model_outputs(grouped_model, non_grouped_model, forward_fn, tolerance=1e-6): - """Compare outputs of grouped and non-grouped models.""" - # Set both models to eval mode - grouped_model.eval() - non_grouped_model.eval() - - with torch.no_grad(): - # Get outputs from both models - grouped_output = forward_fn(grouped_model) - non_grouped_output = forward_fn(non_grouped_model) - - # Compare outputs - if isinstance(grouped_output, tuple): - grouped_output = grouped_output[0] - if isinstance(non_grouped_output, tuple): - non_grouped_output = non_grouped_output[0] - - output_close = torch.allclose( - grouped_output, non_grouped_output, atol=tolerance, rtol=tolerance - ) - return output_close - - -def sync_amax(model): - amax_dict = { - "linear_fc1.input_quantizer": {}, - "linear_fc1.weight_quantizer": {}, - "linear_fc2.input_quantizer": {}, - "linear_fc2.weight_quantizer": {}, - } - for name, module in model.named_modules(): - if not isinstance(module, mtq.nn.TensorQuantizer): - continue - if not hasattr(module, "_amax"): - continue - if "local_experts" not in name: - continue - expert_name, local_expert_name = name.split("local_experts") - for key in amax_dict: - if key in local_expert_name: - amax_dict[key][expert_name] = max(amax_dict[key].get(expert_name, 0), module.amax) - - for name, module in model.named_modules(): - if not isinstance(module, mtq.nn.TensorQuantizer): - continue - if not hasattr(module, "_amax"): - continue - if "local_experts" not in name: - continue - expert_name, local_expert_name = name.split("local_experts") - for key in amax_dict: - if key in local_expert_name: - module.amax = amax_dict[key][expert_name] - - def copy_weights_from_grouped_to_non_grouped(grouped_model, non_grouped_model): """Copy weights from grouped MoE model to non-grouped MoE model.""" grouped_state = grouped_model.state_dict() @@ -636,8 +581,6 @@ def compare_amax_sync_across_expert_parallel(model): # Create quantizer type key by normalizing the name if "local_experts" in name: # Non-grouped MoE: replace expert index with wildcard - import re - quantizer_type = re.sub(r"local_experts\.\d+", "local_experts.*", name) else: # Grouped MoE: use the name as-is since experts are grouped @@ -652,50 +595,7 @@ def compare_amax_sync_across_expert_parallel(model): if len(rank_values) > 1: # Only check if we have multiple ranks values = list(rank_values.values()) max_diff = max(values) - min(values) - if max_diff > 1e-6: # Allow for small floating point differences - return False + return False, quantizer_type, rank_values - return True - - -def disable_distributed_parallel_sync(model, expert_parallel_type: str = "tensor"): - """Disable distributed parallel synchronization groups.""" - module_parallel_groups = {} - - for name, module in model.named_modules(): - if isinstance(module, mtq.nn.QuantModule): - # Store original groups - module_parallel_groups[name] = { - "data_parallel_group": module.parallel_state.data_parallel_group, - "expert_tensor_parallel_group": module.parallel_state.expert_tensor_parallel_group, - "expert_model_parallel_group": module.parallel_state.expert_model_parallel_group, - } - - # Disable groups - module.parallel_state.data_parallel_group = DistributedProcessGroup(-1) - - if expert_parallel_type in ["tensor", "both"]: - module.parallel_state.expert_tensor_parallel_group = DistributedProcessGroup(-1) - if expert_parallel_type in ["model", "both"]: - module.parallel_state.expert_model_parallel_group = DistributedProcessGroup(-1) - - return module_parallel_groups - - -def enable_distributed_parallel_sync( - model, module_parallel_groups, expert_parallel_type: str = "tensor" -): - """Re-enable distributed parallel synchronization groups.""" - for name, module in model.named_modules(): - if isinstance(module, mtq.nn.QuantModule) and name in module_parallel_groups: - groups = module_parallel_groups[name] - - if expert_parallel_type in ["tensor", "both"]: - module.parallel_state.expert_tensor_parallel_group = groups[ - "expert_tensor_parallel_group" - ] - if expert_parallel_type in ["model", "both"]: - module.parallel_state.expert_model_parallel_group = groups[ - "expert_model_parallel_group" - ] + return True, None, None diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index 8297d3f4e..f32065bce 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -40,12 +40,6 @@ def need_8_gpus(): pytest.skip("Need at least 8 GPUs to run this test") -@pytest.fixture -def need_8_gpus(): - if torch.cuda.device_count() < 8: - pytest.skip("Need at least 8 GPUs to run this test") - - @pytest.fixture(scope="module") def set_torch_dtype(request): orig_dtype = torch.get_default_dtype() diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 8dbdcfd8a..eb46968d4 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -22,15 +22,11 @@ from _test_utils.torch_dist.plugins.megatron_common import ( MegatronModel, compare_amax_sync_across_expert_parallel, - compare_model_outputs, copy_weights_from_grouped_to_non_grouped, - disable_distributed_parallel_sync, - enable_distributed_parallel_sync, get_mcore_gpt_model, initialize_for_megatron, run_mcore_inference, sharded_state_dict_test_helper, - sync_amax, ) from _test_utils.torch_misc import set_seed from _test_utils.torch_quantization.models import RegularQuantModelForTP @@ -401,7 +397,6 @@ def _test_sharded_state_dict( num_moe_experts=num_moe_experts, moe_grouped_gemm=moe_grouped_gemm, use_te=use_te, - meta_device=meta_device, ep_size=ep_size, etp_size=etp_size, ) @@ -670,7 +665,7 @@ def forward_fn(model): return megatron_prefill(model, prompt_tokens) # Create grouped MoE model - grouped_model = _gpt_model_provider( + grouped_moe_model = _gpt_model_provider( tp_size=tp_size, ep_size=ep_size, etp_size=etp_size, @@ -679,13 +674,15 @@ def forward_fn(model): use_te=True, num_moe_experts=4, ) - num_grouped_mlp = sum(isinstance(module, TEGroupedMLP) for module in grouped_model.modules()) + num_grouped_mlp = sum( + isinstance(module, TEGroupedMLP) for module in grouped_moe_model.modules() + ) assert num_grouped_mlp == 4, ( f"TEGrupedMoEModel has {num_grouped_mlp} TEGroupedMLP modules, it should have 4" ) # Create non-grouped MoE model - non_grouped_model = _gpt_model_provider( + sequential_moe_model = _gpt_model_provider( tp_size=tp_size, ep_size=ep_size, etp_size=etp_size, @@ -694,43 +691,40 @@ def forward_fn(model): num_moe_experts=4, ) num_sequential_mlp = sum( - isinstance(module, SequentialMLP) for module in non_grouped_model.modules() + isinstance(module, SequentialMLP) for module in sequential_moe_model.modules() ) assert num_sequential_mlp == 4, ( f"SequentialMoEModel has {num_sequential_mlp} SequentialMLP modules, it should have 4" ) # Copy weights from grouped to non-grouped model - copy_weights_from_grouped_to_non_grouped(grouped_model, non_grouped_model) + copy_weights_from_grouped_to_non_grouped(grouped_moe_model, sequential_moe_model) - output_comparison_before = compare_model_outputs(grouped_model, non_grouped_model, forward_fn) - assert output_comparison_before, "Outputs are not close before quantization" + # Compare model outputs before quantization + grouped_moe_output = forward_fn(grouped_moe_model) + non_grouped_moe_output = forward_fn(sequential_moe_model) + assert torch.allclose(grouped_moe_output, non_grouped_moe_output, atol=1e-6, rtol=1e-6) # Quantize grouped model - mtq.quantize(grouped_model, mtq.FP8_DEFAULT_CFG, forward_fn) + mtq.quantize(grouped_moe_model, mtq.FP8_DEFAULT_CFG, forward_fn) # Quantize non-grouped model - mtq.quantize(non_grouped_model, mtq.FP8_DEFAULT_CFG, forward_fn) - - # sync amax across expert parallel - # TODO: Remove once amax sync is enabled by default for SequentialGroupedMLP - sync_amax(non_grouped_model) + mtq.quantize(sequential_moe_model, mtq.FP8_DEFAULT_CFG, forward_fn) # Compare model outputs after quantization - output_comparison_after = compare_model_outputs(grouped_model, non_grouped_model, forward_fn) - assert output_comparison_after, "Outputs are not close after quantization" + grouped_moe_quant_output = forward_fn(grouped_moe_model) + non_grouped_moe_quant_output = forward_fn(sequential_moe_model) + assert torch.allclose( + grouped_moe_quant_output, non_grouped_moe_quant_output, atol=1e-6, rtol=1e-6 + ) def test_grouped_vs_non_grouped_quantize(): """Test that grouped and non-grouped MoE models produce similar quantized models.""" - import time size = torch.cuda.device_count() if size < 4: pytest.skip("Requires at least 4 GPUs for expert parallel test") - # Add small delay to avoid port conflicts - time.sleep(0.1) - spawn_multiprocess_job( size=size, job=partial(_test_grouped_vs_non_grouped_quantize_helper, 1, 2, 2), @@ -738,7 +732,7 @@ def test_grouped_vs_non_grouped_quantize(): ) -def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm, rank, size): +def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm, config, rank, size): """Test expert parallel synchronization with different configurations.""" initialize_for_megatron( tensor_model_parallel_size=1, @@ -758,24 +752,19 @@ def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm, r use_te=moe_grouped_gemm, num_moe_experts=4, ) - - # Create input and forward function prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() def forward_fn(model): return megatron_prefill(model, prompt_tokens) - # Run forward pass and quantize - forward_fn(model) - config = mtq.FP8_DEFAULT_CFG + # quantize the model model = mtq.quantize(model, config, forward_fn) # Check initial sync status - initial_sync = compare_amax_sync_across_expert_parallel(model) + initial_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model) assert initial_sync, ( - "Inconsistent amax across expert parallel ranks, Amax should be synchronized across expert parallel ranks" + f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}" ) - # Create inconsistent amax values cur_rank = torch.distributed.get_rank() for name, module in model.named_modules(): @@ -791,47 +780,35 @@ def forward_fn(model): rank_offset = cur_rank * 0.1 module.amax = module.amax + rank_offset - # Determine expert parallel type - expert_parallel_type = ( - "both" if ep_size > 1 and etp_size > 1 else ("model" if ep_size > 1 else "tensor") - ) - - # Disable parallel groups and test inconsistency - module_parallel_groups = disable_distributed_parallel_sync(model, expert_parallel_type) - mtq.model_calib.max_calibrate(model, forward_fn) - - inconsistent_sync = compare_amax_sync_across_expert_parallel(model) - assert not inconsistent_sync, ( + # Test if the amax values are inconsistent + inconsistent_amax, _, _ = compare_amax_sync_across_expert_parallel(model) + assert not inconsistent_amax, ( "Consistent amax across expert parallel ranks, " "Amax should not be synchronized across expert parallel ranks since expert parallel is disabled" ) - # Re-enable parallel groups and test synchronization - enable_distributed_parallel_sync(model, module_parallel_groups, expert_parallel_type) mtq.model_calib.max_calibrate(model, forward_fn) - final_sync = compare_amax_sync_across_expert_parallel(model) - assert final_sync, ( - "Inconsistent amax across expert parallel ranks, Amax should be synchronized across expert parallel ranks" - ) + final_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model) + assert final_sync, f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}" @pytest.mark.parametrize(("ep_size", "etp_size"), [(1, 2), (2, 1), (2, 2)]) @pytest.mark.parametrize("moe_grouped_gemm", [True, False]) -def test_expert_parallel_sync(need_4_gpus, ep_size, etp_size, moe_grouped_gemm): +def test_expert_parallel_sync(ep_size, etp_size, moe_grouped_gemm): """Test expert model parallel synchronization.""" - import time - size = torch.cuda.device_count() - total_size = ep_size * etp_size - if size < total_size: - pytest.skip(f"Requires at least {total_size} GPUs for expert model parallel test") - - # Add small delay to avoid port conflicts - time.sleep(0.1) + if size < ep_size * etp_size: + pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test") spawn_multiprocess_job( - size=total_size, - job=partial(_test_expert_model_parallel_amax_sync, ep_size, etp_size, moe_grouped_gemm), + size=size, + job=partial( + _test_expert_model_parallel_amax_sync, + ep_size, + etp_size, + moe_grouped_gemm, + mtq.FP8_DEFAULT_CFG, + ), backend="nccl", ) From f9ba6e8154f8dd7942e117518101576744c6f4ce Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 9 Oct 2025 00:05:41 +0000 Subject: [PATCH 17/28] Updated moe names in tests Signed-off-by: Kinjal Patel --- modelopt/torch/quantization/mode.py | 5 +-- .../torch_dist/plugins/megatron_common.py | 36 +++++++-------- .../quantization/plugins/test_megatron.py | 44 +++++++++---------- 3 files changed, 41 insertions(+), 44 deletions(-) diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 1e72a291a..5943a991d 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -293,10 +293,7 @@ def convert(self) -> ConvertEntrypoint: def wrapped_func(model, config, forward_loop=None): # Access _calib_func as a class attribute to avoid binding # Check if _calib_func is defined as a class attribute - calib_results = wrapped_calib_func( - model, config, forward_loop, func=self.__class__._calib_func - ) - return calib_results + return wrapped_calib_func(model, config, forward_loop, func=self.__class__._calib_func) return wrapped_func diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 691002eef..c2a9fb745 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -508,15 +508,15 @@ def convert_maybe_fp8(v): ) -def copy_weights_from_grouped_to_non_grouped(grouped_model, non_grouped_model): - """Copy weights from grouped MoE model to non-grouped MoE model.""" - grouped_state = grouped_model.state_dict() - non_grouped_state = non_grouped_model.state_dict() +def copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_moe_model): + """Copy weights from TEGrouped MoE model to sequential MoE model.""" + te_grouped_state = te_grouped_moe_model.state_dict() + sequential_state = sequential_moe_model.state_dict() - # Map grouped weights to non-grouped weights + # Map grouped weights to sequential weights weight_mapping = {} - non_grouped_key_template = "decoder.layers.{}.mlp.experts.local_experts.{}.linear_fc{}.weight" - for key, value in grouped_state.items(): + sequential_key_template = "decoder.layers.{}.mlp.experts.local_experts.{}.linear_fc{}.weight" + for key, value in te_grouped_state.items(): if "experts.linear_fc" in key and "weight" in key: # Extract expert index from grouped weight name # Format: decoder.layers.X.mlp.experts.linear_fcY.weightZ @@ -525,19 +525,19 @@ def copy_weights_from_grouped_to_non_grouped(grouped_model, non_grouped_model): fc_idx = parts[5] # Y (linear_fc1 or linear_fc2) weight_idx = parts[6] # Z (weight0, weight1, etc.) - # Map to non-grouped format: decoder.layers.X.mlp.experts.local_experts.Y.linear_fcZ.weight + # Map to sequential format: decoder.layers.X.mlp.experts.local_experts.Y.linear_fcZ.weight expert_idx = weight_idx.replace("weight", "") - non_grouped_key = non_grouped_key_template.format(layer_idx, expert_idx, fc_idx[-1]) - weight_mapping[non_grouped_key] = value + sequential_key = sequential_key_template.format(layer_idx, expert_idx, fc_idx[-1]) + weight_mapping[sequential_key] = value elif isinstance(value, torch.Tensor): weight_mapping[key] = value - # Copy weights to non-grouped model - for non_grouped_key in non_grouped_state: - if non_grouped_key in weight_mapping: - non_grouped_state[non_grouped_key] = weight_mapping[non_grouped_key].clone() + # Copy weights to sequential model + for sequential_key in sequential_state: + if sequential_key in weight_mapping: + sequential_state[sequential_key] = weight_mapping[sequential_key].clone() - non_grouped_model.load_state_dict(non_grouped_state) + sequential_moe_model.load_state_dict(sequential_state) def compare_amax_sync_across_expert_parallel(model): @@ -560,7 +560,7 @@ def compare_amax_sync_across_expert_parallel(model): expert_amax_values = {} for name, module in model.named_modules(): if isinstance(module, mtq.nn.TensorQuantizer) and hasattr(module, "_amax"): - # Check for both grouped and non-grouped MoE patterns + # Check for both TEGrouped and sequential MoE patterns if "local_experts" in name or ("experts" in name and "linear_fc" in name): expert_amax_values[name] = ( module.amax.item() if hasattr(module.amax, "item") else module.amax @@ -580,10 +580,10 @@ def compare_amax_sync_across_expert_parallel(model): for name, amax_val in rank_amax.items(): # Create quantizer type key by normalizing the name if "local_experts" in name: - # Non-grouped MoE: replace expert index with wildcard + # sequential MoE: replace expert index with wildcard quantizer_type = re.sub(r"local_experts\.\d+", "local_experts.*", name) else: - # Grouped MoE: use the name as-is since experts are grouped + # TEGrouped MoE: use the name as-is since experts are grouped quantizer_type = name if quantizer_type not in expert_quantizers: diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index eb46968d4..856dd009e 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -649,8 +649,8 @@ def test_moe_sharded_state_dict(need_8_gpus, tmp_path, config): ) -def _test_grouped_vs_non_grouped_quantize_helper(tp_size, ep_size, etp_size, rank, size): - """Test that grouped and non-grouped MoE models produce similar amax values.""" +def _test_te_grouped_vs_sequential_quantize_helper(tp_size, ep_size, etp_size, rank, size): + """Test that TEGrouped and sequential MoE models produce similar amax values.""" initialize_for_megatron( tensor_model_parallel_size=tp_size, expert_model_parallel_size=ep_size, @@ -664,8 +664,8 @@ def _test_grouped_vs_non_grouped_quantize_helper(tp_size, ep_size, etp_size, ran def forward_fn(model): return megatron_prefill(model, prompt_tokens) - # Create grouped MoE model - grouped_moe_model = _gpt_model_provider( + # Create TEGrouped MoE model + te_grouped_moe_model = _gpt_model_provider( tp_size=tp_size, ep_size=ep_size, etp_size=etp_size, @@ -674,14 +674,14 @@ def forward_fn(model): use_te=True, num_moe_experts=4, ) - num_grouped_mlp = sum( - isinstance(module, TEGroupedMLP) for module in grouped_moe_model.modules() + num_te_grouped_mlp = sum( + isinstance(module, TEGroupedMLP) for module in te_grouped_moe_model.modules() ) - assert num_grouped_mlp == 4, ( - f"TEGrupedMoEModel has {num_grouped_mlp} TEGroupedMLP modules, it should have 4" + assert num_te_grouped_mlp == 4, ( + f"TEGrupedMoEModel has {num_te_grouped_mlp} TEGroupedMLP modules, it should have 4" ) - # Create non-grouped MoE model + # Create sequential MoE model sequential_moe_model = _gpt_model_provider( tp_size=tp_size, ep_size=ep_size, @@ -697,29 +697,29 @@ def forward_fn(model): f"SequentialMoEModel has {num_sequential_mlp} SequentialMLP modules, it should have 4" ) # Copy weights from grouped to non-grouped model - copy_weights_from_grouped_to_non_grouped(grouped_moe_model, sequential_moe_model) + copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_moe_model) # Compare model outputs before quantization - grouped_moe_output = forward_fn(grouped_moe_model) - non_grouped_moe_output = forward_fn(sequential_moe_model) - assert torch.allclose(grouped_moe_output, non_grouped_moe_output, atol=1e-6, rtol=1e-6) + te_grouped_moe_output = forward_fn(te_grouped_moe_model) + sequential_moe_output = forward_fn(sequential_moe_model) + assert torch.allclose(te_grouped_moe_output, sequential_moe_output, atol=1e-6, rtol=1e-6) # Quantize grouped model - mtq.quantize(grouped_moe_model, mtq.FP8_DEFAULT_CFG, forward_fn) + mtq.quantize(te_grouped_moe_model, mtq.FP8_DEFAULT_CFG, forward_fn) # Quantize non-grouped model mtq.quantize(sequential_moe_model, mtq.FP8_DEFAULT_CFG, forward_fn) # Compare model outputs after quantization - grouped_moe_quant_output = forward_fn(grouped_moe_model) - non_grouped_moe_quant_output = forward_fn(sequential_moe_model) + te_grouped_moe_quant_output = forward_fn(te_grouped_moe_model) + sequential_moe_quant_output = forward_fn(sequential_moe_model) assert torch.allclose( - grouped_moe_quant_output, non_grouped_moe_quant_output, atol=1e-6, rtol=1e-6 + te_grouped_moe_quant_output, sequential_moe_quant_output, atol=1e-6, rtol=1e-6 ) -def test_grouped_vs_non_grouped_quantize(): - """Test that grouped and non-grouped MoE models produce similar quantized models.""" +def test_te_grouped_vs_sequential_quantize(): + """Test that TEGrouped and sequential MoE models produce similar quantized models.""" size = torch.cuda.device_count() if size < 4: @@ -727,7 +727,7 @@ def test_grouped_vs_non_grouped_quantize(): spawn_multiprocess_job( size=size, - job=partial(_test_grouped_vs_non_grouped_quantize_helper, 1, 2, 2), + job=partial(_test_te_grouped_vs_sequential_quantize_helper, 1, 2, 2), backend="nccl", ) @@ -771,8 +771,8 @@ def forward_fn(model): if isinstance(module, mtq.nn.TensorQuantizer): # Check if this is an expert quantizer is_expert_quantizer = ( - "local_experts" in name # Non-grouped MoE - or ("experts" in name and "linear_fc" in name) # Grouped MoE + "local_experts" in name # sequential MoE + or ("experts" in name and "linear_fc" in name) # TEGrouped MoE ) if is_expert_quantizer and hasattr(module, "_amax"): From a917c2bb6c9f3549c811ea39dafda2ccb0decd6e Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 9 Oct 2025 16:49:19 +0000 Subject: [PATCH 18/28] updated parallel state for experts Signed-off-by: Kinjal Patel --- .../torch/quantization/plugins/megatron.py | 70 +++++++++++++------ .../quantization/plugins/test_megatron.py | 21 +++--- 2 files changed, 62 insertions(+), 29 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 3de8fe9b7..111bb446a 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -22,6 +22,7 @@ import megatron.core.parallel_state as mcore_parallel import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp +import megatron.core.transformer.moe.experts as megatron_moe import torch import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear from megatron.core.extensions import transformer_engine as megatron_te @@ -40,7 +41,7 @@ from modelopt.torch.utils.distributed import ParallelState from ..nn import QuantModuleRegistry, TensorQuantizer -from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear +from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, CUSTOM_POST_CALIBRATION_PLUGINS, _ParallelLinear @@ -520,29 +521,18 @@ def forward(self, input, *args, **kwargs): # Register the public te.pytorch.GroupedLinear class -@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear_public"}) +@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"}) class _QuantTEGroupedLinear(_MegatronParallelLinear): def _setup(self): - if not hasattr(self, "parallel_state") or self.parallel_state is None: - data_parallel_group = None - try: - data_parallel_group = get_data_parallel_group(with_context_parallel=True) - except AssertionError: - data_parallel_group = get_data_parallel_group() - - self.parallel_state = ParallelState( - data_parallel_group, - tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), - expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), - ) - self.input_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_input) - self.weight_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_weight) - self.output_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_output) - self.output_quantizer.disable() - + # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to + # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning + # self.weight0 to self.weight to run the quantizer states initialization. + self.weight = self.weight0 # Memorize the original weight.dtype for modelopt_post_restore given that # the dtype can change later. - self.original_weight_dtype = None if self.weight0 is None else self.weight0.dtype + super()._setup() + # Revert the weight to None after setup. + self.weight = None @property def functionals_to_replace(self): @@ -579,7 +569,7 @@ def modelopt_post_restore(self, prefix: str = ""): # self.weight0 to self.weight to run the quantizer states initialization. self.weight = self.weight0 super().modelopt_post_restore(prefix=prefix) - # Revert the weight to None after post_restore to avoid the weight being None during forward pass. + # Revert the weight to None after post_restore. self.weight = None def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): @@ -613,3 +603,41 @@ class _QuantTEGroupedColumnParallelLinear(_QuantTEGroupedLinear, _MegatronColumn ) class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronRowParallelLinear): _is_row_parallel = True + + +# Register the public megatron_moe.TEGroupedMLP class +@QuantModuleRegistry.register({megatron_moe.TEGroupedMLP: "megatron_moe_TEGroupedMLP"}) +class _QuantTEGroupedMLP(_MegatronMLP): + def _setup(self): + if not hasattr(self, "parallel_state") or self.parallel_state is None: + data_parallel_group = None + try: + data_parallel_group = get_data_parallel_group(with_context_parallel=True) + except AssertionError: + logger.warning( + "Context parallel group is not initialized, using data parallel group" + ) + data_parallel_group = get_data_parallel_group() + + self.parallel_state = ParallelState( + data_parallel_group, + tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), + expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), + ) + + +# Register the public megatron_moe.SequentialMLP class +@QuantModuleRegistry.register({megatron_moe.SequentialMLP: "megatron_moe_SequentialMLP"}) +class _QuantSequentialMLP(_MegatronMLP): + def _setup(self): + if not hasattr(self, "parallel_state") or self.parallel_state is None: + try: + data_parallel_group = mcore_parallel.get_expert_data_parallel_group() + except AssertionError: + data_parallel_group = None + + self.parallel_state = ParallelState( + data_parallel_group, + tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), + expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), + ) diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 856dd009e..88a4bb87b 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -621,17 +621,19 @@ def test_fp8_real_quantize(): mtq.NVFP4_DEFAULT_CFG, ], ) -def test_moe_sharded_state_dict(need_8_gpus, tmp_path, config): +@pytest.mark.parametrize("moe_grouped_gemm", [False, True]) +def test_moe_sharded_state_dict(tmp_path, config, moe_grouped_gemm): size = torch.cuda.device_count() - # TODO: Meta device doesn't work with TE # TODO: Add support for compress=True for TEGroupedMLP + if size < 4: + pytest.skip("Requires at least 4 GPUs for expert parallel test") moe_config = { - "tp_size": 2, + "tp_size": 1, "ep_size": 2, "etp_size": 2, "num_moe_experts": 4, - "moe_grouped_gemm": True, - "use_te": True, + "moe_grouped_gemm": moe_grouped_gemm, + "use_te": moe_grouped_gemm, } spawn_multiprocess_job( size=size, @@ -732,10 +734,12 @@ def test_te_grouped_vs_sequential_quantize(): ) -def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm, config, rank, size): +def _test_expert_model_parallel_amax_sync( + tp_size, ep_size, etp_size, moe_grouped_gemm, config, rank, size +): """Test expert parallel synchronization with different configurations.""" initialize_for_megatron( - tensor_model_parallel_size=1, + tensor_model_parallel_size=tp_size, pipeline_model_parallel_size=1, expert_model_parallel_size=ep_size, expert_tensor_parallel_size=etp_size, @@ -744,7 +748,7 @@ def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm, c # Create model with expert parallelism model = _gpt_model_provider( - tp_size=1, + tp_size=tp_size, ep_size=ep_size, etp_size=etp_size, hidden_size=256, @@ -805,6 +809,7 @@ def test_expert_parallel_sync(ep_size, etp_size, moe_grouped_gemm): size=size, job=partial( _test_expert_model_parallel_amax_sync, + 1, ep_size, etp_size, moe_grouped_gemm, From 1ea4ed1b42f2dbf3639c04582df68693a2a6f47d Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 9 Oct 2025 20:50:15 +0000 Subject: [PATCH 19/28] fixed bug for is_quantized_linear check Signed-off-by: Kinjal Patel --- .../torch/quantization/plugins/megatron.py | 26 ++++++++++++++----- modelopt/torch/quantization/utils.py | 6 +++-- .../quantization/plugins/test_megatron.py | 4 +-- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 111bb446a..274738516 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -522,7 +522,7 @@ def forward(self, input, *args, **kwargs): # Register the public te.pytorch.GroupedLinear class @QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"}) -class _QuantTEGroupedLinear(_MegatronParallelLinear): +class _QuantMegatronTEGroupedLinear(_MegatronParallelLinear): def _setup(self): # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning @@ -594,20 +594,24 @@ def _process_quantizer_amax(self, k, v, quantizer_state_dict): @QuantModuleRegistry.register( {megatron_te.TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"} ) -class _QuantTEGroupedColumnParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear): +class _MegatronTEGroupedColumnParallelLinear( + _QuantMegatronTEGroupedLinear, _MegatronColumnParallelLinear +): _is_column_parallel = True @QuantModuleRegistry.register( {megatron_te.TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"} ) -class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronRowParallelLinear): +class _MegatronTEGroupedRowParallelLinear( + _QuantMegatronTEGroupedLinear, _MegatronRowParallelLinear +): _is_row_parallel = True # Register the public megatron_moe.TEGroupedMLP class @QuantModuleRegistry.register({megatron_moe.TEGroupedMLP: "megatron_moe_TEGroupedMLP"}) -class _QuantTEGroupedMLP(_MegatronMLP): +class _MegatronTEGroupedMLP(_MegatronMLP): def _setup(self): if not hasattr(self, "parallel_state") or self.parallel_state is None: data_parallel_group = None @@ -619,16 +623,20 @@ def _setup(self): ) data_parallel_group = get_data_parallel_group() + try: + expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() + except AssertionError: + expert_tensor_parallel_group = None self.parallel_state = ParallelState( data_parallel_group, - tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), + tensor_parallel_group=expert_tensor_parallel_group, expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), ) # Register the public megatron_moe.SequentialMLP class @QuantModuleRegistry.register({megatron_moe.SequentialMLP: "megatron_moe_SequentialMLP"}) -class _QuantSequentialMLP(_MegatronMLP): +class _MegatronSequentialMLP(_MegatronMLP): def _setup(self): if not hasattr(self, "parallel_state") or self.parallel_state is None: try: @@ -636,8 +644,12 @@ def _setup(self): except AssertionError: data_parallel_group = None + try: + expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() + except AssertionError: + expert_tensor_parallel_group = None self.parallel_state = ParallelState( data_parallel_group, - tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), + tensor_parallel_group=expert_tensor_parallel_group, expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), ) diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 6167daf23..3287f6d91 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -251,8 +251,10 @@ def is_quantized_linear(module): isinstance(module, QuantModule) and isinstance(getattr(module, "input_quantizer", None), TensorQuantizer) and hasattr(module, "weight_quantizer") - and getattr(module, "weight", None) is not None - and module.weight.dim() == 2 + and ( + (getattr(module, "weight", None) is not None and module.weight.dim() == 2) + or (getattr(module, "weight0", None) is not None and module.weight0.dim() == 2) + ) ) diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 88a4bb87b..5ac3b5bfe 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -628,7 +628,7 @@ def test_moe_sharded_state_dict(tmp_path, config, moe_grouped_gemm): if size < 4: pytest.skip("Requires at least 4 GPUs for expert parallel test") moe_config = { - "tp_size": 1, + "tp_size": 2, "ep_size": 2, "etp_size": 2, "num_moe_experts": 4, @@ -809,7 +809,7 @@ def test_expert_parallel_sync(ep_size, etp_size, moe_grouped_gemm): size=size, job=partial( _test_expert_model_parallel_amax_sync, - 1, + 2, ep_size, etp_size, moe_grouped_gemm, From 169677c6471ffc40da484250f49995535eef8ffa Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Sat, 11 Oct 2025 03:43:25 +0000 Subject: [PATCH 20/28] code cleanup and bug fixes Signed-off-by: Kinjal Patel --- modelopt/torch/quantization/model_calib.py | 2 +- .../torch/quantization/plugins/megatron.py | 173 +++++++++--------- .../torch_dist/plugins/megatron_common.py | 9 + tests/gpu/torch/conftest.py | 6 + .../quantization/plugins/test_megatron.py | 20 +- 5 files changed, 105 insertions(+), 105 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 414e13e2d..175f31670 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -81,7 +81,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis return def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): - """Synchronize the amax across all ranks in the data parallel and context parallel groups.""" + """Synchronize the amax across all ranks in the data parallel and expert parallel groups.""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: sync_quantizer_amax_across_dp_ep(_q, parallel_state) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 274738516..0aa97fd21 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -52,37 +52,34 @@ def sync_amax_across_sequential_mlp(model: torch.nn.Module): """Sync amax across experts in a SequentialMLP.""" - amax_dict = { - "linear_fc1.input_quantizer": {}, - "linear_fc1.weight_quantizer": {}, - "linear_fc2.input_quantizer": {}, - "linear_fc2.weight_quantizer": {}, - } - # gather amax values from SequentialMLP experts - for name, module in model.named_modules(): + amax_dict = {} + + def get_sequential_mlp_expert_names(name: str, module: torch.nn.Module): if ( - not isinstance(module, TensorQuantizer) - or not hasattr(module, "_amax") - or "local_experts" not in name + isinstance(module, TensorQuantizer) + and hasattr(module, "_amax") + and ".local_experts." in name ): - continue - expert_name, local_expert_name = name.split("local_experts") - for key in amax_dict: - if key in local_expert_name: - amax_dict[key][expert_name] = max(amax_dict[key].get(expert_name, 0), module.amax) + expert_name, local_expert_name = name.split(".local_experts.") + # extract quantizer name by removing local_expert number from the name + local_expert_name = ".".join(local_expert_name.split(".")[1:]) + return expert_name, local_expert_name + return None, None + + # gather amax values from SequentialMLP experts + for name, module in model.named_modules(): + expert_name, local_expert_name = get_sequential_mlp_expert_names(name, module) + if expert_name and local_expert_name: + amax_dict[local_expert_name] = amax_dict.get(local_expert_name, {}) + amax_dict[local_expert_name][expert_name] = max( + amax_dict[local_expert_name].get(expert_name, 0), module.amax + ) # sync amax values across experts in SequentialMLP for name, module in model.named_modules(): - if ( - not isinstance(module, TensorQuantizer) - or not hasattr(module, "_amax") - or "local_experts" not in name - ): - continue - expert_name, local_expert_name = name.split("local_experts") - for key in amax_dict: - if key in local_expert_name: - module.amax = amax_dict[key][expert_name] + expert_name, local_expert_name = get_sequential_mlp_expert_names(name, module) + if expert_name and local_expert_name: + module.amax = amax_dict[local_expert_name][expert_name] CUSTOM_POST_CALIBRATION_PLUGINS.add(sync_amax_across_sequential_mlp) @@ -523,6 +520,11 @@ def forward(self, input, *args, **kwargs): # Register the public te.pytorch.GroupedLinear class @QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"}) class _QuantMegatronTEGroupedLinear(_MegatronParallelLinear): + _functionals_to_replace = [ + (te_grouped_linear._GroupedLinear, "forward"), + (te_grouped_linear._GroupedLinear, "apply"), + ] + def _setup(self): # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning @@ -531,37 +533,8 @@ def _setup(self): # Memorize the original weight.dtype for modelopt_post_restore given that # the dtype can change later. super()._setup() - # Revert the weight to None after setup. - self.weight = None - - @property - def functionals_to_replace(self): - original_forward = te_grouped_linear._GroupedLinear.forward - - def te_grouped_quantized_linear_fn(ctx, inp, m_splits, *args): - num_gemms = len(m_splits) - weights_and_biases = args[-2 * num_gemms :] - weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] - quantized_inputs = self.input_quantizer(inp) - quantized_weights = [self.weight_quantizer(weight) for weight in weights] - - output = original_forward( - ctx, - quantized_inputs, - m_splits, - *args[: -2 * num_gemms], - *quantized_weights, - *biases, - ) - return self.output_quantizer(output) - - return [ - ( - te_grouped_linear._GroupedLinear, - "forward", - te_grouped_quantized_linear_fn, - ), - ] + # Remove self.weight after setup. + delattr(self, "weight") def modelopt_post_restore(self, prefix: str = ""): # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to @@ -569,8 +542,8 @@ def modelopt_post_restore(self, prefix: str = ""): # self.weight0 to self.weight to run the quantizer states initialization. self.weight = self.weight0 super().modelopt_post_restore(prefix=prefix) - # Revert the weight to None after post_restore. - self.weight = None + # Remove self.weight after post_restore. + delattr(self, "weight") def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in @@ -585,10 +558,34 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs) def _process_quantizer_amax(self, k, v, quantizer_state_dict): - if v.ndim == 4: - quantizer_state_dict[k] = v.squeeze(1).squeeze(-1) - else: - quantizer_state_dict[k] = v.view(-1, 1) if v.numel() > 1 else v.view(-1) + assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization" + quantizer_state_dict[k] = v.view(-1) + + @staticmethod + def te_grouped_quantized_linear_fn(package, func_name, self, *args): + idx = 1 if func_name == "_forward" else 0 + inp = args[idx] + num_gemms = len(args[idx + 1]) + weights_and_biases = args[-2 * num_gemms :] + weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] + quantized_inputs = self.input_quantizer(inp) + quantized_weights = [self.weight_quantizer(weight) for weight in weights] + + output = getattr(package, func_name)( + *( + args[0], + quantized_inputs, + ) + if func_name == "_forward" + else (quantized_inputs,), + *args[idx + 1 : -2 * num_gemms], + *quantized_weights, + *biases, + ) + return self.output_quantizer(output) + + # Override the quantized linear function + _quantized_linear_fn = te_grouped_quantized_linear_fn @QuantModuleRegistry.register( @@ -614,24 +611,18 @@ class _MegatronTEGroupedRowParallelLinear( class _MegatronTEGroupedMLP(_MegatronMLP): def _setup(self): if not hasattr(self, "parallel_state") or self.parallel_state is None: - data_parallel_group = None - try: - data_parallel_group = get_data_parallel_group(with_context_parallel=True) - except AssertionError: - logger.warning( - "Context parallel group is not initialized, using data parallel group" - ) - data_parallel_group = get_data_parallel_group() - - try: - expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() - except AssertionError: - expert_tensor_parallel_group = None self.parallel_state = ParallelState( - data_parallel_group, - tensor_parallel_group=expert_tensor_parallel_group, - expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), + mcore_parallel.get_expert_data_parallel_group(check_initialized=False), + tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group( + check_initialized=False + ), + expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group( + check_initialized=False + ), ) + # initialize parallel state for submodules linear_fc1 and linear_fc2 + self.linear_fc1.parallel_state = self.parallel_state + self.linear_fc2.parallel_state = self.parallel_state # Register the public megatron_moe.SequentialMLP class @@ -639,17 +630,17 @@ def _setup(self): class _MegatronSequentialMLP(_MegatronMLP): def _setup(self): if not hasattr(self, "parallel_state") or self.parallel_state is None: - try: - data_parallel_group = mcore_parallel.get_expert_data_parallel_group() - except AssertionError: - data_parallel_group = None - - try: - expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() - except AssertionError: - expert_tensor_parallel_group = None self.parallel_state = ParallelState( - data_parallel_group, - tensor_parallel_group=expert_tensor_parallel_group, - expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), + mcore_parallel.get_expert_data_parallel_group(check_initialized=False), + tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group( + check_initialized=False + ), + expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group( + check_initialized=False + ), ) + + # Initialize parallel state for submodules local_experts.*.linear_fc1 and local_experts.*.linear_fc2 + for expert in self.local_experts: + expert.linear_fc1.parallel_state = self.parallel_state + expert.linear_fc2.parallel_state = self.parallel_state diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index c2a9fb745..02629ecf0 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -588,6 +588,15 @@ def compare_amax_sync_across_expert_parallel(model): if quantizer_type not in expert_quantizers: expert_quantizers[quantizer_type] = {} + if ( + quantizer_type in expert_quantizers + and rank_idx in expert_quantizers[quantizer_type] + ): + # compare expert value across expert for sequential MoE + assert expert_quantizers[quantizer_type][rank_idx] == amax_val, ( + f"{rank_idx}, {quantizer_type}, expert_quantizers[quantizer_type][rank_idx]: " + f"{expert_quantizers[quantizer_type][rank_idx]}, amax_val: {amax_val}" + ) expert_quantizers[quantizer_type][rank_idx] = amax_val # Check synchronization - fail fast on first inconsistency diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index f32065bce..d1ba9dd47 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -40,6 +40,12 @@ def need_8_gpus(): pytest.skip("Need at least 8 GPUs to run this test") +@pytest.fixture +def need_4_gpus(): + if torch.cuda.device_count() < 4: + pytest.skip("Need at least 4 GPUs to run this test") + + @pytest.fixture(scope="module") def set_torch_dtype(request): orig_dtype = torch.get_default_dtype() diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 5ac3b5bfe..6da3732e4 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -35,7 +35,6 @@ auto_quantize_helper, data_tensor_context_parallel_test_helper, dp_cp_parallel_test_helper, - tensor_parallel_test_helper, ) skip_if_no_megatron() @@ -621,12 +620,10 @@ def test_fp8_real_quantize(): mtq.NVFP4_DEFAULT_CFG, ], ) -@pytest.mark.parametrize("moe_grouped_gemm", [False, True]) -def test_moe_sharded_state_dict(tmp_path, config, moe_grouped_gemm): +@pytest.mark.parametrize("moe_grouped_gemm", [True, False]) +def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm): size = torch.cuda.device_count() # TODO: Add support for compress=True for TEGroupedMLP - if size < 4: - pytest.skip("Requires at least 4 GPUs for expert parallel test") moe_config = { "tp_size": 2, "ep_size": 2, @@ -720,13 +717,9 @@ def forward_fn(model): ) -def test_te_grouped_vs_sequential_quantize(): +def test_te_grouped_vs_sequential_quantize(need_4_gpus): """Test that TEGrouped and sequential MoE models produce similar quantized models.""" - size = torch.cuda.device_count() - if size < 4: - pytest.skip("Requires at least 4 GPUs for expert parallel test") - spawn_multiprocess_job( size=size, job=partial(_test_te_grouped_vs_sequential_quantize_helper, 1, 2, 2), @@ -763,7 +756,6 @@ def forward_fn(model): # quantize the model model = mtq.quantize(model, config, forward_fn) - # Check initial sync status initial_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model) assert initial_sync, ( @@ -790,8 +782,10 @@ def forward_fn(model): "Consistent amax across expert parallel ranks, " "Amax should not be synchronized across expert parallel ranks since expert parallel is disabled" ) - # Re-enable parallel groups and test synchronization - mtq.model_calib.max_calibrate(model, forward_fn) + # Re-calibrate the model and test synchronization + mtq.mode.wrapped_calib_func( + model, mtq.config.MaxCalibConfig(), forward_fn, mtq.model_calib.max_calibrate + ) final_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model) assert final_sync, f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}" From 153e376243ff2db37b6aefa571c71d049ace9967 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Sat, 11 Oct 2025 03:47:55 +0000 Subject: [PATCH 21/28] rebase bug fixes Signed-off-by: Kinjal Patel --- .../torch_quantization/quantize_common.py | 97 ------------------- .../quantization/plugins/test_megatron.py | 83 ---------------- 2 files changed, 180 deletions(-) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 94e0ccc3c..8647aaa00 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -210,103 +210,6 @@ def forward_loop(model): ) -@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) -def dp_cp_parallel_test_helper(model, config, group, mock_awq_lite): - calib_data = model.get_dummy_input().cuda() - - def forward_loop(model): - model(calib_data) - - model = mtq.quantize(model, config, forward_loop) - - # Sanity check - forward_loop(model) - - # Input quantizer amax - if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) - _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) - - # Weight quantizer amax - if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): - for quantizer in model.fc1.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) - else: - _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) - if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): - for quantizer in model.fc2.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) - else: - _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) - - if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: - # Check act scale - _reduce_quantizer_attr( - model.fc1.awq_lite, - "act_scale", - dist.ReduceOp.AVG, - group=group, - ) - _reduce_quantizer_attr( - model.fc2.awq_lite, - "act_scale", - dist.ReduceOp.AVG, - group=group, - ) - - -@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) -def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, mock_awq_lite): - # Calib data should be same across each DP rank - dp_rank = dist.get_rank(group=dp_group) - calib_data = model.get_dummy_input(seed=dp_rank).cuda() - - def forward_loop(model): - model(calib_data) - - model = mtq.quantize(model, config, forward_loop) - - def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): - quantizer_attr = getattr(quantizer, attr).clone() - - # Perform all-reduce operations - dist.all_reduce(quantizer_attr, op=op, group=tp_group) - - dist.all_reduce(quantizer_attr, op=op, group=dp_group) - - assert torch.allclose(quantizer_attr, getattr(quantizer, attr)), getattr(quantizer, attr) - - # Input quantizer amax - if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) - _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX) - - # Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks - # Channel-wise (INT8) only expects same amax across row parallel ranks - # Block-wise quantization does not expect same amax across row and column parallel ranks - if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]: - if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): - for quantizer in model.fc1.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) - else: - _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) - - if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT8_DEFAULT_CFG]: - if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): - for quantizer in model.fc2.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) - else: - _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) - - # Check act scale - if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: - _reduce_quantizer_attr( - model.fc1.awq_lite, - "act_scale", - dist.ReduceOp.AVG, - ) - - def auto_quantize_helper(model): model, search_state = mtq.auto_quantize( model, diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 6da3732e4..c0ed237d7 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -231,89 +231,6 @@ def test_data_tensor_context_parallel(need_8_gpus, config): ) -# 2. Data Parallel Test -def _test_data_parallel_helper(config, rank, size): - initialize_for_megatron(seed=SEED + rank) # modify seed so data is different across ranks - model = MegatronModel().cuda() - - dp_cp_parallel_test_helper(model, config, get_data_parallel_group()) - - -@pytest.mark.parametrize( - "config", - [ - mtq.INT8_DEFAULT_CFG, - mtq.FP8_DEFAULT_CFG, - mtq.W4A8_AWQ_BETA_CFG, - mtq.INT8_SMOOTHQUANT_CFG, - mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, - mtq.INT4_AWQ_CFG, - mtq.NVFP4_DEFAULT_CFG, - ], -) -def test_data_parallel(need_2_gpus, config): - spawn_multiprocess_job(size=2, job=partial(_test_data_parallel_helper, config), backend="nccl") - - -# 3. Context Parallel Test -def _test_context_parallel_helper(config, rank, size): - initialize_for_megatron( - context_parallel_size=size, seed=SEED + rank - ) # modify seed so data is different across ranks - model = MegatronModel(cp_size=size).cuda() - - dp_cp_parallel_test_helper(model, config, get_data_parallel_group(with_context_parallel=True)) - - -@pytest.mark.parametrize( - "config", - [ - mtq.INT8_DEFAULT_CFG, - mtq.FP8_DEFAULT_CFG, - mtq.W4A8_AWQ_BETA_CFG, - mtq.INT8_SMOOTHQUANT_CFG, - mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, - mtq.INT4_AWQ_CFG, - mtq.NVFP4_DEFAULT_CFG, - ], -) -def test_context_parallel(need_2_gpus, config): - spawn_multiprocess_job( - size=2, job=partial(_test_context_parallel_helper, config), backend="nccl" - ) - - -# 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs) -def _test_data_tensor_context_parallel_helper(config, rank, size): - initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED + rank) - model = MegatronModel(tp_size=2, cp_size=2).cuda() - - data_tensor_context_parallel_test_helper( - model, - config, - get_data_parallel_group(with_context_parallel=True), - get_tensor_model_parallel_group(), - ) - - -@pytest.mark.parametrize( - "config", - [ - mtq.INT8_DEFAULT_CFG, - mtq.FP8_DEFAULT_CFG, - mtq.W4A8_AWQ_BETA_CFG, - mtq.INT8_SMOOTHQUANT_CFG, - mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, - mtq.INT4_AWQ_CFG, - mtq.NVFP4_DEFAULT_CFG, - ], -) -def test_data_tensor_context_parallel(need_8_gpus, config): - spawn_multiprocess_job( - size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" - ) - - def _gpt_model_provider( tp_size: int, hidden_size=256, From 5bc99e0cf26193a3dec5d5a31f000482f6c2f094 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Sat, 11 Oct 2025 04:46:48 +0000 Subject: [PATCH 22/28] fixing test and comments Signed-off-by: Kinjal Patel --- .../torch/quantization/plugins/megatron.py | 23 +++++++++---------- .../torch_dist/plugins/megatron_common.py | 5 ++-- .../quantization/plugins/test_megatron.py | 1 - 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 0aa97fd21..caa2cb6f0 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -28,8 +28,6 @@ from megatron.core.extensions import transformer_engine as megatron_te from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region -from megatron.core.parallel_state import get_data_parallel_group -from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint from megatron.core.utils import get_tensor_model_parallel_group_if_none @@ -63,23 +61,24 @@ def get_sequential_mlp_expert_names(name: str, module: torch.nn.Module): expert_name, local_expert_name = name.split(".local_experts.") # extract quantizer name by removing local_expert number from the name local_expert_name = ".".join(local_expert_name.split(".")[1:]) - return expert_name, local_expert_name - return None, None + return f"{expert_name}.{local_expert_name}" + return None # gather amax values from SequentialMLP experts for name, module in model.named_modules(): - expert_name, local_expert_name = get_sequential_mlp_expert_names(name, module) - if expert_name and local_expert_name: - amax_dict[local_expert_name] = amax_dict.get(local_expert_name, {}) - amax_dict[local_expert_name][expert_name] = max( - amax_dict[local_expert_name].get(expert_name, 0), module.amax + expert_name = get_sequential_mlp_expert_names(name, module) + if expert_name and module.amax is not None: + stored_amax = amax_dict.get(expert_name) + amax_tensor = module.amax.detach().clone() + amax_dict[expert_name] = ( + amax_tensor if stored_amax is None else torch.maximum(stored_amax, amax_tensor) ) # sync amax values across experts in SequentialMLP for name, module in model.named_modules(): - expert_name, local_expert_name = get_sequential_mlp_expert_names(name, module) - if expert_name and local_expert_name: - module.amax = amax_dict[local_expert_name][expert_name] + expert_name = get_sequential_mlp_expert_names(name, module) + if expert_name and module.amax is not None: + module.amax = amax_dict[expert_name].detach().clone().to(module.amax.device) CUSTOM_POST_CALIBRATION_PLUGINS.add(sync_amax_across_sequential_mlp) diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 02629ecf0..8f47d060a 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -562,9 +562,8 @@ def compare_amax_sync_across_expert_parallel(model): if isinstance(module, mtq.nn.TensorQuantizer) and hasattr(module, "_amax"): # Check for both TEGrouped and sequential MoE patterns if "local_experts" in name or ("experts" in name and "linear_fc" in name): - expert_amax_values[name] = ( - module.amax.item() if hasattr(module.amax, "item") else module.amax - ) + amax_val = module.amax.item() if hasattr(module.amax, "item") else module.amax + expert_amax_values[name] = amax_val # Early return if no expert quantizers found assert expert_amax_values, "No expert quantizers found" diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index c0ed237d7..5cd5d88fc 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -34,7 +34,6 @@ from _test_utils.torch_quantization.quantize_common import ( auto_quantize_helper, data_tensor_context_parallel_test_helper, - dp_cp_parallel_test_helper, ) skip_if_no_megatron() From 23daf38254c8f45ac082ad682c68e77880f6a85e Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Mon, 13 Oct 2025 20:32:16 +0000 Subject: [PATCH 23/28] Code cleanup Signed-off-by: Kinjal Patel --- .../torch/quantization/plugins/megatron.py | 20 ++++--------- .../torch_dist/plugins/megatron_common.py | 28 ++++++++++--------- .../quantization/plugins/test_megatron.py | 27 ++++++------------ 3 files changed, 29 insertions(+), 46 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index caa2cb6f0..af1e4a327 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -611,13 +611,9 @@ class _MegatronTEGroupedMLP(_MegatronMLP): def _setup(self): if not hasattr(self, "parallel_state") or self.parallel_state is None: self.parallel_state = ParallelState( - mcore_parallel.get_expert_data_parallel_group(check_initialized=False), - tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group( - check_initialized=False - ), - expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group( - check_initialized=False - ), + mcore_parallel.get_expert_data_parallel_group(), + tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), + expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), ) # initialize parallel state for submodules linear_fc1 and linear_fc2 self.linear_fc1.parallel_state = self.parallel_state @@ -630,13 +626,9 @@ class _MegatronSequentialMLP(_MegatronMLP): def _setup(self): if not hasattr(self, "parallel_state") or self.parallel_state is None: self.parallel_state = ParallelState( - mcore_parallel.get_expert_data_parallel_group(check_initialized=False), - tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group( - check_initialized=False - ), - expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group( - check_initialized=False - ), + mcore_parallel.get_expert_data_parallel_group(), + tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), + expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), ) # Initialize parallel state for submodules local_experts.*.linear_fc1 and local_experts.*.linear_fc2 diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 8f47d060a..8d31a0aba 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -515,20 +515,21 @@ def copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_mo # Map grouped weights to sequential weights weight_mapping = {} - sequential_key_template = "decoder.layers.{}.mlp.experts.local_experts.{}.linear_fc{}.weight" + sequential_key_template = "decoder.layers.{}.mlp.experts.local_experts.{}.linear_fc{}" for key, value in te_grouped_state.items(): - if "experts.linear_fc" in key and "weight" in key: + if "experts.linear_fc" in key and any(param in key for param in ("weight", "bias")): # Extract expert index from grouped weight name # Format: decoder.layers.X.mlp.experts.linear_fcY.weightZ parts = key.split(".") layer_idx = parts[2] # X fc_idx = parts[5] # Y (linear_fc1 or linear_fc2) - weight_idx = parts[6] # Z (weight0, weight1, etc.) - - # Map to sequential format: decoder.layers.X.mlp.experts.local_experts.Y.linear_fcZ.weight - expert_idx = weight_idx.replace("weight", "") + param_idx = parts[6] # weight0 / bias0 / etc. + match = re.search(r"\d+", param_idx) + expert_idx = match.group(0) if match else "0" # Z for expert index + # Map to sequential format: decoder.layers.X.mlp.experts.local_experts.Y.linear_fcZ sequential_key = sequential_key_template.format(layer_idx, expert_idx, fc_idx[-1]) - weight_mapping[sequential_key] = value + param_name = "weight" if "weight" in param_idx else "bias" + weight_mapping[f"{sequential_key}.{param_name}"] = value elif isinstance(value, torch.Tensor): weight_mapping[key] = value @@ -540,7 +541,7 @@ def copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_mo sequential_moe_model.load_state_dict(sequential_state) -def compare_amax_sync_across_expert_parallel(model): +def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True): """ Test if amax values are synchronized across expert parallel groups. @@ -591,11 +592,12 @@ def compare_amax_sync_across_expert_parallel(model): quantizer_type in expert_quantizers and rank_idx in expert_quantizers[quantizer_type] ): - # compare expert value across expert for sequential MoE - assert expert_quantizers[quantizer_type][rank_idx] == amax_val, ( - f"{rank_idx}, {quantizer_type}, expert_quantizers[quantizer_type][rank_idx]: " - f"{expert_quantizers[quantizer_type][rank_idx]}, amax_val: {amax_val}" - ) + if compare_across_experts: + # compare expert value across expert for sequential MoE + assert expert_quantizers[quantizer_type][rank_idx] == amax_val, ( + f"{rank_idx}, {quantizer_type}, expert_quantizers[quantizer_type][rank_idx]: " + f"{expert_quantizers[quantizer_type][rank_idx]}, amax_val: {amax_val}" + ) expert_quantizers[quantizer_type][rank_idx] = amax_val # Check synchronization - fail fast on first inconsistency diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 5cd5d88fc..52ac4ce51 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -677,31 +677,20 @@ def forward_fn(model): assert initial_sync, ( f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}" ) - # Create inconsistent amax values - cur_rank = torch.distributed.get_rank() - for name, module in model.named_modules(): - if isinstance(module, mtq.nn.TensorQuantizer): - # Check if this is an expert quantizer - is_expert_quantizer = ( - "local_experts" in name # sequential MoE - or ("experts" in name and "linear_fc" in name) # TEGrouped MoE - ) - if is_expert_quantizer and hasattr(module, "_amax"): - # Create rank-specific amax values to simulate missing sync - rank_offset = cur_rank * 0.1 - module.amax = module.amax + rank_offset + # Test if the amax values are inconsistent when distributed sync is disabled + mtq.model_calib.max_calibrate(model, forward_fn, distributed_sync=False) + inconsistent_amax, _, _ = compare_amax_sync_across_expert_parallel( + model, compare_across_experts=False + ) - # Test if the amax values are inconsistent - inconsistent_amax, _, _ = compare_amax_sync_across_expert_parallel(model) assert not inconsistent_amax, ( "Consistent amax across expert parallel ranks, " "Amax should not be synchronized across expert parallel ranks since expert parallel is disabled" ) - # Re-calibrate the model and test synchronization - mtq.mode.wrapped_calib_func( - model, mtq.config.MaxCalibConfig(), forward_fn, mtq.model_calib.max_calibrate - ) + # calibrate the model with distributed sync and test synchronization + mtq.model_calib.max_calibrate(model, forward_fn, distributed_sync=True) + mtq.plugins.megatron.sync_amax_across_sequential_mlp(model) final_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model) assert final_sync, f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}" From 15ffb87213b2253c9454c48851e5e470a1ed4d0f Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Tue, 14 Oct 2025 21:49:23 +0000 Subject: [PATCH 24/28] Code cleanup and test update Signed-off-by: Kinjal Patel --- modelopt/torch/quantization/model_calib.py | 4 + .../torch/quantization/plugins/megatron.py | 225 +++++++----------- .../plugins/transformer_engine.py | 56 +++++ .../torch_dist/plugins/megatron_common.py | 31 ++- .../quantization/plugins/test_megatron.py | 10 +- 5 files changed, 175 insertions(+), 151 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 175f31670..1abcddd12 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -176,6 +176,10 @@ def sync_quantizer_amax_across_tp( parallel_state=module.parallel_state, ) + for name, module in model.named_modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() + def enable_stats_collection(model: nn.Module): """Enable stats collection for all quantizers in the model.""" diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index af1e4a327..39054e5c3 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -24,8 +24,6 @@ import megatron.core.transformer.mlp as megatron_mlp import megatron.core.transformer.moe.experts as megatron_moe import torch -import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear -from megatron.core.extensions import transformer_engine as megatron_te from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule @@ -41,47 +39,23 @@ from ..nn import QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper -from .custom import CUSTOM_MODEL_PLUGINS, CUSTOM_POST_CALIBRATION_PLUGINS, _ParallelLinear +from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear -logger = logging.getLogger(__name__) - -__all__ = [] - - -def sync_amax_across_sequential_mlp(model: torch.nn.Module): - """Sync amax across experts in a SequentialMLP.""" - amax_dict = {} - - def get_sequential_mlp_expert_names(name: str, module: torch.nn.Module): - if ( - isinstance(module, TensorQuantizer) - and hasattr(module, "_amax") - and ".local_experts." in name - ): - expert_name, local_expert_name = name.split(".local_experts.") - # extract quantizer name by removing local_expert number from the name - local_expert_name = ".".join(local_expert_name.split(".")[1:]) - return f"{expert_name}.{local_expert_name}" - return None +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelGroupedLinear, + TERowParallelGroupedLinear, + ) - # gather amax values from SequentialMLP experts - for name, module in model.named_modules(): - expert_name = get_sequential_mlp_expert_names(name, module) - if expert_name and module.amax is not None: - stored_amax = amax_dict.get(expert_name) - amax_tensor = module.amax.detach().clone() - amax_dict[expert_name] = ( - amax_tensor if stored_amax is None else torch.maximum(stored_amax, amax_tensor) - ) + from .transformer_engine import _QuantTEGroupedLinear - # sync amax values across experts in SequentialMLP - for name, module in model.named_modules(): - expert_name = get_sequential_mlp_expert_names(name, module) - if expert_name and module.amax is not None: - module.amax = amax_dict[expert_name].detach().clone().to(module.amax.device) + HAS_TE = True +except ImportError: + HAS_TE = False +logger = logging.getLogger(__name__) -CUSTOM_POST_CALIBRATION_PLUGINS.add(sync_amax_across_sequential_mlp) +__all__ = [] def real_quant_module_get_extra_state(self) -> dict: @@ -516,111 +490,6 @@ def forward(self, input, *args, **kwargs): return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs) -# Register the public te.pytorch.GroupedLinear class -@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"}) -class _QuantMegatronTEGroupedLinear(_MegatronParallelLinear): - _functionals_to_replace = [ - (te_grouped_linear._GroupedLinear, "forward"), - (te_grouped_linear._GroupedLinear, "apply"), - ] - - def _setup(self): - # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to - # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning - # self.weight0 to self.weight to run the quantizer states initialization. - self.weight = self.weight0 - # Memorize the original weight.dtype for modelopt_post_restore given that - # the dtype can change later. - super()._setup() - # Remove self.weight after setup. - delattr(self, "weight") - - def modelopt_post_restore(self, prefix: str = ""): - # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to - # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning - # self.weight0 to self.weight to run the quantizer states initialization. - self.weight = self.weight0 - super().modelopt_post_restore(prefix=prefix) - # Remove self.weight after post_restore. - delattr(self, "weight") - - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): - # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in - # sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for - # TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] - # for modelopt checkpoint restore - filtered_state_dict = { - k: v - for k, v in state_dict.items() - if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms)) - } - return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs) - - def _process_quantizer_amax(self, k, v, quantizer_state_dict): - assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization" - quantizer_state_dict[k] = v.view(-1) - - @staticmethod - def te_grouped_quantized_linear_fn(package, func_name, self, *args): - idx = 1 if func_name == "_forward" else 0 - inp = args[idx] - num_gemms = len(args[idx + 1]) - weights_and_biases = args[-2 * num_gemms :] - weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] - quantized_inputs = self.input_quantizer(inp) - quantized_weights = [self.weight_quantizer(weight) for weight in weights] - - output = getattr(package, func_name)( - *( - args[0], - quantized_inputs, - ) - if func_name == "_forward" - else (quantized_inputs,), - *args[idx + 1 : -2 * num_gemms], - *quantized_weights, - *biases, - ) - return self.output_quantizer(output) - - # Override the quantized linear function - _quantized_linear_fn = te_grouped_quantized_linear_fn - - -@QuantModuleRegistry.register( - {megatron_te.TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"} -) -class _MegatronTEGroupedColumnParallelLinear( - _QuantMegatronTEGroupedLinear, _MegatronColumnParallelLinear -): - _is_column_parallel = True - - -@QuantModuleRegistry.register( - {megatron_te.TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"} -) -class _MegatronTEGroupedRowParallelLinear( - _QuantMegatronTEGroupedLinear, _MegatronRowParallelLinear -): - _is_row_parallel = True - - -# Register the public megatron_moe.TEGroupedMLP class -@QuantModuleRegistry.register({megatron_moe.TEGroupedMLP: "megatron_moe_TEGroupedMLP"}) -class _MegatronTEGroupedMLP(_MegatronMLP): - def _setup(self): - if not hasattr(self, "parallel_state") or self.parallel_state is None: - self.parallel_state = ParallelState( - mcore_parallel.get_expert_data_parallel_group(), - tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), - expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), - ) - # initialize parallel state for submodules linear_fc1 and linear_fc2 - self.linear_fc1.parallel_state = self.parallel_state - self.linear_fc2.parallel_state = self.parallel_state - - -# Register the public megatron_moe.SequentialMLP class @QuantModuleRegistry.register({megatron_moe.SequentialMLP: "megatron_moe_SequentialMLP"}) class _MegatronSequentialMLP(_MegatronMLP): def _setup(self): @@ -635,3 +504,73 @@ def _setup(self): for expert in self.local_experts: expert.linear_fc1.parallel_state = self.parallel_state expert.linear_fc2.parallel_state = self.parallel_state + + def sync_moe_local_experts_amax(self): + """Sync amax across experts in a SequentialMLP.""" + amax_dict = {} + # gather amax values from SequentialMLP experts + for expert in self.local_experts: + for name, module in expert.named_modules(): + if isinstance(module, TensorQuantizer) and module.amax is not None: + stored_amax = amax_dict.get(name) + amax_tensor = module.amax.detach().clone() + amax_dict[name] = ( + amax_tensor + if stored_amax is None + else torch.maximum(stored_amax, amax_tensor) + ) + + # sync amax values across experts in SequentialMLP + for expert in self.local_experts: + for name, module in expert.named_modules(): + if isinstance(module, TensorQuantizer) and module.amax is not None: + module.amax = amax_dict[name].detach().clone().to(module.amax.device) + + +if HAS_TE: + # Quantized subclasses to support TEGroupedMLP quantization + class _QuantMegatronTEGroupedLinear(_QuantTEGroupedLinear, _MegatronParallelLinear): + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in + # sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for + # TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] + # for modelopt checkpoint restore + filtered_state_dict = { + k: v + for k, v in state_dict.items() + if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms)) + } + return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs) + + def _process_quantizer_amax(self, k, v, quantizer_state_dict): + assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization" + quantizer_state_dict[k] = v.view(-1) + + @QuantModuleRegistry.register( + {TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"} + ) + class _MegatronTEGroupedColumnParallelLinear( + _QuantMegatronTEGroupedLinear, _MegatronColumnParallelLinear + ): + pass + + @QuantModuleRegistry.register( + {TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"} + ) + class _MegatronTEGroupedRowParallelLinear( + _QuantMegatronTEGroupedLinear, _MegatronRowParallelLinear + ): + pass + + @QuantModuleRegistry.register({megatron_moe.TEGroupedMLP: "megatron_moe_TEGroupedMLP"}) + class _MegatronTEGroupedMLP(_MegatronMLP): + def _setup(self): + if not hasattr(self, "parallel_state") or self.parallel_state is None: + self.parallel_state = ParallelState( + mcore_parallel.get_expert_data_parallel_group(), + tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), + expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), + ) + # initialize parallel state for submodules linear_fc1 and linear_fc2 + self.linear_fc1.parallel_state = self.parallel_state + self.linear_fc2.parallel_state = self.parallel_state diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index b068ebca7..41a6c8321 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -17,6 +17,7 @@ import torch import transformer_engine as te +import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear import transformer_engine.pytorch.module.linear as te_linear from ..nn import QuantModuleRegistry @@ -58,3 +59,58 @@ def te_quantized_linear_fn(package, func_name, self, *args, **kwargs): # Override the quantized linear function _quantized_linear_fn = te_quantized_linear_fn + + +# Register the public te.pytorch.GroupedLinear class +@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"}) +class _QuantTEGroupedLinear(_ParallelLinear): + _functionals_to_replace = [ + (te_grouped_linear._GroupedLinear, "forward"), + (te_grouped_linear._GroupedLinear, "apply"), + ] + + def _setup(self): + # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to + # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning + # self.weight0 to self.weight to run the quantizer states initialization. + self.weight = self.weight0 + # Memorize the original weight.dtype for modelopt_post_restore given that + # the dtype can change later. + super()._setup() + # Remove self.weight after setup. + delattr(self, "weight") + + def modelopt_post_restore(self, prefix: str = ""): + # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to + # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning + # self.weight0 to self.weight to run the quantizer states initialization. + self.weight = self.weight0 + super().modelopt_post_restore(prefix=prefix) + # Remove self.weight after post_restore. + delattr(self, "weight") + + @staticmethod + def te_grouped_quantized_linear_fn(package, func_name, self, *args): + idx = 1 if func_name == "_forward" else 0 + inp = args[idx] + num_gemms = len(args[idx + 1]) + weights_and_biases = args[-2 * num_gemms :] + weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] + quantized_inputs = self.input_quantizer(inp) + quantized_weights = [self.weight_quantizer(weight) for weight in weights] + + output = getattr(package, func_name)( + *( + args[0], + quantized_inputs, + ) + if func_name == "_forward" + else (quantized_inputs,), + *args[idx + 1 : -2 * num_gemms], + *quantized_weights, + *biases, + ) + return self.output_quantizer(output) + + # Override the quantized linear function + _quantized_linear_fn = te_grouped_quantized_linear_fn diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 8d31a0aba..c9562ac96 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -190,7 +190,7 @@ def squared_relu(x): pipeline_model_parallel_size=pipeline_model_parallel_size, expert_model_parallel_size=expert_model_parallel_size, expert_tensor_parallel_size=expert_tensor_parallel_size, - sequence_parallel=False, + sequence_parallel=expert_model_parallel_size > 1, moe_grouped_gemm=moe_grouped_gemm, num_layers=num_layers, num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, @@ -215,7 +215,8 @@ def squared_relu(x): num_experts=num_moe_experts, normalization=normalization, moe_grouped_gemm=moe_grouped_gemm, - use_te=use_te, + # TODO: uncomment this when TEGroupedMLP is enabled in Megatron-LM + # use_te=use_te, ) else: assert HAS_TE, "Transformer Engine not installed" @@ -563,7 +564,8 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True) if isinstance(module, mtq.nn.TensorQuantizer) and hasattr(module, "_amax"): # Check for both TEGrouped and sequential MoE patterns if "local_experts" in name or ("experts" in name and "linear_fc" in name): - amax_val = module.amax.item() if hasattr(module.amax, "item") else module.amax + # Convert to scalar only if tensor has a single element + amax_val = module.amax.detach().clone().cpu() expert_amax_values[name] = amax_val # Early return if no expert quantizers found @@ -594,7 +596,13 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True) ): if compare_across_experts: # compare expert value across expert for sequential MoE - assert expert_quantizers[quantizer_type][rank_idx] == amax_val, ( + prev_val = expert_quantizers[quantizer_type][rank_idx] + # Handle both scalar and tensor comparisons + if isinstance(amax_val, torch.Tensor) and isinstance(prev_val, torch.Tensor): + are_equal = torch.allclose(prev_val, amax_val, rtol=1e-6, atol=1e-6) + else: + are_equal = prev_val == amax_val + assert are_equal, ( f"{rank_idx}, {quantizer_type}, expert_quantizers[quantizer_type][rank_idx]: " f"{expert_quantizers[quantizer_type][rank_idx]}, amax_val: {amax_val}" ) @@ -604,8 +612,17 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True) for quantizer_type, rank_values in expert_quantizers.items(): if len(rank_values) > 1: # Only check if we have multiple ranks values = list(rank_values.values()) - max_diff = max(values) - min(values) - if max_diff > 1e-6: # Allow for small floating point differences - return False, quantizer_type, rank_values + # Handle both scalar and tensor comparisons + first_val = values[0] + if isinstance(first_val, torch.Tensor): + # For tensors, check if all values are close to the first one + for val in values[1:]: + if not torch.allclose(first_val, val, rtol=1e-6, atol=1e-6): + return False, quantizer_type, rank_values + else: + # For scalars, use numeric comparison + max_diff = max(values) - min(values) + if max_diff > 1e-6: # Allow for small floating point differences + return False, quantizer_type, rank_values return True, None, None diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 52ac4ce51..bf27fdd11 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -538,6 +538,8 @@ def test_fp8_real_quantize(): ) @pytest.mark.parametrize("moe_grouped_gemm", [True, False]) def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm): + if moe_grouped_gemm: + pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently") size = torch.cuda.device_count() # TODO: Add support for compress=True for TEGroupedMLP moe_config = { @@ -635,6 +637,7 @@ def forward_fn(model): def test_te_grouped_vs_sequential_quantize(need_4_gpus): """Test that TEGrouped and sequential MoE models produce similar quantized models.""" + pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently") size = torch.cuda.device_count() spawn_multiprocess_job( size=size, @@ -690,7 +693,9 @@ def forward_fn(model): ) # calibrate the model with distributed sync and test synchronization mtq.model_calib.max_calibrate(model, forward_fn, distributed_sync=True) - mtq.plugins.megatron.sync_amax_across_sequential_mlp(model) + for module in model.modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() final_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model) assert final_sync, f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}" @@ -704,6 +709,9 @@ def test_expert_parallel_sync(ep_size, etp_size, moe_grouped_gemm): if size < ep_size * etp_size: pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test") + if moe_grouped_gemm: + pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently") + spawn_multiprocess_job( size=size, job=partial( From 28c8bbf578649b64863da5519b97ae72247d302b Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Tue, 14 Oct 2025 23:59:03 +0000 Subject: [PATCH 25/28] remove post calib hook Signed-off-by: Kinjal Patel --- modelopt/torch/quantization/mode.py | 3 --- modelopt/torch/quantization/plugins/custom.py | 7 ------- 2 files changed, 10 deletions(-) diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 5943a991d..4e6e9fd49 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -208,8 +208,6 @@ def wrapped_calib_func( forward_loop and the relevant kwargs and are independent of the ModelOpt framework. So lets wrap them to be compatible with the ModelOpt convert entrypoint. """ - from .plugins.custom import register_custom_post_calibration_plugins - kwargs = config.model_dump() method = kwargs.pop("method") if method is not None and "awq" in method: @@ -220,7 +218,6 @@ def wrapped_calib_func( # Call the function with forward_loop as a separate argument func(model, forward_loop=forward_loop, **kwargs) - register_custom_post_calibration_plugins(model) # Lets get the latest metadata for the quantizer states metadata = {} update_quantize_metadata(model, config, metadata) diff --git a/modelopt/torch/quantization/plugins/custom.py b/modelopt/torch/quantization/plugins/custom.py index 38317bd52..4227f3c49 100644 --- a/modelopt/torch/quantization/plugins/custom.py +++ b/modelopt/torch/quantization/plugins/custom.py @@ -30,7 +30,6 @@ CUSTOM_MODEL_PLUGINS = set() CUSTOM_POST_CONVERSION_PLUGINS = set() -CUSTOM_POST_CALIBRATION_PLUGINS = set() # TODO: This is a temporary solution @@ -47,12 +46,6 @@ def register_custom_post_conversion_plugins(model): callback(model) -def register_custom_post_calibration_plugins(model): - """Registers custom modules as QUANT_MODULE after calibration.""" - for callback in CUSTOM_POST_CALIBRATION_PLUGINS: - callback(model) - - class _QuantFunctionalMixin(QuantModule): """Mixin class for quantized functionals. From 5481d10335a768581e58238223bf3430262b23b6 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 16 Oct 2025 22:06:28 +0000 Subject: [PATCH 26/28] fixed tests for per-channel support Signed-off-by: Kinjal Patel --- .../torch/quantization/plugins/megatron.py | 12 ++- .../torch_dist/plugins/megatron_common.py | 93 +++++++++++++------ .../quantization/plugins/test_megatron.py | 31 +++++-- 3 files changed, 94 insertions(+), 42 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 39054e5c3..fd6b0660d 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -506,9 +506,15 @@ def _setup(self): expert.linear_fc2.parallel_state = self.parallel_state def sync_moe_local_experts_amax(self): - """Sync amax across experts in a SequentialMLP.""" + """Sync amax across local experts in a SequentialMLP. + + amax across EP and ETP (for RowParallel) are synchronized as part of model_calib.max_calibrate(). + This function is called to synchronize the amax values across local experts s.t. all localexperts will + share the same amax. + """ + torch.distributed.barrier() + # Collect amax from all local experts amax_dict = {} - # gather amax values from SequentialMLP experts for expert in self.local_experts: for name, module in expert.named_modules(): if isinstance(module, TensorQuantizer) and module.amax is not None: @@ -520,7 +526,7 @@ def sync_moe_local_experts_amax(self): else torch.maximum(stored_amax, amax_tensor) ) - # sync amax values across experts in SequentialMLP + # Apply synchronized amax values back to all local experts for expert in self.local_experts: for name, module in expert.named_modules(): if isinstance(module, TensorQuantizer) and module.amax is not None: diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index c9562ac96..4c91c0904 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -14,6 +14,7 @@ # limitations under the License. import copy import re +from collections import defaultdict from warnings import warn import torch @@ -41,6 +42,7 @@ from megatron.core.parallel_state import ( get_expert_model_parallel_group, get_expert_tensor_parallel_group, + get_expert_tensor_parallel_rank, initialize_model_parallel, is_pipeline_first_stage, is_pipeline_last_stage, @@ -190,7 +192,7 @@ def squared_relu(x): pipeline_model_parallel_size=pipeline_model_parallel_size, expert_model_parallel_size=expert_model_parallel_size, expert_tensor_parallel_size=expert_tensor_parallel_size, - sequence_parallel=expert_model_parallel_size > 1, + sequence_parallel=False, moe_grouped_gemm=moe_grouped_gemm, num_layers=num_layers, num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, @@ -221,7 +223,12 @@ def squared_relu(x): else: assert HAS_TE, "Transformer Engine not installed" transformer_layer_spec = ( - get_gpt_modelopt_spec(config, remap_te_layernorm=True) + get_gpt_modelopt_spec( + config, + remap_te_layernorm=True, + # TODO: uncomment this when TEGroupedMLP is enabled in Megatron-LM + # moe_grouped_gemm=moe_grouped_gemm + ) if transformer_impl == "modelopt" else get_gpt_layer_with_transformer_engine_spec() ) @@ -565,8 +572,7 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True) # Check for both TEGrouped and sequential MoE patterns if "local_experts" in name or ("experts" in name and "linear_fc" in name): # Convert to scalar only if tensor has a single element - amax_val = module.amax.detach().clone().cpu() - expert_amax_values[name] = amax_val + expert_amax_values[name] = module.amax.detach().clone().cpu() # Early return if no expert quantizers found assert expert_amax_values, "No expert quantizers found" @@ -577,19 +583,16 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True) torch.distributed.all_gather_object(all_amax_values, expert_amax_values) # Group quantizers by type (ignoring specific expert indices) and check sync - expert_quantizers = {} + expert_quantizers = defaultdict(dict) for rank_idx, rank_amax in enumerate(all_amax_values): for name, amax_val in rank_amax.items(): # Create quantizer type key by normalizing the name - if "local_experts" in name: - # sequential MoE: replace expert index with wildcard - quantizer_type = re.sub(r"local_experts\.\d+", "local_experts.*", name) - else: - # TEGrouped MoE: use the name as-is since experts are grouped - quantizer_type = name - - if quantizer_type not in expert_quantizers: - expert_quantizers[quantizer_type] = {} + quantizer_type = ( + re.sub(r"local_experts\.\d+", "local_experts.*", name) + if "local_experts" in name + else name + ) + if ( quantizer_type in expert_quantizers and rank_idx in expert_quantizers[quantizer_type] @@ -608,21 +611,53 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True) ) expert_quantizers[quantizer_type][rank_idx] = amax_val - # Check synchronization - fail fast on first inconsistency + rank_info = { + "global_rank": torch.distributed.get_rank(), + "etp_rank": get_expert_tensor_parallel_rank(), + } + + all_rank_info = [None] * world_size + torch.distributed.all_gather_object(all_rank_info, rank_info) + + # Group ranks by ETP rank for fc1 (ColumnParallel: same output channels should match) + etp_groups = defaultdict(list) + for info in all_rank_info: + etp_groups[info["etp_rank"] if info["etp_rank"] else 0].append(info["global_rank"]) + for quantizer_type, rank_values in expert_quantizers.items(): - if len(rank_values) > 1: # Only check if we have multiple ranks - values = list(rank_values.values()) - # Handle both scalar and tensor comparisons - first_val = values[0] - if isinstance(first_val, torch.Tensor): - # For tensors, check if all values are close to the first one - for val in values[1:]: - if not torch.allclose(first_val, val, rtol=1e-6, atol=1e-6): - return False, quantizer_type, rank_values - else: - # For scalars, use numeric comparison - max_diff = max(values) - min(values) - if max_diff > 1e-6: # Allow for small floating point differences - return False, quantizer_type, rank_values + # Determine which ranks should have same amax + # Find which rank should have same amax + # + # fc1: ColumnParallel: X @ [A_1, A_2] (weights split along Cout) + # so amax should be the same across same ETP rank + # if EP is 2, ETP is 2, we have 4 ranks, EP1, ETP1: 0, EP1, ETP2: 1, EP2, ETP1: 2, EP2, ETP2: 3 + # so we need to compare amax across same ETP rank [0, 2] [1, 3] for per-channel quantization + # + # fc2: RowParallel: [X_1, X_2] @ [A_1 + # A_2] (weights split along Cin) + # amax should be the same across all ranks + + rank_groups = ( + list(etp_groups.values()) + if "linear_fc1" in quantizer_type and rank_values[0].ndim > 0 + else [list(range(world_size))] + ) + + # Check each group independently + for group in rank_groups: + group_values = [rank_values[r] for r in group if r in rank_values] + if len(group_values) > 1: + # All values in this group should be identical + first_val = group_values[0] + for val in group_values[1:]: + if isinstance(first_val, torch.Tensor): + if not torch.allclose(first_val, val, rtol=1e-6, atol=1e-6): + group_rank_values = { + r: rank_values[r] for r in group if r in rank_values + } + return False, f"{quantizer_type} (group {group})", group_rank_values + elif abs(first_val - val) > 1e-6: + group_rank_values = {r: rank_values[r] for r in group if r in rank_values} + return False, f"{quantizer_type} (group {group})", group_rank_values return True, None, None diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index bf27fdd11..d67ff44d0 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -45,6 +45,7 @@ ) from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP +from megatron.core.transformer.moe.router import TopKRouter import modelopt import modelopt.torch.opt as mto @@ -240,6 +241,7 @@ def _gpt_model_provider( ep_size=1, etp_size=None, use_te=False, + transformer_impl="local", ): """Build the model.""" @@ -253,7 +255,7 @@ def _gpt_model_provider( ffn_hidden_size=None, num_attention_heads=8, activation_func="squared_relu", - transformer_impl="local", + transformer_impl=transformer_impl, hidden_size=hidden_size, vocab_size=vocab_size, use_cpu_initialization=meta_device, @@ -270,7 +272,7 @@ def _gpt_model_provider( ffn_hidden_size=None, num_attention_heads=8, activation_func="squared_relu", - transformer_impl="local", + transformer_impl=transformer_impl, hidden_size=hidden_size, vocab_size=vocab_size, num_moe_experts=num_moe_experts, @@ -297,6 +299,7 @@ def _test_sharded_state_dict( num_moe_experts = moe_config.get("num_moe_experts", None) moe_grouped_gemm = moe_config.get("moe_grouped_gemm", False) use_te = moe_config.get("use_te", False) + transformer_impl = moe_config.get("transformer_impl", "local") initialize_for_megatron( tensor_model_parallel_size=tp_size, @@ -314,6 +317,7 @@ def _test_sharded_state_dict( use_te=use_te, ep_size=ep_size, etp_size=etp_size, + transformer_impl=transformer_impl, ) model_test = _gpt_model_provider( tp_size, @@ -325,6 +329,7 @@ def _test_sharded_state_dict( meta_device=meta_device, ep_size=ep_size, etp_size=etp_size, + transformer_impl=transformer_impl, ) prompt_tokens = torch.randint( @@ -531,10 +536,7 @@ def test_fp8_real_quantize(): @pytest.mark.parametrize( "config", - [ - mtq.FP8_DEFAULT_CFG, - mtq.NVFP4_DEFAULT_CFG, - ], + [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG], ) @pytest.mark.parametrize("moe_grouped_gemm", [True, False]) def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm): @@ -549,6 +551,7 @@ def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm) "num_moe_experts": 4, "moe_grouped_gemm": moe_grouped_gemm, "use_te": moe_grouped_gemm, + "transformer_impl": "modelopt", } spawn_multiprocess_job( size=size, @@ -606,6 +609,7 @@ def forward_fn(model): hidden_size=32, moe_grouped_gemm=False, num_moe_experts=4, + transformer_impl="modelopt", ) num_sequential_mlp = sum( isinstance(module, SequentialMLP) for module in sequential_moe_model.modules() @@ -666,10 +670,16 @@ def _test_expert_model_parallel_amax_sync( hidden_size=256, moe_grouped_gemm=moe_grouped_gemm, use_te=moe_grouped_gemm, - num_moe_experts=4, + num_moe_experts=8, + transformer_impl="modelopt", ) prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + # force all expert routing + for module in model.modules(): + if isinstance(module, TopKRouter): + module.topk = module.num_experts + def forward_fn(model): return megatron_prefill(model, prompt_tokens) @@ -701,9 +711,10 @@ def forward_fn(model): assert final_sync, f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}" +@pytest.mark.parametrize("config", [mtq.FP8_DEFAULT_CFG, mtq.INT8_DEFAULT_CFG]) @pytest.mark.parametrize(("ep_size", "etp_size"), [(1, 2), (2, 1), (2, 2)]) @pytest.mark.parametrize("moe_grouped_gemm", [True, False]) -def test_expert_parallel_sync(ep_size, etp_size, moe_grouped_gemm): +def test_expert_parallel_sync(config, ep_size, etp_size, moe_grouped_gemm): """Test expert model parallel synchronization.""" size = torch.cuda.device_count() if size < ep_size * etp_size: @@ -716,11 +727,11 @@ def test_expert_parallel_sync(ep_size, etp_size, moe_grouped_gemm): size=size, job=partial( _test_expert_model_parallel_amax_sync, - 2, + etp_size, # tp_size ep_size, etp_size, moe_grouped_gemm, - mtq.FP8_DEFAULT_CFG, + config, ), backend="nccl", ) From 91837c3b70340624202953f4b6f1a2e8b77dcfaa Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Thu, 16 Oct 2025 22:42:15 +0000 Subject: [PATCH 27/28] minor fix Signed-off-by: Kinjal Patel --- tests/_test_utils/torch_dist/plugins/megatron_common.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 4c91c0904..c913bd7d2 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -636,10 +636,9 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True) # fc2: RowParallel: [X_1, X_2] @ [A_1 # A_2] (weights split along Cin) # amax should be the same across all ranks - rank_groups = ( list(etp_groups.values()) - if "linear_fc1" in quantizer_type and rank_values[0].ndim > 0 + if "linear_fc1" in quantizer_type and (next(iter(rank_values.values()))).ndim > 0 else [list(range(world_size))] ) From ca5534883b9b4a8f596cf6a981c49f068927cc86 Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Fri, 17 Oct 2025 00:05:18 +0000 Subject: [PATCH 28/28] Addressed MR comments Signed-off-by: Kinjal Patel --- modelopt/torch/quantization/model_calib.py | 2 +- modelopt/torch/quantization/plugins/transformer_engine.py | 2 ++ modelopt/torch/quantization/utils.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 1abcddd12..c1e6feb06 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -118,7 +118,7 @@ def sync_quantizer_amax_across_tp( # Syncing amax across TP for sequential quantizer if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - "Syncing amax across TP for sequential quantizer" + # Syncing amax across TP for sequential quantizer sync_quantizer_amax_across_tp( _q, linear_name, quantizer_type, axes_for_sync, parallel_state ) diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index 41a6c8321..5199bbf34 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -73,6 +73,7 @@ def _setup(self): # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning # self.weight0 to self.weight to run the quantizer states initialization. + assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear" self.weight = self.weight0 # Memorize the original weight.dtype for modelopt_post_restore given that # the dtype can change later. @@ -84,6 +85,7 @@ def modelopt_post_restore(self, prefix: str = ""): # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning # self.weight0 to self.weight to run the quantizer states initialization. + assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear" self.weight = self.weight0 super().modelopt_post_restore(prefix=prefix) # Remove self.weight after post_restore. diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 3287f6d91..43e269fa1 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -253,6 +253,7 @@ def is_quantized_linear(module): and hasattr(module, "weight_quantizer") and ( (getattr(module, "weight", None) is not None and module.weight.dim() == 2) + # module.weight0 check is required to support TEGroupedLinear or (getattr(module, "weight0", None) is not None and module.weight0.dim() == 2) ) )