From f17131f82d2e4cd41fcda0c07129b7c67f605eb4 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 24 Sep 2025 00:27:11 +0000 Subject: [PATCH 01/16] sync amax in context parallel and awq act scale Signed-off-by: Jennifer Chen --- examples/nemo_run/qat/README.md | 2 +- modelopt/torch/quantization/model_calib.py | 19 +++- .../torch/quantization/plugins/megatron.py | 9 +- modelopt/torch/utils/distributed.py | 4 +- .../torch_dist/plugins/megatron_common.py | 7 +- .../torch_quantization/quantize_common.py | 57 +++++++++++- tests/gpu/torch/conftest.py | 6 ++ .../quantization/plugins/test_megatron.py | 89 ++++++++++++++++++- 8 files changed, 178 insertions(+), 15 deletions(-) diff --git a/examples/nemo_run/qat/README.md b/examples/nemo_run/qat/README.md index 79715953c..cd74c96e2 100644 --- a/examples/nemo_run/qat/README.md +++ b/examples/nemo_run/qat/README.md @@ -92,7 +92,7 @@ In order to train using QAD, launch the example with `python qat/nemo_qat_flow.p To perform QAD training, run: ```bash -python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment +python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment --tensor_parallelism 4 ``` ## Supported models diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 5276b1334..6d1b7c86f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -79,21 +79,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis if not distributed_sync: return - def sync_quantizer_amax_across_dp(quantizer, parallel_state): + def sync_quantizer_amax_across_dp_cp(quantizer, parallel_state): + """Synchronize the amax across all ranks in the data parallel and context parallel groups.""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - sync_quantizer_amax_across_dp(_q, parallel_state) + sync_quantizer_amax_across_dp_cp(_q, parallel_state) return if getattr(quantizer, "_amax", None) is not None: quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) + quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group) # TODO: create sync_bias_across_distributed_group for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): - sync_quantizer_amax_across_dp(child, module.parallel_state) - + sync_quantizer_amax_across_dp_cp(child, module.parallel_state) # TP sync: # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same @@ -624,6 +625,14 @@ def forward(self, input, *args, **kwargs): # This will also perform distributed amax sync for input_quantizers max_calibrate(model, lambda model: None) + def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group): + # Sync across Data Parallel (DP) + if data_parallel_group.is_initialized(): + dist.all_reduce(module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group) + # Sync across Context Parallel (CP) + if context_parallel_group.is_initialized(): + dist.all_reduce(module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group) + for name, module in model.named_modules(): if ( is_quantized_linear(module) @@ -631,6 +640,8 @@ def forward(self, input, *args, **kwargs): and module.awq_lite.num_cache_steps > 0 ): module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps + sync_act_scale_across_dp_cp(module, module.parallel_state.data_parallel_group, module.parallel_state.context_parallel_group) + # Hack: MoEs forward all tokens through all experts if _if_calib is True module._if_calib = True diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index ab64a795a..72442526f 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -23,6 +23,7 @@ import megatron.core.transformer.mlp as megatron_mlp import torch from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region +from megatron.core.parallel_state import get_data_parallel_group from megatron.core.transformer import MegatronModule from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint from megatron.core.utils import get_tensor_model_parallel_group_if_none @@ -217,9 +218,15 @@ class _MegatronParallelLinear(_ParallelLinear): ] def _setup(self): + data_parallel_group = None + try: + data_parallel_group = get_data_parallel_group(with_context_parallel=True) + except: + data_parallel_group = get_data_parallel_group() self.parallel_state = ParallelState( - getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(), + data_parallel_group, mcore_parallel.get_tensor_model_parallel_group(), + mcore_parallel.get_context_parallel_group(), ) super()._setup() diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 7aebc992d..c1a313e48 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -241,13 +241,15 @@ def __init__( self, data_parallel_group: torch.distributed.ProcessGroup | int | None = None, tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1, + context_parallel_group: torch.distributed.ProcessGroup | int | None = -1, ): """Initialize the parallel state.""" self.data_parallel_group = DistributedProcessGroup(data_parallel_group) self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group) + self.context_parallel_group = DistributedProcessGroup(context_parallel_group) def __repr__(self) -> str: - return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}" + return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}, context_parallel_group: {self.context_parallel_group}" def get_group(ranks: list[int]): diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 9c1dd1bf7..ab9d467ea 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -83,9 +83,10 @@ class MegatronModel(MegatronModule): - def __init__(self, tp_size: int = 1, use_te_norm: bool = False): + def __init__(self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False): config = TransformerConfig( tensor_model_parallel_size=tp_size, + context_parallel_size=cp_size, pipeline_model_parallel_size=1, normalization="LayerNorm", # Unused parameters below are set to avoid ZeroDivisionError in __post_init__ @@ -383,13 +384,13 @@ def run_mcore_inference_with_dummy_input( def initialize_for_megatron( - tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234 + tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=1, seed=1234 ): """Initialize Megatron model parallelism. NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`. """ - initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size) + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, context_parallel_size=context_parallel_size) model_parallel_cuda_manual_seed(seed) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 505eac2b6..65cc39a5e 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -116,8 +116,8 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N mto.restore_from_modelopt_state(model_ref, state_dict) -def tensor_parallel_test_helper(model, config, tp_group, dp_group): - # The input to fist layer, the column parallel should be the same across all tp ranks +def tensor_parallel_test_helper(model, config, tp_group): + # The input to first layer, the column parallel should be the same across all tp ranks calib_data = model.get_dummy_input().cuda() dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) @@ -149,6 +149,59 @@ def forward_loop(model): dist.destroy_process_group() +def data_parallel_test_helper(model, config, dp_group): + calib_data = model.get_dummy_input().cuda() + + def forward_loop(model): + model(calib_data) + + model = mtq.quantize(model, config, forward_loop) + + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + +def context_parallel_test_helper(model, config, cp_group): + calib_data = model.get_dummy_input().cuda() + + def forward_loop(model): + model(calib_data) + + model = mtq.quantize(model, config, forward_loop) + + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + +def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): + calib_data = model.get_dummy_input().cuda() + # data should be same across each TP rank + dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) + + def forward_loop(model): + model(calib_data) + + model = mtq.quantize(model, config, forward_loop) + + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) def auto_quantize_helper(model): model, search_state = mtq.auto_quantize( diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index 208fb2287..05de03012 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -33,6 +33,12 @@ def need_2_gpus(): if torch.cuda.device_count() < 2: pytest.skip("Need at least 2 GPUs to run this test") +@pytest.fixture +def need_8_gpus(): + if torch.cuda.device_count() < 8: + pytest.skip("Need at least 8 GPUs to run this test") + + @pytest.fixture(scope="module") def set_torch_dtype(request): diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index c3630e028..05ba40f7c 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -32,6 +32,9 @@ from _test_utils.torch_quantization.quantize_common import ( auto_quantize_helper, tensor_parallel_test_helper, + data_parallel_test_helper, + context_parallel_test_helper, + data_tensor_context_parallel_test_helper, ) from packaging.version import Version @@ -41,6 +44,7 @@ from megatron.core.parallel_state import ( destroy_model_parallel, get_data_parallel_group, + get_context_parallel_group, get_tensor_model_parallel_group, ) from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear @@ -91,13 +95,13 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1): # Clean up since this is not a spawned process destroy_model_parallel() - +# 1. Tensor Parallel Test def _test_tensor_parallel_helper(config, rank, size): initialize_for_megatron(tensor_model_parallel_size=2, seed=SEED) - model = MegatronModel(size).cuda() + model = MegatronModel(tp_size=size).cuda() tensor_parallel_test_helper( - model, config, get_tensor_model_parallel_group(), get_data_parallel_group() + model, config, get_tensor_model_parallel_group() ) @@ -118,6 +122,85 @@ def test_tensor_parallel(need_2_gpus, config): size=2, job=partial(_test_tensor_parallel_helper, config), backend="nccl" ) +# 2. Data Parallel Test +def _test_data_parallel_helper(config, rank, size): + # TODO does this model automatically get copied to both DP ranks? + initialize_for_megatron(seed=SEED) + model = MegatronModel().cuda() + + data_parallel_test_helper( + model, config, get_data_parallel_group() + ) + + +@pytest.mark.parametrize( + "config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.W4A8_AWQ_BETA_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + mtq.INT4_AWQ_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_data_parallel(need_2_gpus, config): + spawn_multiprocess_job( + size=2, job=partial(_test_data_parallel_helper, config), backend="nccl" + ) + +# 3. Context Parallel Test +def _test_context_parallel_helper(config, rank, size): + initialize_for_megatron(context_parallel_size=size, seed=SEED) + model = MegatronModel(cp_size=size).cuda() + + context_parallel_test_helper( + model, config, get_context_parallel_group() + ) + +@pytest.mark.parametrize( + "config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.W4A8_AWQ_BETA_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + mtq.INT4_AWQ_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_context_parallel(need_2_gpus, config): + spawn_multiprocess_job( + size=2, job=partial(_test_context_parallel_helper, config), backend="nccl" + ) + +# 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs) +def _test_data_tensor_context_parallel_helper(config, rank, size): + initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED) + model = MegatronModel(tp_size=2, cp_size=2).cuda() + + data_tensor_context_parallel_test_helper( + model, config, get_data_parallel_group(), get_tensor_model_parallel_group(), get_context_parallel_group() + ) + +@pytest.mark.parametrize( + "config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.W4A8_AWQ_BETA_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + mtq.INT4_AWQ_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_data_tensor_context_parallel(need_8_gpus, config): + spawn_multiprocess_job( + size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" + ) def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): """Build the model.""" From 42519ccd809c50a73b85bd59d87b5859a885c3fe Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Thu, 25 Sep 2025 18:54:28 +0000 Subject: [PATCH 02/16] lint Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 16 ++++++-- .../torch/quantization/plugins/megatron.py | 4 +- modelopt/torch/utils/distributed.py | 6 ++- .../torch_dist/plugins/megatron_common.py | 6 ++- .../torch_quantization/quantize_common.py | 4 ++ tests/gpu/torch/conftest.py | 2 +- .../quantization/plugins/test_megatron.py | 37 ++++++++++--------- 7 files changed, 49 insertions(+), 26 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 6d1b7c86f..03f4936ae 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -628,10 +628,14 @@ def forward(self, input, *args, **kwargs): def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group): # Sync across Data Parallel (DP) if data_parallel_group.is_initialized(): - dist.all_reduce(module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group) + dist.all_reduce( + module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group + ) # Sync across Context Parallel (CP) if context_parallel_group.is_initialized(): - dist.all_reduce(module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group) + dist.all_reduce( + module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group + ) for name, module in model.named_modules(): if ( @@ -640,8 +644,12 @@ def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_gr and module.awq_lite.num_cache_steps > 0 ): module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps - sync_act_scale_across_dp_cp(module, module.parallel_state.data_parallel_group, module.parallel_state.context_parallel_group) - + sync_act_scale_across_dp_cp( + module, + module.parallel_state.data_parallel_group, + module.parallel_state.context_parallel_group, + ) + # Hack: MoEs forward all tokens through all experts if _if_calib is True module._if_calib = True diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 72442526f..96e8c61e1 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -22,8 +22,8 @@ import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp import torch -from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.parallel_state import get_data_parallel_group +from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint from megatron.core.utils import get_tensor_model_parallel_group_if_none @@ -221,7 +221,7 @@ def _setup(self): data_parallel_group = None try: data_parallel_group = get_data_parallel_group(with_context_parallel=True) - except: + except AssertionError: data_parallel_group = get_data_parallel_group() self.parallel_state = ParallelState( data_parallel_group, diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index c1a313e48..18c2f40c2 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -249,7 +249,11 @@ def __init__( self.context_parallel_group = DistributedProcessGroup(context_parallel_group) def __repr__(self) -> str: - return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}, context_parallel_group: {self.context_parallel_group}" + return ( + f"data_parallel_group: {self.data_parallel_group}, " + f"tensor_parallel_group: {self.tensor_parallel_group}, " + f"context_parallel_group: {self.context_parallel_group}" + ) def get_group(ranks: list[int]): diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index ab9d467ea..7c497a895 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -390,7 +390,11 @@ def initialize_for_megatron( NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`. """ - initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, context_parallel_size=context_parallel_size) + initialize_model_parallel( + tensor_model_parallel_size, + pipeline_model_parallel_size, + context_parallel_size=context_parallel_size, + ) model_parallel_cuda_manual_seed(seed) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 65cc39a5e..e8c40af75 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -149,6 +149,7 @@ def forward_loop(model): dist.destroy_process_group() + def data_parallel_test_helper(model, config, dp_group): calib_data = model.get_dummy_input().cuda() @@ -165,6 +166,7 @@ def forward_loop(model): dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + def context_parallel_test_helper(model, config, cp_group): calib_data = model.get_dummy_input().cuda() @@ -181,6 +183,7 @@ def forward_loop(model): dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): calib_data = model.get_dummy_input().cuda() # data should be same across each TP rank @@ -203,6 +206,7 @@ def forward_loop(model): dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + def auto_quantize_helper(model): model, search_state = mtq.auto_quantize( model, diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index 05de03012..f32065bce 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -33,13 +33,13 @@ def need_2_gpus(): if torch.cuda.device_count() < 2: pytest.skip("Need at least 2 GPUs to run this test") + @pytest.fixture def need_8_gpus(): if torch.cuda.device_count() < 8: pytest.skip("Need at least 8 GPUs to run this test") - @pytest.fixture(scope="module") def set_torch_dtype(request): orig_dtype = torch.get_default_dtype() diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 05ba40f7c..486756e2a 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -31,10 +31,10 @@ from _test_utils.torch_quantization.quant_utils import get_model_size from _test_utils.torch_quantization.quantize_common import ( auto_quantize_helper, - tensor_parallel_test_helper, - data_parallel_test_helper, context_parallel_test_helper, + data_parallel_test_helper, data_tensor_context_parallel_test_helper, + tensor_parallel_test_helper, ) from packaging.version import Version @@ -43,8 +43,8 @@ import megatron.core from megatron.core.parallel_state import ( destroy_model_parallel, - get_data_parallel_group, get_context_parallel_group, + get_data_parallel_group, get_tensor_model_parallel_group, ) from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear @@ -95,14 +95,13 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1): # Clean up since this is not a spawned process destroy_model_parallel() + # 1. Tensor Parallel Test def _test_tensor_parallel_helper(config, rank, size): initialize_for_megatron(tensor_model_parallel_size=2, seed=SEED) model = MegatronModel(tp_size=size).cuda() - tensor_parallel_test_helper( - model, config, get_tensor_model_parallel_group() - ) + tensor_parallel_test_helper(model, config, get_tensor_model_parallel_group()) @pytest.mark.parametrize( @@ -122,15 +121,14 @@ def test_tensor_parallel(need_2_gpus, config): size=2, job=partial(_test_tensor_parallel_helper, config), backend="nccl" ) + # 2. Data Parallel Test def _test_data_parallel_helper(config, rank, size): # TODO does this model automatically get copied to both DP ranks? initialize_for_megatron(seed=SEED) model = MegatronModel().cuda() - data_parallel_test_helper( - model, config, get_data_parallel_group() - ) + data_parallel_test_helper(model, config, get_data_parallel_group()) @pytest.mark.parametrize( @@ -146,18 +144,16 @@ def _test_data_parallel_helper(config, rank, size): ], ) def test_data_parallel(need_2_gpus, config): - spawn_multiprocess_job( - size=2, job=partial(_test_data_parallel_helper, config), backend="nccl" - ) + spawn_multiprocess_job(size=2, job=partial(_test_data_parallel_helper, config), backend="nccl") + # 3. Context Parallel Test def _test_context_parallel_helper(config, rank, size): initialize_for_megatron(context_parallel_size=size, seed=SEED) model = MegatronModel(cp_size=size).cuda() - context_parallel_test_helper( - model, config, get_context_parallel_group() - ) + context_parallel_test_helper(model, config, get_context_parallel_group()) + @pytest.mark.parametrize( "config", @@ -176,15 +172,21 @@ def test_context_parallel(need_2_gpus, config): size=2, job=partial(_test_context_parallel_helper, config), backend="nccl" ) + # 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs) def _test_data_tensor_context_parallel_helper(config, rank, size): initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED) model = MegatronModel(tp_size=2, cp_size=2).cuda() data_tensor_context_parallel_test_helper( - model, config, get_data_parallel_group(), get_tensor_model_parallel_group(), get_context_parallel_group() + model, + config, + get_data_parallel_group(), + get_tensor_model_parallel_group(), + get_context_parallel_group(), ) + @pytest.mark.parametrize( "config", [ @@ -199,9 +201,10 @@ def _test_data_tensor_context_parallel_helper(config, rank, size): ) def test_data_tensor_context_parallel(need_8_gpus, config): spawn_multiprocess_job( - size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" + size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" ) + def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): """Build the model.""" From 264adbb1b13221bfaec2c59b161e6061bafcafd9 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Thu, 25 Sep 2025 21:46:30 +0000 Subject: [PATCH 03/16] test weight quantizer too Signed-off-by: Jennifer Chen --- examples/speculative_decoding/README.md | 195 ++++++++++-------- examples/speculative_decoding/launch_train.sh | 2 +- modelopt/torch/distill/plugins/megatron.py | 29 ++- modelopt/torch/quantization/export_onnx.py | 115 ++++++----- .../nn/modules/tensor_quantizer.py | 6 +- .../torch/quantization/plugins/diffusers.py | 23 ++- .../torch_quantization/quantize_common.py | 62 ++++-- .../speculative_decoding/test_eagle.py | 3 +- .../speculative_decoding/test_medusa.py | 3 +- 9 files changed, 266 insertions(+), 172 deletions(-) diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 503cf303f..aeea25adb 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -15,8 +15,11 @@ This example focuses on training with Hugging Face. To train with Megatron‑LM, | **Section** | **Description** | **Jump To** | | :------------: | :------------: | :------------: | | Pre-Requisites | Required & optional dependencies | \[[Link](#pre-requisites)\] | -| Simplified Workflow | Train, evaluate, and export eagle model with one-line command | \[[Link](#getting-started-simplified-workflow)\] | -| Complete Workflow | Full example with configurable training pipeline | \[[Link](#complete-workflow)\] | +| Simplified Workflow | Train, evaluate, and export EAGLE model with one-line command | \[[Link](#getting-started-simplified-workflow)\] | +| Online Training | Train draft model alongside base model in GPU memory | \[[Link](#training-draft-model-with-online-base-model)\] | +| Offline Training | Train draft model using pre-computed hidden states | \[[Link](#training-draft-model-with-offline-base-model)\] | +| After Training | Evaluation, export and deployment | \[[Link](#model-validation)\] | +| Advanced Usage | Data synthesis, vocab compression, and configuration | \[[Link](#advanced-usage)\] | | Support Matrix | Supported models for speculative decoding training | \[[Link](#support-matrix)\] | | Speculation Module Checkpoints | View pre-trained speculation modules ready to deploy! | \[[Link](#speculation-module-checkpoints)\] | | Resources | Extra links to relevant resources | \[[Link](#resources)\] | @@ -61,13 +64,113 @@ This one-line command runs a minimal example workflow of training and exporting - Evaluates the acceptance rate on [MT-Bench](https://huggingface.co/datasets/HuggingFaceH4/mt_bench_prompts) - Exports a checkpoint ready for deployment -## Complete Workflow +## Training Draft Model with Online Base Model -This section presents a more comprehensive example for customizing speculative decoding training with Modelopt, including optional steps to enhance training quality and efficiency. +For small base models that fit in GPU memory, we can collocate them with draft models and train with the following command: -### (Optional) Data Synthesis +```bash +./launch_train.sh --model $BASE_MODEL \ + --output_dir $OUTPUT_DIR \ + --data Daring-Anteater/train.jsonl \ + --num_gpu $NUM_GPU \ + --num_epochs $NUM_EPOCH \ + --eagle_config eagle_config.json +``` + +This command will launch `main.py` with `accelerate`. See [section: interact with modelopt.torch.speculative](#interact-with-modelopttorchspeculative) for more details. +The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. + +## Training Draft Model with Offline Base Model + +For large models, you can export intermediate hidden states to disk and train only the draft model. This significantly reduces GPU memory requirements, but requires several to tens of terabytes of storage depending on dataset size. + +First, dump the base model's hidden states with the following command: + +```bash +python collect_hidden_states/compute_hidden_states_hf.py \ + --model $BASE_MODEL \ + --input-file Daring-Anteater/train.jsonl \ + --output-dir $HIDDEN_STATES_DIR +``` + +See [`run_hf_compute_hiddens_dp.sh`](./collect_hidden_states/run_hf_compute_hiddens_dp.sh) for a simple example using data parallelism (DP) to accelerate hidden state generation. + +Then, train draft model with `--offline-data` argument: + +```bash +./launch_train.sh --model $BASE_MODEL \ + --output_dir $OUTPUT_DIR \ + --data $DATA \ + --num_gpu $NUM_GPU \ + --num_epochs $NUM_EPOCH \ + --eagle_config eagle_config.json \ + --offline-data $HIDDEN_STATES_DIR +``` + +## Model Validation + +After training draft model, we can evaluate the saved modelopt checkpoint on MT-bench by: + +```bash +python ar_validate.py --model_path $OUTPUT_DIR +``` + +Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below. + +## Export + +```bash +python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH +``` + +This exports the model from a ModelOpt checkpoint to a deployment-compatible format. + +## Deployment + +The exported checkpoint can be deployed on TRT-LLM or SGLang. + +### TRT-LLM + +To serve the checkpoint with TRT-LLM, run trtllm-serve with: + +```bash +trtllm-serve --host 0.0.0.0 --port 8000 --backend pytorch --max_batch_size 32 --max_num_tokens 8192 --max_seq_len 8192 --extra_llm_api_options extra-llm-api-config.yml +``` + +, with `extra-llm-api-config.yml` being + +```yaml +enable_attention_dp: false +disable_overlap_scheduler: true +enable_autotuner: false + +cuda_graph_config: + max_batch_size: 1 + +speculative_config: + decoding_type: Eagle + max_draft_len: 3 + speculative_model_dir: + +kv_cache_config: + enable_block_reuse: false +``` + +Please refer to [TRT-LLM Doc: Speculative Decoding](https://nvidia.github.io/TensorRT-LLM/examples/llm_speculative_decoding.html) for detailed usage. + +### SGLang -To achieve higher acceptance rates during speculative decoding, it is beneficial to use conversations generated by the base model as training data, ensuring that the draft model’s output distribution closely aligns with that of the base model. +Please refer to [SGLang Doc: Speculative Decoding](https://docs.sglang.ai/advanced_features/speculative_decoding.html#EAGLE-3-Decoding) for detailed usage. + +### Deploying Quantized model + +See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/README.md). + +## Advanced Usage + +### Data Synthesis + +To achieve higher acceptance rates during speculative decoding, it is beneficial to use conversations generated by the base model as training data. This ensures that the draft model's output distribution closely aligns with that of the base model. To prepare such data, we launch an inference server with the base model: @@ -78,7 +181,7 @@ vllm serve meta-llama/Llama-3.2-1B-Instruct --api-key token-abc123 --port 8000 Note: Add `--quantization=modelopt` flag for quantized models. -Then, we generate conversations with base model and prompts from Daring-Anteater: +Then, we generate conversations with the base model using prompts from Daring-Anteater: ```bash python server_generate.py --data_path Daring-Anteater/train.jsonl --output_path synthetic/train.jsonl @@ -88,7 +191,7 @@ To add a system prompt, use the `--system_prompt ` argument. For large scale data generation, please see [SLURM prepare data](SLURM_prepare_data.md) for SLURM support. -### (Optional) Draft Vocabulary Compression +### Draft Vocabulary Compression We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set: @@ -98,7 +201,7 @@ python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`. -### (Optional) Configuring Draft Model +### Configuring Draft Model For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. In this example, we override `draft_vocab_size` in `eagle_config.json`: @@ -108,7 +211,7 @@ For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](htt } ``` -### Training Draft Model with Modelopt +### Interact with `modelopt.torch.speculative` `main.py` provides an example for converting a HF base model for speculative decoding and training it. It consists of a few simple steps: First, load the base model and tokenizer from Hugging Face: @@ -162,78 +265,6 @@ trainer.save_state() trainer.save_model("") ``` -We omitted details like tokenizer initialization for simplicity. A complete training example is provided in `main.py`, along with a bash script to launch training with Hugging Face Accelerate in `launch_train.sh`, which can be run by: - -```bash -./launch_train.sh --model $BASE_MODEL \ - --output_dir $OUTPUT_DIR \ - --data $DATA \ - --num_gpu $NUM_GPU \ - --num_epochs 10 \ - --eagle_config eagle_config.json #This is where we optionally overwrite default eagle configs -``` - -The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. - -### Model Validation - -After training draft model, we can evaluate the saved modelopt checkpoint on MT-bench by: - -```bash -python ar_validate.py --model_path $OUTPUT_DIR -``` - -Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below. - -### Export - -```bash -python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH -``` - -This exports the model from a ModelOpt checkpoint to a deployment‑compatible format. - -### Deployment - -The exported checkpoint can be deployed on TRT-LLM or SGLang. - -#### TRT-LLM - -To serve the checkpoint with trtllm, run trtllm-serve with: - -```bash -trtllm-serve --host 0.0.0.0 --port 8000 --backend pytorch --max_batch_size 32 --max_num_tokens 8192 --max_seq_len 8192 --extra_llm_api_options extra-llm-api-config.yml -``` - -, with `extra-llm-api-config.yml` being - -```yaml -enable_attention_dp: false -disable_overlap_scheduler: true -enable_autotuner: false - -cuda_graph_config: - max_batch_size: 1 - -speculative_config: - decoding_type: Eagle - max_draft_len: 3 - speculative_model_dir: - -kv_cache_config: - enable_block_reuse: false -``` - -Please refer to [TRT-LLM Doc: Speculative Decoding](https://nvidia.github.io/TensorRT-LLM/examples/llm_speculative_decoding.html) for detailed usage. - -#### SGLang - -Please refer to [SGLang Doc: Speculative Decoding](https://docs.sglang.ai/advanced_features/speculative_decoding.html#EAGLE-3-Decoding) for detailed usage. - -#### Deploying Quantized model - -See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/README.md). - ## Support Matrix | Model | Medusa | EAGLE1/2 | EAGLE3 | diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 3ecd4238a..2d0a4abe7 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -129,7 +129,7 @@ if [[ "$OFFLINE_DATA_PATH" != "" ]]; then echo "Offline data path $OFFLINE_DATA_PATH does not exist or is not a directory." exit 1 else - OFFLINE_TRAINING_ARGS="--offline-data-path $OFFLINE_DATA_PATH" + OFFLINE_TRAINING_ARGS="--offline-data-path $OFFLINE_DATA_PATH --ar_validate_steps -1" fi else OFFLINE_TRAINING_ARGS="" diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index 7078cca36..c1fa45f6b 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -59,7 +59,7 @@ class DistillationConfig: logit_kl_temperature: Temperature for the logit KL-divergence loss. """ - intermediate_layer_pairs: list[tuple[str, str]] = field(default_factory=list) + intermediate_layer_pairs: list[tuple[str, ...]] = field(default_factory=list) logit_layers: tuple[str, str] = ("output_layer", "output_layer") skip_lm_loss: bool = True kd_loss_scale: float = 1.0 @@ -69,12 +69,28 @@ class DistillationConfig: def __post_init__(self): assert len(self.logit_layers) == 2, f"{self.logit_layers=}" - assert all(len(pair) == 2 for pair in self.intermediate_layer_pairs), ( + assert all(len(pair) in (2, 3) for pair in self.intermediate_layer_pairs), ( f"{self.intermediate_layer_pairs=}" ) assert self.kd_loss_scale > 0, f"{self.kd_loss_scale=}" assert self.logit_kl_temperature > 0, f"{self.logit_kl_temperature=}" + @staticmethod + def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, Callable]: + """Parse an intermediate entry into a student layer, teacher layer, and loss function.""" + if len(entry) == 3: + student_layer, teacher_layer, loss_fn_name = entry + if loss_fn_name == "cosine": + loss_fn = HiddenStateCosineLoss + elif loss_fn_name == "mse": + loss_fn = MSELoss + else: + raise ValueError(f"Unknown intermediate loss function: {loss_fn_name}") + else: + student_layer, teacher_layer = entry + loss_fn = HiddenStateCosineLoss # default to cosine loss + return student_layer, teacher_layer, loss_fn + def load_distillation_config( config_path: str | None, student_cfg: "TransformerConfig", teacher_cfg: "TransformerConfig" @@ -105,7 +121,8 @@ def load_distillation_config( # NOTE: Projection layer shared among intermediate layer pairs. projection_layer = ProjectionLayer(student_cfg, teacher_cfg) - for student_layer, teacher_layer in cfg.intermediate_layer_pairs: + for entry in cfg.intermediate_layer_pairs: + student_layer, teacher_layer, loss_fn = cfg.parse_intermediate_entry(entry) if parallel_state.get_tensor_and_context_parallel_rank() == 0: logger.info( "Distillation: Adding intermediate loss between" @@ -114,7 +131,7 @@ def load_distillation_config( ) student_layer = _adjust_layer_index_for_pp(student_layer, student_cfg) teacher_layer = _adjust_layer_index_for_pp(teacher_layer, teacher_cfg) - criterion[(student_layer, teacher_layer)] = HiddenStateCosineLoss( + criterion[(student_layer, teacher_layer)] = loss_fn( student_cfg, projection_layer=projection_layer ) @@ -202,9 +219,9 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: predictions, targets = self.pre_forward(predictions, targets) loss = F.mse_loss(predictions, targets, reduction="none") - loss = loss.sum(dim=-1) + loss = loss.mean(dim=-1) - return self.post_forward(loss) + return self.post_forward(loss, is_sequence_parallel=self._config.sequence_parallel) class HiddenStateCosineLoss(BaseLoss): diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index fe9bd927b..e8a38b162 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -103,13 +103,18 @@ """Utility to export a quantized torch model to quantized ONNX.""" import contextlib +from typing import TYPE_CHECKING import onnx import torch from torch.onnx import symbolic_helper from torch.onnx import symbolic_helper as sym_help -from torch.onnx._internal import jit_utils -from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask + +if TYPE_CHECKING: + if hasattr(torch.onnx._internal, "jit_utils"): + from torch.onnx._internal.jit_utils import GraphContext + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter.jit_utils import GraphContext onnx_dtype_map = { "BFloat16": onnx.TensorProto.BFLOAT16, @@ -125,7 +130,7 @@ def export_int8( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, amax: torch.Tensor, num_bits: int, @@ -184,7 +189,7 @@ def export_int8( def export_int4( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, amax: torch.Tensor, num_bits: int, @@ -208,7 +213,7 @@ def export_int4( def _fp8_quantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale_inv: float, trt_high_precision_dtype: str, @@ -236,7 +241,7 @@ def _fp8_quantize( def _fp8_dequantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale_inv: float, trt_high_precision_dtype: str, @@ -263,7 +268,7 @@ def _fp8_dequantize( def export_fp8( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, amax: float, trt_high_precision_dtype: str | None, @@ -279,21 +284,29 @@ def export_fp8( def scaled_dot_product_attention( - g: jit_utils.GraphContext, - query: torch._C.Value, - key: torch._C.Value, - value: torch._C.Value, - attn_mask: torch._C.Value | None = None, + g: "GraphContext", + query: "torch._C.Value", + key: "torch._C.Value", + value: "torch._C.Value", + attn_mask: "torch._C.Value | None" = None, dropout_p: float = 0.0, is_causal: bool = False, - scale: torch._C.Value | None = None, + scale: "torch._C.Value | None" = None, enable_gqa: bool = False, ): """Perform scaled dot product attention.""" if hasattr(torch.onnx, "_type_utils"): - from torch.onnx import _type_utils - else: - from torch.onnx._internal.torchscript_exporter import _type_utils + from torch.onnx._type_utils import JitScalarType + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter import JitScalarType + + if hasattr(torch.onnx, "symbolic_opset14"): + from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import ( + _attention_scale, + _causal_attention_mask, + ) assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), ( "is_causal and attn_mask cannot be set at the same time" @@ -327,22 +340,20 @@ def scaled_dot_product_attention( if symbolic_helper._is_none(attn_mask): mul_qk_add = mul_qk - elif _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL: + elif JitScalarType.from_value(attn_mask) == JitScalarType.BOOL: # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) const_zero = g.op("Constant", value_t=torch.tensor([0.0])) const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) mul_qk_add = g.op("Add", mul_qk, attn_mask) - elif _type_utils.JitScalarType.from_value(attn_mask) in ( - _type_utils.JitScalarType.FLOAT, - _type_utils.JitScalarType.HALF, - _type_utils.JitScalarType.BFLOAT16, + elif JitScalarType.from_value(attn_mask) in ( + JitScalarType.FLOAT, + JitScalarType.HALF, + JitScalarType.BFLOAT16, ): mul_qk_add = g.op("Add", mul_qk, attn_mask) else: - raise ValueError( - f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" - ) + raise ValueError(f"Unsupported type for attn_mask: {JitScalarType.from_value(attn_mask)}") attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) @@ -357,14 +368,14 @@ def scaled_dot_product_attention( def export_fp8_mha( - g: torch.onnx._internal.jit_utils.GraphContext, - query: torch._C.Value, - key: torch._C.Value, - value: torch._C.Value, - attn_mask: torch._C.Value | None = None, + g: "GraphContext", + query: "torch._C.Value", + key: "torch._C.Value", + value: "torch._C.Value", + attn_mask: "torch._C.Value | None" = None, dropout_p: float = 0.0, is_causal: bool = False, - scale: torch._C.Value | None = None, + scale: "torch._C.Value | None" = None, q_quantized_scale: float = 1.0, k_quantized_scale: float = 1.0, v_quantized_scale: float = 1.0, @@ -396,12 +407,18 @@ def export_fp8_mha( | Cast """ - from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask - if hasattr(torch.onnx, "_type_utils"): - from torch.onnx import _type_utils - else: - from torch.onnx._internal.torchscript_exporter import _type_utils + from torch.onnx._type_utils import JitScalarType + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter import JitScalarType + + if hasattr(torch.onnx, "symbolic_opset14"): + from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import ( + _attention_scale, + _causal_attention_mask, + ) # Pass all arguments, including x, to the custom ONNX operator assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), ( @@ -452,22 +469,20 @@ def export_fp8_mha( if sym_help._is_none(attn_mask): mul_qk_add = mul_qk - elif _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL: + elif JitScalarType.from_value(attn_mask) == JitScalarType.BOOL: # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) const_zero = g.op("Constant", value_t=torch.tensor([0.0])) const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) mul_qk_add = g.op("Add", mul_qk, attn_mask) - elif _type_utils.JitScalarType.from_value(attn_mask) in ( - _type_utils.JitScalarType.FLOAT, - _type_utils.JitScalarType.HALF, - _type_utils.JitScalarType.BFLOAT16, + elif JitScalarType.from_value(attn_mask) in ( + JitScalarType.FLOAT, + JitScalarType.HALF, + JitScalarType.BFLOAT16, ): mul_qk_add = g.op("Add", mul_qk, attn_mask) else: - raise ValueError( - f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" - ) + raise ValueError(f"Unsupported type for attn_mask: {JitScalarType.from_value(attn_mask)}") attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) @@ -495,7 +510,7 @@ def export_fp8_mha( def _fp4_dynamic_quantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale: float, trt_high_precision_dtype: str | None, @@ -531,7 +546,7 @@ def _fp4_dynamic_quantize( def _fp4_dequantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale: float | torch.Value, trt_high_precision_dtype: str | None, @@ -546,7 +561,7 @@ def _fp4_dequantize( def _fp4_dequantize_2( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, dyn_scale: torch.Value, block_size: int, @@ -557,7 +572,7 @@ def _fp4_dequantize_2( def _mxfp8_dynamic_quantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, block_size: int, axis: int = -1, @@ -575,7 +590,7 @@ def _mxfp8_dynamic_quantize( def _mxfp8_dequantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale: torch.Value, block_size: int, @@ -593,7 +608,7 @@ def _mxfp8_dequantize( def export_mxfp8( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Tensor, onnx_quantizer_type: str, block_size: int, @@ -611,7 +626,7 @@ def export_mxfp8( def export_fp4( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, block_size: int, amax: torch.Value, diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 0635b7c9b..6e431dce9 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -548,7 +548,7 @@ def _get_amax(self, inputs): def _validate_amax(self, amax): # Dynamic control flow is not supported by torch dynamo - if not is_torch_export_mode() and not torch._dynamo.is_compiling(): + if not is_torch_export_mode() and not torch.compiler.is_compiling(): assert torch.all(amax >= 0) and not torch.any(torch.isinf(amax)), ( f"Got invalid amax: {amax}" ) @@ -880,7 +880,7 @@ def forward(self, inputs): """ if hasattr(torch.onnx, "_globals"): from torch.onnx._globals import GLOBALS - else: + else: # torch >= 2.9 from torch.onnx._internal.torchscript_exporter._globals import GLOBALS if DTensor is not None and isinstance(inputs, DTensor): @@ -914,7 +914,7 @@ def forward(self, inputs): if ( not is_torch_export_mode() - and not torch._dynamo.is_compiling() + and not torch.compiler.is_compiling() and GLOBALS.in_onnx_export ): # GLOBALS could break TorchDynamo for some Pytorch versions (i.e., 2.3.0) diff --git a/modelopt/torch/quantization/plugins/diffusers.py b/modelopt/torch/quantization/plugins/diffusers.py index 5f1ab5db1..7c018e1bb 100644 --- a/modelopt/torch/quantization/plugins/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusers.py @@ -15,10 +15,10 @@ """Support quantization of diffusers layers.""" -import functools from collections.abc import Callable, Iterator from functools import partial from types import ModuleType +from typing import TYPE_CHECKING import onnx import torch @@ -27,7 +27,12 @@ from torch.autograd import Function from torch.nn import functional as F from torch.onnx import symbolic_helper -from torch.onnx._internal import jit_utils, registration + +if TYPE_CHECKING: + if hasattr(torch.onnx._internal, "jit_utils"): + from torch.onnx._internal.jit_utils import GraphContext + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter.jit_utils import GraphContext from ..export_onnx import export_fp8_mha from ..nn import ( @@ -40,8 +45,6 @@ ) from .custom import _QuantFunctionalMixin -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18) - onnx_dtype_map = { "BFloat16": onnx.TensorProto.BFLOAT16, "Float": onnx.TensorProto.FLOAT, @@ -205,14 +208,14 @@ def forward( @staticmethod @symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b") def symbolic( - g: jit_utils.GraphContext, - query: torch._C.Value, - key: torch._C.Value, - value: torch._C.Value, - attn_mask: torch._C.Value | None = None, + g: "GraphContext", + query: "torch._C.Value", + key: "torch._C.Value", + value: "torch._C.Value", + attn_mask: "torch._C.Value | None" = None, dropout_p: float = 0.0, is_causal: bool = False, - scale: torch._C.Value | None = None, + scale: "torch._C.Value | None" = None, q_quantized_scale: float = 1.0, k_quantized_scale: float = 1.0, v_quantized_scale: float = 1.0, diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index e8c40af75..ad01233cc 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -158,13 +158,22 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - fc1_amax = model.fc1.input_quantizer.amax.clone() + # Input quantizer amax + if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + + # Weight quantizer amax + fc1_amax = model.fc1.weight_quantizer.amax.clone() dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - - fc2_amax = model.fc2.input_quantizer.amax.clone() + assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax) + fc2_amax = model.fc2.weight_quantizer.amax.clone() dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax) def context_parallel_test_helper(model, config, cp_group): @@ -175,13 +184,22 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - fc1_amax = model.fc1.input_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + # Input quantizer amax + if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) - fc2_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + # Weight quantizer amax + fc1_weight_amax = model.fc1.weight_quantizer.amax.clone() + dist.all_reduce(fc1_weight_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc1_weight_amax, model.fc1.weight_quantizer.amax) + fc2_weight_amax = model.fc2.weight_quantizer.amax.clone() + dist.all_reduce(fc2_weight_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc2_weight_amax, model.fc2.weight_quantizer.amax) def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): @@ -194,17 +212,29 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - fc1_amax = model.fc1.input_quantizer.amax.clone() + # Input quantizer amax + if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + + fc1_amax = model.fc1.weight_quantizer.amax.clone() dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group) dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - - fc2_amax = model.fc2.input_quantizer.amax.clone() + assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax) + fc2_amax = model.fc2.weight_quantizer.amax.clone() dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group) dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax) def auto_quantize_helper(model): diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index c65089114..c81bc9363 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -36,12 +36,11 @@ def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_ run_example_command( [ - "./launch.sh", + "./launch_train.sh", "--model", tiny_llama_path, "--data", tiny_daring_anteater_path, "--num_epochs", "1", "--lr", "1e-5", - "--do_eval", "False", "--num_gpu", str(num_gpus), "--mode", "eagle3", "--eagle_config", str(config_file), diff --git a/tests/examples/speculative_decoding/test_medusa.py b/tests/examples/speculative_decoding/test_medusa.py index c11a2e707..c8a7616b9 100644 --- a/tests/examples/speculative_decoding/test_medusa.py +++ b/tests/examples/speculative_decoding/test_medusa.py @@ -38,12 +38,11 @@ def test_llama_medusa_fp8_qat(tiny_llama_path, num_gpus, tiny_daring_anteater_pa # Test Medusa run_example_command( [ - "./launch.sh", + "./launch_train.sh", "--model", tiny_llama_path, "--data", tiny_daring_anteater_path, "--num_epochs", "1", "--lr", "1e-5", - "--do_eval", "False", "--num_gpu", str(num_gpus), "--mode", "medusa", "--output_dir", medusa_path, From 1f7d17ecca4b27e39f53a8f8e2aede3079b24538 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Fri, 26 Sep 2025 01:07:13 +0000 Subject: [PATCH 04/16] fix test Signed-off-by: Jennifer Chen --- .../torch_quantization/quantize_common.py | 98 ++++++++----------- .../quantization/plugins/test_megatron.py | 7 +- 2 files changed, 42 insertions(+), 63 deletions(-) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index ad01233cc..abcd39caf 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -23,6 +23,7 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm +from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer from modelopt.torch.quantization.utils import is_quantized_linear from modelopt.torch.utils import torch_to @@ -150,7 +151,7 @@ def forward_loop(model): dist.destroy_process_group() -def data_parallel_test_helper(model, config, dp_group): +def dp_cp_parallel_test_helper(model, config, group): calib_data = model.get_dummy_input().cuda() def forward_loop(model): @@ -158,48 +159,27 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - # Input quantizer amax - if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - fc1_amax = model.fc1.input_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - fc2_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) - - # Weight quantizer amax - fc1_amax = model.fc1.weight_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax) - fc2_amax = model.fc2.weight_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax) - - -def context_parallel_test_helper(model, config, cp_group): - calib_data = model.get_dummy_input().cuda() - - def forward_loop(model): - model(calib_data) - - model = mtq.quantize(model, config, forward_loop) + def reduce_amax(quantizer): + amax = quantizer.amax.clone() + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group) + assert torch.allclose(amax, quantizer.amax) # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - fc1_amax = model.fc1.input_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - fc2_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + reduce_amax(model.fc1.input_quantizer) + reduce_amax(model.fc2.input_quantizer) # Weight quantizer amax - fc1_weight_amax = model.fc1.weight_quantizer.amax.clone() - dist.all_reduce(fc1_weight_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc1_weight_amax, model.fc1.weight_quantizer.amax) - fc2_weight_amax = model.fc2.weight_quantizer.amax.clone() - dist.all_reduce(fc2_weight_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc2_weight_amax, model.fc2.weight_quantizer.amax) + if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc1.weight_quantizer: + reduce_amax(quantizer) + else: + reduce_amax(model.fc1.weight_quantizer) + if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc2.weight_quantizer: + reduce_amax(quantizer) + else: + reduce_amax(model.fc2.weight_quantizer) def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): @@ -212,29 +192,29 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) + def reduce_amax(quantizer): + amax = quantizer.amax.clone() + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(amax, quantizer.amax) + # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - fc1_amax = model.fc1.input_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group) - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - fc2_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group) - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) - - fc1_amax = model.fc1.weight_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group) - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax) - fc2_amax = model.fc2.weight_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group) - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax) + reduce_amax(model.fc1.input_quantizer) + reduce_amax(model.fc2.input_quantizer) + + if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc1.weight_quantizer: + reduce_amax(quantizer) + else: + reduce_amax(model.fc1.weight_quantizer) + + if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc2.weight_quantizer: + reduce_amax(quantizer) + else: + reduce_amax(model.fc2.weight_quantizer) def auto_quantize_helper(model): diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 486756e2a..1ea51b800 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -31,9 +31,8 @@ from _test_utils.torch_quantization.quant_utils import get_model_size from _test_utils.torch_quantization.quantize_common import ( auto_quantize_helper, - context_parallel_test_helper, - data_parallel_test_helper, data_tensor_context_parallel_test_helper, + dp_cp_parallel_test_helper, tensor_parallel_test_helper, ) from packaging.version import Version @@ -128,7 +127,7 @@ def _test_data_parallel_helper(config, rank, size): initialize_for_megatron(seed=SEED) model = MegatronModel().cuda() - data_parallel_test_helper(model, config, get_data_parallel_group()) + dp_cp_parallel_test_helper(model, config, get_data_parallel_group()) @pytest.mark.parametrize( @@ -152,7 +151,7 @@ def _test_context_parallel_helper(config, rank, size): initialize_for_megatron(context_parallel_size=size, seed=SEED) model = MegatronModel(cp_size=size).cuda() - context_parallel_test_helper(model, config, get_context_parallel_group()) + dp_cp_parallel_test_helper(model, config, get_context_parallel_group()) @pytest.mark.parametrize( From d02365c5f31d3614db9957b26d0d59677cca7bbe Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Mon, 29 Sep 2025 21:38:48 +0000 Subject: [PATCH 05/16] awq test Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 18 +++++----- .../torch_dist/plugins/megatron_common.py | 5 ++- .../torch_quantization/quantize_common.py | 8 +++-- .../quantization/plugins/test_megatron.py | 7 ++-- .../torch/quantization/test_model_calib.py | 33 +++++++++++++++++++ 5 files changed, 56 insertions(+), 15 deletions(-) create mode 100644 tests/gpu/torch/quantization/test_model_calib.py diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index b683f265a..3e974c1f0 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -617,20 +617,20 @@ def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_gr and hasattr(module, "awq_lite") and module.awq_lite.num_cache_steps > 0 ): + # Hack: MoEs forward all tokens through all experts if _if_calib is True + module._if_calib = True module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps - + if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( torch.isnan(module.awq_lite.weight_scale) ): module.awq_lite.is_enabled = False - - sync_act_scale_across_dp_cp( - module, - module.parallel_state.data_parallel_group, - module.parallel_state.context_parallel_group, - ) - # Hack: MoEs forward all tokens through all experts if _if_calib is True - module._if_calib = True + else: + sync_act_scale_across_dp_cp( + module, + module.parallel_state.data_parallel_group, + module.parallel_state.context_parallel_group, + ) AWQLiteHelper.cache_mode = False print_rank_0("awq_lite: Searching parameters...") diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 7c497a895..6324d3390 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -384,7 +384,10 @@ def run_mcore_inference_with_dummy_input( def initialize_for_megatron( - tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=1, seed=1234 + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + seed=1234, + context_parallel_size=1, ): """Initialize Megatron model parallelism. diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index abcd39caf..c4f1dc2b5 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -194,9 +194,13 @@ def forward_loop(model): def reduce_amax(quantizer): amax = quantizer.amax.clone() - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group) + print("amax before reduce", amax) + print("quantizer.amax before reduce", quantizer.amax) dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group) + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group) + print("amax after reduce", amax) + print("quantizer.amax after reduce", quantizer.amax) assert torch.allclose(amax, quantizer.amax) # Input quantizer amax diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 1ea51b800..03a6c4ba8 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -123,8 +123,7 @@ def test_tensor_parallel(need_2_gpus, config): # 2. Data Parallel Test def _test_data_parallel_helper(config, rank, size): - # TODO does this model automatically get copied to both DP ranks? - initialize_for_megatron(seed=SEED) + initialize_for_megatron(seed=SEED + rank) # modify seed so data is different across ranks model = MegatronModel().cuda() dp_cp_parallel_test_helper(model, config, get_data_parallel_group()) @@ -148,7 +147,9 @@ def test_data_parallel(need_2_gpus, config): # 3. Context Parallel Test def _test_context_parallel_helper(config, rank, size): - initialize_for_megatron(context_parallel_size=size, seed=SEED) + initialize_for_megatron( + context_parallel_size=size, seed=SEED + rank + ) # modify seed so data is different across ranks model = MegatronModel(cp_size=size).cuda() dp_cp_parallel_test_helper(model, config, get_context_parallel_group()) diff --git a/tests/gpu/torch/quantization/test_model_calib.py b/tests/gpu/torch/quantization/test_model_calib.py new file mode 100644 index 000000000..2179b8532 --- /dev/null +++ b/tests/gpu/torch/quantization/test_model_calib.py @@ -0,0 +1,33 @@ +import torch +import torch.distributed as dist +from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job +from _test_utils.torch_dist.plugins.megatron_common import MegatronModel, initialize_for_megatron +from megatron.core.parallel_state import get_data_parallel_group + +from modelopt.torch.quantization.model_calib import awq_lite + + +def _test_awq_lite_act_scale_sync_helper(rank, size): + initialize_for_megatron(seed=1234 + rank) + model = MegatronModel().cuda() + + calib_data = model.get_dummy_input().cuda() + + def forward_loop(model): + model(calib_data) + + model = awq_lite(model, forward_loop) + # Sanity check + forward_loop(model) + + act_scale = model.fc1.weight_quantizer.awq_lite.act_scale.clone() + dist.all_reduce(act_scale, op=dist.ReduceOp.AVG, group=get_data_parallel_group()) + assert torch.allclose(act_scale, model.fc1.weight_quantizer.awq_lite.act_scale) + + act_scale = model.fc2.weight_quantizer.awq_lite.act_scale.clone() + dist.all_reduce(act_scale, op=dist.ReduceOp.AVG, group=get_data_parallel_group()) + assert torch.allclose(act_scale, model.fc2.weight_quantizer.awq_lite.act_scale) + + +def test_awq_lite_act_scale_sync(need_2_gpus): + spawn_multiprocess_job(size=2, job=_test_awq_lite_act_scale_sync_helper, backend="nccl") From 5a572da4bcfb77b5b6b9bc1e8e8b056114222b50 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Mon, 29 Sep 2025 23:17:00 +0000 Subject: [PATCH 06/16] move awq test inside megatron tests Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 1 + .../torch_quantization/quantize_common.py | 121 ++++++++++++------ .../quantization/plugins/test_megatron.py | 2 +- .../torch/quantization/test_model_calib.py | 33 ----- 4 files changed, 86 insertions(+), 71 deletions(-) delete mode 100644 tests/gpu/torch/quantization/test_model_calib.py diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3e974c1f0..7a1d9791a 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -581,6 +581,7 @@ def forward(self, input, *args, **kwargs): return out_actual for name, module in model.named_modules(): + print(name, module, module.weight_quantizer.is_enabled) if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model): module.awq_lite = AWQLiteHelper(module, name) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index c4f1dc2b5..e0fabb84a 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -117,6 +117,12 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N mto.restore_from_modelopt_state(model_ref, state_dict) +def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None): + quantizer_attr = getattr(quantizer, attr).clone() + dist.all_reduce(quantizer_attr, op=op, group=group) + assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) + + def tensor_parallel_test_helper(model, config, tp_group): # The input to first layer, the column parallel should be the same across all tp ranks calib_data = model.get_dummy_input().cuda() @@ -126,27 +132,39 @@ def forward_loop(model): model(calib_data) model = mtq.quantize(model, config, forward_loop) - # Sanity check forward_loop(model) if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]: # Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks - activation_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group) - assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax) + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group) # Lets check the row parallel weight amax; it should be the same across all tp ranks - weight_amax = model.fc2.weight_quantizer.amax.clone() - dist.all_reduce(weight_amax, op=dist.ReduceOp.MAX, group=tp_group) - assert torch.allclose(weight_amax, model.fc2.weight_quantizer.amax) + _reduce_quantizer_attr( + model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group + ) if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks input_quantizer = model.fc1.input_quantizer - pre_quant_scale = input_quantizer.pre_quant_scale.clone() - dist.all_reduce(pre_quant_scale, op=dist.ReduceOp.MAX, group=tp_group) - assert torch.allclose(pre_quant_scale, input_quantizer.pre_quant_scale) + _reduce_quantizer_attr( + input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, group=tp_group + ) + + if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: + # Check act scale + _reduce_quantizer_attr( + model.fc1.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=tp_group, + ) + _reduce_quantizer_attr( + model.fc2.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=tp_group, + ) dist.destroy_process_group() @@ -159,27 +177,37 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - def reduce_amax(quantizer): - amax = quantizer.amax.clone() - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group) - assert torch.allclose(amax, quantizer.amax) - # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - reduce_amax(model.fc1.input_quantizer) - reduce_amax(model.fc2.input_quantizer) + _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) # Weight quantizer amax if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): for quantizer in model.fc1.weight_quantizer: - reduce_amax(quantizer) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) else: - reduce_amax(model.fc1.weight_quantizer) + _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): for quantizer in model.fc2.weight_quantizer: - reduce_amax(quantizer) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) else: - reduce_amax(model.fc2.weight_quantizer) + _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) + + if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: + # Check act scale + _reduce_quantizer_attr( + model.fc1.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=group, + ) + _reduce_quantizer_attr( + model.fc2.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=group, + ) def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): @@ -192,33 +220,52 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - def reduce_amax(quantizer): - amax = quantizer.amax.clone() - print("amax before reduce", amax) - print("quantizer.amax before reduce", quantizer.amax) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group) - print("amax after reduce", amax) - print("quantizer.amax after reduce", quantizer.amax) - assert torch.allclose(amax, quantizer.amax) + def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): + quantizer_attr = getattr(quantizer, attr).clone() + print("quantizer_attr before reduce", quantizer_attr) + print("quantizer.attr before reduce", getattr(quantizer, attr)) + dist.all_reduce(quantizer_attr, op=op, group=dp_group) + dist.all_reduce(quantizer_attr, op=op, group=cp_group) + dist.all_reduce(quantizer_attr, op=op, group=tp_group) + print("quantizer_attr after reduce", quantizer_attr) + print("quantizer.attr after reduce", getattr(quantizer, attr)) + assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - reduce_amax(model.fc1.input_quantizer) - reduce_amax(model.fc2.input_quantizer) + _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): for quantizer in model.fc1.weight_quantizer: - reduce_amax(quantizer) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) else: - reduce_amax(model.fc1.weight_quantizer) + _reduce_quantizer_attr( + model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group + ) if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): for quantizer in model.fc2.weight_quantizer: - reduce_amax(quantizer) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) else: - reduce_amax(model.fc2.weight_quantizer) + _reduce_quantizer_attr( + model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group + ) + + # Check act scale + if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: + _reduce_quantizer_attr( + model.fc1.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=tp_group, + ) + _reduce_quantizer_attr( + model.fc2.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=tp_group, + ) def auto_quantize_helper(model): diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 03a6c4ba8..84f8bca63 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -175,7 +175,7 @@ def test_context_parallel(need_2_gpus, config): # 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs) def _test_data_tensor_context_parallel_helper(config, rank, size): - initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED) + initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED + rank) model = MegatronModel(tp_size=2, cp_size=2).cuda() data_tensor_context_parallel_test_helper( diff --git a/tests/gpu/torch/quantization/test_model_calib.py b/tests/gpu/torch/quantization/test_model_calib.py deleted file mode 100644 index 2179b8532..000000000 --- a/tests/gpu/torch/quantization/test_model_calib.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch -import torch.distributed as dist -from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job -from _test_utils.torch_dist.plugins.megatron_common import MegatronModel, initialize_for_megatron -from megatron.core.parallel_state import get_data_parallel_group - -from modelopt.torch.quantization.model_calib import awq_lite - - -def _test_awq_lite_act_scale_sync_helper(rank, size): - initialize_for_megatron(seed=1234 + rank) - model = MegatronModel().cuda() - - calib_data = model.get_dummy_input().cuda() - - def forward_loop(model): - model(calib_data) - - model = awq_lite(model, forward_loop) - # Sanity check - forward_loop(model) - - act_scale = model.fc1.weight_quantizer.awq_lite.act_scale.clone() - dist.all_reduce(act_scale, op=dist.ReduceOp.AVG, group=get_data_parallel_group()) - assert torch.allclose(act_scale, model.fc1.weight_quantizer.awq_lite.act_scale) - - act_scale = model.fc2.weight_quantizer.awq_lite.act_scale.clone() - dist.all_reduce(act_scale, op=dist.ReduceOp.AVG, group=get_data_parallel_group()) - assert torch.allclose(act_scale, model.fc2.weight_quantizer.awq_lite.act_scale) - - -def test_awq_lite_act_scale_sync(need_2_gpus): - spawn_multiprocess_job(size=2, job=_test_awq_lite_act_scale_sync_helper, backend="nccl") From fc0bb884b07839a2ae8f2a9ffe731e4d089f4822 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 30 Sep 2025 00:20:17 +0000 Subject: [PATCH 07/16] fix amax tests Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 1 - .../torch_quantization/quantize_common.py | 47 +++++++++++++------ 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 7a1d9791a..3e974c1f0 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -581,7 +581,6 @@ def forward(self, input, *args, **kwargs): return out_actual for name, module in model.named_modules(): - print(name, module, module.weight_quantizer.is_enabled) if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model): module.awq_lite = AWQLiteHelper(module, name) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index e0fabb84a..d2b4b60e5 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +from unittest.mock import patch import pytest import torch @@ -119,11 +120,26 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None): quantizer_attr = getattr(quantizer, attr).clone() + print("quantizer.attr before reduce", getattr(quantizer, attr)) dist.all_reduce(quantizer_attr, op=op, group=group) + print("quantizer.attr after reduce", getattr(quantizer, attr)) + print("quantizer_attr after reduce", quantizer_attr) assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) -def tensor_parallel_test_helper(model, config, tp_group): +# Store the original function before patching +import modelopt.torch.quantization.model_calib as model_calib_module + +original_awq_lite = model_calib_module.awq_lite + + +def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True): + """Function to mock awq_lite function to always use debug=True for testing""" + return original_awq_lite(model, forward_loop, alpha_step, debug=True) + + +@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) +def tensor_parallel_test_helper(model, config, tp_group, mock_awq_lite): # The input to first layer, the column parallel should be the same across all tp ranks calib_data = model.get_dummy_input().cuda() dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) @@ -138,7 +154,6 @@ def forward_loop(model): if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]: # Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group) - # Lets check the row parallel weight amax; it should be the same across all tp ranks _reduce_quantizer_attr( model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group @@ -152,24 +167,25 @@ def forward_loop(model): ) if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: - # Check act scale + # Check activation scale for AWQ lite _reduce_quantizer_attr( - model.fc1.weight_quantizer.awq_lite.act_scale, + model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, ) + # TODO fc2 assert is failing + """ _reduce_quantizer_attr( - model.fc2.weight_quantizer.awq_lite.act_scale, - "act_scale", - dist.ReduceOp.AVG, - group=tp_group, + model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, ) + """ dist.destroy_process_group() -def dp_cp_parallel_test_helper(model, config, group): +@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) +def dp_cp_parallel_test_helper(model, config, group, mock_awq_lite): calib_data = model.get_dummy_input().cuda() def forward_loop(model): @@ -197,20 +213,23 @@ def forward_loop(model): if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: # Check act scale _reduce_quantizer_attr( - model.fc1.weight_quantizer.awq_lite.act_scale, + model.fc1.weight_quantizer.awq_lite, "act_scale", dist.ReduceOp.AVG, group=group, ) _reduce_quantizer_attr( - model.fc2.weight_quantizer.awq_lite.act_scale, + model.fc2.weight_quantizer.awq_lite, "act_scale", dist.ReduceOp.AVG, group=group, ) -def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): +@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) +def data_tensor_context_parallel_test_helper( + model, config, dp_group, tp_group, cp_group, mock_awq_lite +): calib_data = model.get_dummy_input().cuda() # data should be same across each TP rank dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) @@ -255,13 +274,13 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): # Check act scale if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: _reduce_quantizer_attr( - model.fc1.weight_quantizer.awq_lite.act_scale, + model.fc1.weight_quantizer.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, ) _reduce_quantizer_attr( - model.fc2.weight_quantizer.awq_lite.act_scale, + model.fc2.weight_quantizer.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, From 95da8329e2c7ed9497542d1126e5c31370313160 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 30 Sep 2025 00:21:38 +0000 Subject: [PATCH 08/16] fix awq lite param Signed-off-by: Jennifer Chen --- tests/_test_utils/torch_quantization/quantize_common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index d2b4b60e5..ced82ed2e 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -213,13 +213,13 @@ def forward_loop(model): if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: # Check act scale _reduce_quantizer_attr( - model.fc1.weight_quantizer.awq_lite, + model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, group=group, ) _reduce_quantizer_attr( - model.fc2.weight_quantizer.awq_lite, + model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, group=group, @@ -274,13 +274,13 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): # Check act scale if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: _reduce_quantizer_attr( - model.fc1.weight_quantizer.awq_lite, + model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, ) _reduce_quantizer_attr( - model.fc2.weight_quantizer.awq_lite, + model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, From 34c11ef46cb4d736be0b43e8c59f90387c05ef35 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 30 Sep 2025 19:04:19 +0000 Subject: [PATCH 09/16] fix test Signed-off-by: Jennifer Chen --- examples/nemo_run/qat/README.md | 7 +++--- .../torch_quantization/quantize_common.py | 22 ++++++------------- .../quantization/plugins/test_megatron.py | 4 ++-- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/examples/nemo_run/qat/README.md b/examples/nemo_run/qat/README.md index cd74c96e2..b9be7ba0e 100644 --- a/examples/nemo_run/qat/README.md +++ b/examples/nemo_run/qat/README.md @@ -56,15 +56,16 @@ The resulting exported checkpoint also is much smaller in memory at 6.4GB compar You can run the example either locally or on a [Slurm cluster](ADVANCED.md). -To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.07 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container. +To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.09 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container. - `git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git` -- `git clone https://github.com/NVIDIA-NeMo/NeMo.git && cd NeMo && git checkout 676ed1a` +- `git clone https://github.com/NVIDIA-NeMo/NeMo.git` +- `git clone https://github.com/NVIDIA/Megatron-LM.git` Example docker command: ```bash -docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash +docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt -v /home/user/Megatron-LM:/opt/megatron-lm --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.09 bash ``` You will also need to set your Huggingface token with `export HF_TOKEN=`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written. diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index ced82ed2e..c00fd5b33 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -23,6 +23,7 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq +import modelopt.torch.quantization.model_calib as model_calib_module # needed for patching awq_lite from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer from modelopt.torch.quantization.utils import is_quantized_linear @@ -127,9 +128,6 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) -# Store the original function before patching -import modelopt.torch.quantization.model_calib as model_calib_module - original_awq_lite = model_calib_module.awq_lite @@ -252,24 +250,20 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) - _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) + _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX) if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): for quantizer in model.fc1.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) else: - _reduce_quantizer_attr( - model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group - ) + _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): for quantizer in model.fc2.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) else: - _reduce_quantizer_attr( - model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group - ) + _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) # Check act scale if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: @@ -277,13 +271,11 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, - group=tp_group, ) _reduce_quantizer_attr( model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, - group=tp_group, ) diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 84f8bca63..07c026513 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -214,7 +214,7 @@ def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_devic tensor_model_parallel_size=tp_size, num_layers=4, ffn_hidden_size=None, - num_attention_heads=4, + num_attention_heads=8, activation_func="squared_relu", transformer_impl="local", hidden_size=hidden_size, @@ -226,7 +226,7 @@ def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_devic tensor_model_parallel_size=tp_size, num_layers=4, ffn_hidden_size=None, - num_attention_heads=4, + num_attention_heads=8, activation_func="squared_relu", transformer_impl="local", hidden_size=hidden_size, From 9f0691f5a60fff1d97497983b6cad5219bb78101 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 30 Sep 2025 19:23:31 +0000 Subject: [PATCH 10/16] uncomment test Signed-off-by: Jennifer Chen --- tests/_test_utils/torch_quantization/quantize_common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index c00fd5b33..712a13171 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -172,12 +172,12 @@ def forward_loop(model): dist.ReduceOp.AVG, group=tp_group, ) - # TODO fc2 assert is failing - """ _reduce_quantizer_attr( - model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, + model.fc2.awq_lite, + "act_scale", + dist.ReduceOp.AVG, + group=tp_group, ) - """ dist.destroy_process_group() From fa8f4c8d8433f83255768da75af491bb6ef111b1 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 1 Oct 2025 01:07:14 +0000 Subject: [PATCH 11/16] add print Signed-off-by: Jennifer Chen --- .../torch_quantization/quantize_common.py | 83 ++++++++++++++++--- .../quantization/plugins/test_megatron.py | 6 +- 2 files changed, 74 insertions(+), 15 deletions(-) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 712a13171..bc643ad01 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -119,7 +119,7 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N mto.restore_from_modelopt_state(model_ref, state_dict) -def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None): +def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None): quantizer_attr = getattr(quantizer, attr).clone() print("quantizer.attr before reduce", getattr(quantizer, attr)) dist.all_reduce(quantizer_attr, op=op, group=group) @@ -225,9 +225,46 @@ def forward_loop(model): @patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) -def data_tensor_context_parallel_test_helper( - model, config, dp_group, tp_group, cp_group, mock_awq_lite -): +def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, mock_awq_lite): + # Print rank information for debugging + world_rank = dist.get_rank() + world_size = dist.get_world_size() + + print("\n=== RANK INFORMATION ===") + print(f"World Rank: {world_rank}, World Size: {world_size}") + + # Get group information with actual ranks + def get_group_ranks(group): + if group is None: + return None + ranks = [] + ranks = [ + i for i in range(world_size) if dist.get_rank(group=group) == dist.get_rank(group=group) + ] + return ranks + + if dp_group is not None: + dp_rank = dist.get_rank(group=dp_group) + dp_size = dist.get_world_size(group=dp_group) + print(f"DP Group - Rank: {dp_rank}, Size: {dp_size}") + + if tp_group is not None: + tp_rank = dist.get_rank(group=tp_group) + tp_size = dist.get_world_size(group=tp_group) + print(f"TP Group - Rank: {tp_rank}, Size: {tp_size}") + + print("=== END RANK INFO ===\n") + + # Print a summary of all ranks + print("=== ALL RANKS SUMMARY ===") + print(f"Total GPUs: {world_size}") + print(f"Current rank: {world_rank}") + if dp_group is not None: + print(f"DP groups: {dp_size} groups of {world_size // dp_size} ranks each") + if tp_group is not None: + print(f"TP groups: {tp_size} groups of {world_size // tp_size} ranks each") + print("=== END SUMMARY ===\n") + calib_data = model.get_dummy_input().cuda() # data should be same across each TP rank dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) @@ -238,14 +275,38 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): + world_rank = dist.get_rank() + print(f"\n--- Rank {world_rank}: Reducing {attr} ---") + from megatron.core.parallel_state import ( + _CONTEXT_PARALLEL_GLOBAL_RANKS, + _DATA_PARALLEL_GLOBAL_RANKS, + _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP, + _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS, + ) + + print(f"DATA_PARALLEL_GLOBAL_RANKS: {_DATA_PARALLEL_GLOBAL_RANKS}") + print(f"CONTEXT_PARALLEL_GLOBAL_RANKS: {_CONTEXT_PARALLEL_GLOBAL_RANKS}") + print(f"DATA_PARALLEL_GLOBAL_RANKS_WITH_CP: {_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP}") + print(f"TENSOR_MODEL_PARALLEL_GLOBAL_RANKS: {_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS}") quantizer_attr = getattr(quantizer, attr).clone() - print("quantizer_attr before reduce", quantizer_attr) - print("quantizer.attr before reduce", getattr(quantizer, attr)) - dist.all_reduce(quantizer_attr, op=op, group=dp_group) - dist.all_reduce(quantizer_attr, op=op, group=cp_group) - dist.all_reduce(quantizer_attr, op=op, group=tp_group) - print("quantizer_attr after reduce", quantizer_attr) - print("quantizer.attr after reduce", getattr(quantizer, attr)) + print(f"Rank {world_rank} - quantizer_attr before reduce", quantizer_attr) + print(f"Rank {world_rank} - quantizer.attr before reduce", getattr(quantizer, attr)) + + # Perform all-reduce operations + if tp_group is not None: + tp_rank = dist.get_rank(group=tp_group) + print(f"Rank {world_rank} - TP reduce (TP rank {tp_rank})") + dist.all_reduce(quantizer_attr, op=op, group=tp_group) + + if dp_group is not None: + dp_rank = dist.get_rank(group=dp_group) + print(f"Rank {world_rank} - DP reduce (DP rank {dp_rank})") + dist.all_reduce(quantizer_attr, op=op, group=dp_group) + + print(f"Rank {world_rank} - quantizer_attr after reduce", quantizer_attr) + print(f"Rank {world_rank} - quantizer.attr after reduce", getattr(quantizer, attr)) + print(f"--- End Rank {world_rank} ---\n") + assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) # Input quantizer amax diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 07c026513..b45c191bc 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -42,7 +42,6 @@ import megatron.core from megatron.core.parallel_state import ( destroy_model_parallel, - get_context_parallel_group, get_data_parallel_group, get_tensor_model_parallel_group, ) @@ -152,7 +151,7 @@ def _test_context_parallel_helper(config, rank, size): ) # modify seed so data is different across ranks model = MegatronModel(cp_size=size).cuda() - dp_cp_parallel_test_helper(model, config, get_context_parallel_group()) + dp_cp_parallel_test_helper(model, config, get_data_parallel_group(with_context_parallel=True)) @pytest.mark.parametrize( @@ -181,9 +180,8 @@ def _test_data_tensor_context_parallel_helper(config, rank, size): data_tensor_context_parallel_test_helper( model, config, - get_data_parallel_group(), + get_data_parallel_group(with_context_parallel=True), get_tensor_model_parallel_group(), - get_context_parallel_group(), ) From d1fac44cd7826fb2c44e3de208f7995402d81a7f Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 1 Oct 2025 20:58:09 +0000 Subject: [PATCH 12/16] docstring Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3e974c1f0..8e10df7df 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -117,6 +117,7 @@ def sync_quantizer_amax_across_tp( ): if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: + "Syncing amax across TP for sequential quantizer" sync_quantizer_amax_across_tp( _q, linear_name, quantizer_type, axes_for_sync, parallel_state ) From 22b8b73cc82ea2e5c285cddf0b26a3ff5b593025 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Thu, 2 Oct 2025 00:09:31 +0000 Subject: [PATCH 13/16] fix tests Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 21 ++-- .../torch/quantization/plugins/megatron.py | 5 +- modelopt/torch/utils/distributed.py | 3 - .../torch_dist/plugins/megatron_common.py | 6 +- .../torch_quantization/quantize_common.py | 117 ++++-------------- .../quantization/plugins/test_megatron.py | 2 +- 6 files changed, 42 insertions(+), 112 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 8e10df7df..d9189768c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -80,22 +80,21 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis if not distributed_sync: return - def sync_quantizer_amax_across_dp_cp(quantizer, parallel_state): - """Synchronize the amax across all ranks in the data parallel and context parallel groups.""" + def sync_quantizer_amax_across_dp(quantizer, parallel_state): + """Synchronize the amax across all ranks in the data parallel group.""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - sync_quantizer_amax_across_dp_cp(_q, parallel_state) + sync_quantizer_amax_across_dp(_q, parallel_state) return if getattr(quantizer, "_amax", None) is not None: quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) - quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group) # TODO: create sync_bias_across_distributed_group for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): - sync_quantizer_amax_across_dp_cp(child, module.parallel_state) + sync_quantizer_amax_across_dp(child, module.parallel_state) # TP sync: # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same @@ -600,17 +599,12 @@ def forward(self, input, *args, **kwargs): # This will also perform distributed amax sync for input_quantizers max_calibrate(model, lambda model: None) - def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group): - # Sync across Data Parallel (DP) + def sync_act_scale_across_dp(module, data_parallel_group): + """Sync activation scale across Data Parallel (DP).""" if data_parallel_group.is_initialized(): dist.all_reduce( module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group ) - # Sync across Context Parallel (CP) - if context_parallel_group.is_initialized(): - dist.all_reduce( - module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group - ) for name, module in model.named_modules(): if ( @@ -627,10 +621,9 @@ def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_gr ): module.awq_lite.is_enabled = False else: - sync_act_scale_across_dp_cp( + sync_act_scale_across_dp( module, module.parallel_state.data_parallel_group, - module.parallel_state.context_parallel_group, ) AWQLiteHelper.cache_mode = False diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index b41bf7a58..85784d2fe 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -15,6 +15,7 @@ """Support quantization for megatron linear layers.""" +import logging import warnings from typing import Any @@ -39,6 +40,8 @@ from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear +logger = logging.getLogger(__name__) + __all__ = [] @@ -222,11 +225,11 @@ def _setup(self): try: data_parallel_group = get_data_parallel_group(with_context_parallel=True) except AssertionError: + logger.warning("Context parallel group is not initialized, using data parallel group") data_parallel_group = get_data_parallel_group() self.parallel_state = ParallelState( data_parallel_group, mcore_parallel.get_tensor_model_parallel_group(), - mcore_parallel.get_context_parallel_group(), ) super()._setup() diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index f6987a3f6..f11a736db 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -241,18 +241,15 @@ def __init__( self, data_parallel_group: torch.distributed.ProcessGroup | int | None = None, tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1, - context_parallel_group: torch.distributed.ProcessGroup | int | None = -1, ): """Initialize the parallel state.""" self.data_parallel_group = DistributedProcessGroup(data_parallel_group) self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group) - self.context_parallel_group = DistributedProcessGroup(context_parallel_group) def __repr__(self) -> str: return ( f"data_parallel_group: {self.data_parallel_group}, " f"tensor_parallel_group: {self.tensor_parallel_group}, " - f"context_parallel_group: {self.context_parallel_group}" ) diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 6324d3390..9d2b0c047 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -127,7 +127,11 @@ def forward(self, x): x = x[0] return x - def get_dummy_input(self) -> torch.Tensor: + def get_dummy_input(self, seed: int | None = None) -> torch.Tensor: + if seed is not None: + gen = torch.Generator() + gen.manual_seed(seed) + return torch.randn(1, 4, 32, generator=gen) return torch.randn(1, 4, 32) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index bc643ad01..6dbb5b213 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -172,12 +172,6 @@ def forward_loop(model): dist.ReduceOp.AVG, group=tp_group, ) - _reduce_quantizer_attr( - model.fc2.awq_lite, - "act_scale", - dist.ReduceOp.AVG, - group=tp_group, - ) dist.destroy_process_group() @@ -191,6 +185,9 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) + # Sanity check + forward_loop(model) + # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) @@ -226,48 +223,9 @@ def forward_loop(model): @patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, mock_awq_lite): - # Print rank information for debugging - world_rank = dist.get_rank() - world_size = dist.get_world_size() - - print("\n=== RANK INFORMATION ===") - print(f"World Rank: {world_rank}, World Size: {world_size}") - - # Get group information with actual ranks - def get_group_ranks(group): - if group is None: - return None - ranks = [] - ranks = [ - i for i in range(world_size) if dist.get_rank(group=group) == dist.get_rank(group=group) - ] - return ranks - - if dp_group is not None: - dp_rank = dist.get_rank(group=dp_group) - dp_size = dist.get_world_size(group=dp_group) - print(f"DP Group - Rank: {dp_rank}, Size: {dp_size}") - - if tp_group is not None: - tp_rank = dist.get_rank(group=tp_group) - tp_size = dist.get_world_size(group=tp_group) - print(f"TP Group - Rank: {tp_rank}, Size: {tp_size}") - - print("=== END RANK INFO ===\n") - - # Print a summary of all ranks - print("=== ALL RANKS SUMMARY ===") - print(f"Total GPUs: {world_size}") - print(f"Current rank: {world_rank}") - if dp_group is not None: - print(f"DP groups: {dp_size} groups of {world_size // dp_size} ranks each") - if tp_group is not None: - print(f"TP groups: {tp_size} groups of {world_size // tp_size} ranks each") - print("=== END SUMMARY ===\n") - - calib_data = model.get_dummy_input().cuda() - # data should be same across each TP rank - dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) + # Calib data should be same across each DP rank + dp_rank = dist.get_rank(group=dp_group) + calib_data = model.get_dummy_input(seed=dp_rank).cuda() def forward_loop(model): model(calib_data) @@ -275,56 +233,36 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): - world_rank = dist.get_rank() - print(f"\n--- Rank {world_rank}: Reducing {attr} ---") - from megatron.core.parallel_state import ( - _CONTEXT_PARALLEL_GLOBAL_RANKS, - _DATA_PARALLEL_GLOBAL_RANKS, - _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP, - _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS, - ) - - print(f"DATA_PARALLEL_GLOBAL_RANKS: {_DATA_PARALLEL_GLOBAL_RANKS}") - print(f"CONTEXT_PARALLEL_GLOBAL_RANKS: {_CONTEXT_PARALLEL_GLOBAL_RANKS}") - print(f"DATA_PARALLEL_GLOBAL_RANKS_WITH_CP: {_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP}") - print(f"TENSOR_MODEL_PARALLEL_GLOBAL_RANKS: {_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS}") quantizer_attr = getattr(quantizer, attr).clone() - print(f"Rank {world_rank} - quantizer_attr before reduce", quantizer_attr) - print(f"Rank {world_rank} - quantizer.attr before reduce", getattr(quantizer, attr)) # Perform all-reduce operations - if tp_group is not None: - tp_rank = dist.get_rank(group=tp_group) - print(f"Rank {world_rank} - TP reduce (TP rank {tp_rank})") - dist.all_reduce(quantizer_attr, op=op, group=tp_group) + dist.all_reduce(quantizer_attr, op=op, group=tp_group) - if dp_group is not None: - dp_rank = dist.get_rank(group=dp_group) - print(f"Rank {world_rank} - DP reduce (DP rank {dp_rank})") - dist.all_reduce(quantizer_attr, op=op, group=dp_group) + dist.all_reduce(quantizer_attr, op=op, group=dp_group) - print(f"Rank {world_rank} - quantizer_attr after reduce", quantizer_attr) - print(f"Rank {world_rank} - quantizer.attr after reduce", getattr(quantizer, attr)) - print(f"--- End Rank {world_rank} ---\n") - - assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) + assert torch.allclose(quantizer_attr, getattr(quantizer, attr)), getattr(quantizer, attr) # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX) - if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): - for quantizer in model.fc1.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) - else: - _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) - - if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): - for quantizer in model.fc2.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) - else: - _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) + # Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks + # Channel-wise (INT8) only expects same amax across row parallel ranks + # Block-wise quantization does not expect same amax across row and column parallel ranks + if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]: + if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc1.weight_quantizer: + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) + else: + _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) + + if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT8_DEFAULT_CFG]: + if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc2.weight_quantizer: + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) + else: + _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) # Check act scale if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: @@ -333,11 +271,6 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): "act_scale", dist.ReduceOp.AVG, ) - _reduce_quantizer_attr( - model.fc2.awq_lite, - "act_scale", - dist.ReduceOp.AVG, - ) def auto_quantize_helper(model): diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index b45c191bc..f71359044 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -199,7 +199,7 @@ def _test_data_tensor_context_parallel_helper(config, rank, size): ) def test_data_tensor_context_parallel(need_8_gpus, config): spawn_multiprocess_job( - size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" + size=4, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" ) From 3f857a32b42fd23fbf8ce1efa749e2a2102e73a7 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Thu, 2 Oct 2025 21:17:08 +0000 Subject: [PATCH 14/16] fix multiprocess size Signed-off-by: Jennifer Chen --- tests/gpu/torch/quantization/plugins/test_megatron.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index f71359044..b45c191bc 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -199,7 +199,7 @@ def _test_data_tensor_context_parallel_helper(config, rank, size): ) def test_data_tensor_context_parallel(need_8_gpus, config): spawn_multiprocess_job( - size=4, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" + size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" ) From 4dc16b0842e74af5517c501bf157195ff1f076fe Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Tue, 7 Oct 2025 02:24:14 +0000 Subject: [PATCH 15/16] Added quantization support for TEGroupedMoE for megatron-lm Signed-off-by: Kinjal Patel --- modelopt/torch/quantization/model_calib.py | 13 +- .../torch/quantization/plugins/megatron.py | 172 +++++++++- modelopt/torch/utils/distributed.py | 14 +- .../torch_dist/plugins/megatron_common.py | 226 ++++++++++++- tests/gpu/torch/conftest.py | 6 + .../quantization/plugins/test_megatron.py | 302 +++++++++++++++++- 6 files changed, 716 insertions(+), 17 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index d9189768c..d3cecc00c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -80,21 +80,26 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis if not distributed_sync: return - def sync_quantizer_amax_across_dp(quantizer, parallel_state): - """Synchronize the amax across all ranks in the data parallel group.""" + def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): + """Synchronize the amax across all ranks in the data parallel and context parallel groups.""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - sync_quantizer_amax_across_dp(_q, parallel_state) + sync_quantizer_amax_across_dp_ep(_q, parallel_state) return if getattr(quantizer, "_amax", None) is not None: quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) + quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) + if parallel_state.expert_tensor_parallel_group is not None: + quantizer.sync_amax_across_distributed_group( + parallel_state.expert_tensor_parallel_group + ) # TODO: create sync_bias_across_distributed_group for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): - sync_quantizer_amax_across_dp(child, module.parallel_state) + sync_quantizer_amax_across_dp_ep(child, module.parallel_state) # TP sync: # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 85784d2fe..dfd4b42be 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -23,6 +23,8 @@ import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp import torch +import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear +from megatron.core.extensions import transformer_engine as megatron_te from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule @@ -35,8 +37,8 @@ ) from modelopt.torch.utils.distributed import ParallelState -from ..nn import QuantModuleRegistry, TensorQuantizer -from ..nn.modules.quant_linear import RealQuantLinear +from ..nn import QuantModuleRegistry, SequentialQuantizer, TensorQuantizer +from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear @@ -227,9 +229,17 @@ def _setup(self): except AssertionError: logger.warning("Context parallel group is not initialized, using data parallel group") data_parallel_group = get_data_parallel_group() + + try: + expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() + except AssertionError: + expert_tensor_parallel_group = None + self.parallel_state = ParallelState( data_parallel_group, mcore_parallel.get_tensor_model_parallel_group(), + mcore_parallel.get_expert_model_parallel_group(), + expert_tensor_parallel_group, ) super()._setup() @@ -472,3 +482,161 @@ class _RealQuantMegatronRowParallelLinear( def forward(self, input, *args, **kwargs): return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs) + + +# Register the public te.pytorch.GroupedLinear class +@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear_public"}) +class _QuantTEGroupedLinear(_MegatronParallelLinear): + def _setup(self): + data_parallel_group = None + try: + data_parallel_group = get_data_parallel_group(with_context_parallel=True) + except AssertionError: + data_parallel_group = get_data_parallel_group() + + try: + expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() + except AssertionError: + expert_tensor_parallel_group = None + self.parallel_state = ParallelState( + data_parallel_group, + mcore_parallel.get_tensor_model_parallel_group(), + mcore_parallel.get_context_parallel_group(), + mcore_parallel.get_expert_model_parallel_group(), + expert_tensor_parallel_group, + ) + self.input_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_input) + self.weight_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_weight) + self.output_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_output) + self.output_quantizer.disable() + + # Memorize the original weight.dtype for modelopt_post_restore given that + # the dtype can change later. + self.original_weight_dtype = None if self.weight0 is None else self.weight0.dtype + + @property + def functionals_to_replace(self): + original_forward = te_grouped_linear._GroupedLinear.forward + + def te_grouped_quantized_linear_fn(ctx, inp, m_splits, *args): + num_gemms = len(m_splits) + weights_and_biases = args[-2 * num_gemms :] + weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] + quantized_inputs = self.input_quantizer(inp) + quantized_weights = [self.weight_quantizer(weight) for weight in weights] + + output = original_forward( + ctx, + quantized_inputs, + m_splits, + *args[: -2 * num_gemms], + *quantized_weights, + *biases, + ) + return self.output_quantizer(output) + + return [ + ( + te_grouped_linear._GroupedLinear, + "forward", + te_grouped_quantized_linear_fn, + ), + ] + + def modelopt_post_restore(self, prefix: str = ""): + """Post restore to correctly configure the TensorQuantizer states for MCore/distributed frameworks. + + ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their + shape before saving. However this is not enough for MCore/distributed frameworks since the tensor parallelism + could change between saving and restoring. If the tensor parallelism changes, the shape of the quantizer + states also changes. So we need to re-calculate the quantizer states. + """ + from modelopt.torch.quantization.model_calib import max_calibrate + + def _check_unsupported_states(quantizer: TensorQuantizer): + for k in quantizer.state_dict(): + if k not in ["_amax", "_pre_quant_scale"]: + warnings.warn( + f"Restore of {k} for {prefix} is not supported. The restore of this layer might be " + f"incorrect. Please implement a custom restore for {k}." + ) + + def _has_state(quantizer, name): + # Handling for SequentialQuantizer + quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer + return hasattr(quantizer, name) + + # weights for TEGroupedLinear are stored in weight0, weight1, etc. + if self.weight0 is None: + return + for quantizer in [self.weight_quantizer, self.input_quantizer, self.output_quantizer]: + _check_unsupported_states( + quantizer if isinstance(quantizer, TensorQuantizer) else quantizer[0] + ) + if _has_state(self.weight_quantizer, "_amax"): + self.weight_quantizer.reset_amax() + for i in range(self.num_gemms): + weight = getattr(self, f"weight{i}") + assert weight is not None, "weight is None" + + max_calibrate(self.weight_quantizer, lambda wq: wq(weight), distributed_sync=False) + if _has_state(self.input_quantizer, "_pre_quant_scale"): + if hasattr(self.input_quantizer, "_pre_quant_scale"): + delattr(self.input_quantizer, "_pre_quant_scale") + pqs = torch.zeros( + (weight.shape[1]), device=weight.device, dtype=self.original_weight_dtype + ) + self.input_quantizer.register_buffer("_pre_quant_scale", pqs) + + if _has_state(self.input_quantizer, "_amax"): + self.input_quantizer.reset_amax() + dummy_input = torch.ones( + (1, 1, self.weight0.shape[1]), + device=self.weight0.device, + dtype=self.original_weight_dtype, + ) + max_calibrate(self.input_quantizer, lambda iq: iq(dummy_input), distributed_sync=False) + if _has_state(self.output_quantizer, "_amax"): + self.output_quantizer.reset_amax() + dummy_input = torch.ones( + (1, 1, self.weight0.shape[0]), + device=self.weight0.device, + dtype=self.original_weight_dtype, + ) + max_calibrate(self.output_quantizer, lambda oq: oq(dummy_input), distributed_sync=False) + # If there are any other states, lets move them to the correct device + + self.weight = None + super().modelopt_post_restore(prefix=prefix) + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in + # sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for + # TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] + # for modelopt checkpoint restore + filtered_state_dict = { + k: v + for k, v in state_dict.items() + if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms)) + } + return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs) + + def _process_quantizer_amax(self, k, v, quantizer_state_dict): + if v.ndim == 4: + quantizer_state_dict[k] = v.squeeze(1).squeeze(-1) + else: + quantizer_state_dict[k] = v.view(-1, 1) if v.numel() > 1 else v.view(-1) + + +@QuantModuleRegistry.register( + {megatron_te.TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"} +) +class _QuantTEGroupedColumnParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear): + _is_column_parallel = True + + +@QuantModuleRegistry.register( + {megatron_te.TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"} +) +class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear): + _is_row_parallel = True diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index f11a736db..3d80d1464 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -241,16 +241,28 @@ def __init__( self, data_parallel_group: torch.distributed.ProcessGroup | int | None = None, tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1, + expert_model_parallel_group: torch.distributed.ProcessGroup | int | None = -1, + expert_tensor_parallel_group: torch.distributed.ProcessGroup | int | None = None, ): """Initialize the parallel state.""" self.data_parallel_group = DistributedProcessGroup(data_parallel_group) self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group) + self.expert_model_parallel_group = DistributedProcessGroup(expert_model_parallel_group) + self.expert_tensor_parallel_group = None + if expert_tensor_parallel_group is not None: + self.expert_tensor_parallel_group = DistributedProcessGroup( + expert_tensor_parallel_group + ) def __repr__(self) -> str: - return ( + parallel_groups = ( f"data_parallel_group: {self.data_parallel_group}, " f"tensor_parallel_group: {self.tensor_parallel_group}, " + f"expert_model_parallel_group: {self.expert_model_parallel_group}" ) + if self.expert_tensor_parallel_group: + parallel_groups += f"expert_tensor_parallel_group: {self.expert_tensor_parallel_group}" + return parallel_groups def get_group(ranks: list[int]): diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 9d2b0c047..68268805e 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -37,6 +37,8 @@ ) from megatron.core.models.mamba import MambaModel from megatron.core.parallel_state import ( + get_expert_model_parallel_group, + get_expert_tensor_parallel_group, initialize_model_parallel, is_pipeline_first_stage, is_pipeline_last_stage, @@ -48,12 +50,14 @@ from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig +import modelopt.torch.quantization as mtq from modelopt.torch.export.unified_export_megatron import import_mcore_gpt_from_hf from modelopt.torch.opt.plugins.mcore_dist_checkpointing import ( restore_sharded_modelopt_state, save_sharded_modelopt_state, ) from modelopt.torch.utils import to_empty_if_meta_device +from modelopt.torch.utils.distributed import DistributedProcessGroup try: from megatron.core.extensions.transformer_engine import TENorm @@ -138,6 +142,8 @@ def get_dummy_input(self, seed: int | None = None) -> torch.Tensor: def get_mcore_gpt_model( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + expert_model_parallel_size: int = 1, + expert_tensor_parallel_size: int = 1, initialize_megatron: bool = False, *, num_layers: int = 2, @@ -153,7 +159,10 @@ def get_mcore_gpt_model( normalization: str = "LayerNorm", transformer_impl: str = "modelopt" if HAS_TE else "local", use_cpu_initialization: bool = False, + num_moe_experts: int | None = None, + moe_grouped_gemm: bool = False, bf16: bool = True, + use_te: bool = False, ) -> GPTModel: assert activation_func in ["swiglu", "squared_relu"] assert normalization in ["LayerNorm", "RMSNorm"] @@ -161,7 +170,12 @@ def get_mcore_gpt_model( print(f"Using `{transformer_impl=}` model spec for building GPT Model.") if initialize_megatron: - initialize_for_megatron(tensor_model_parallel_size, pipeline_model_parallel_size) + initialize_for_megatron( + tensor_model_parallel_size, + pipeline_model_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + expert_tensor_parallel_size=expert_tensor_parallel_size, + ) def squared_relu(x): return torch.pow(F.relu(x), 2) @@ -169,7 +183,10 @@ def squared_relu(x): config = TransformerConfig( tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + expert_tensor_parallel_size=expert_tensor_parallel_size, sequence_parallel=False, + moe_grouped_gemm=moe_grouped_gemm, num_layers=num_layers, num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, @@ -177,6 +194,7 @@ def squared_relu(x): num_attention_heads=num_attention_heads, num_query_groups=num_query_groups, ffn_hidden_size=ffn_hidden_size, + num_moe_experts=num_moe_experts, activation_func=squared_relu if activation_func == "squared_relu" else F.silu, normalization=normalization, gated_linear_unit=(activation_func == "swiglu"), @@ -188,7 +206,12 @@ def squared_relu(x): if transformer_impl == "local": assert HAS_APEX, "Apex not installed" - transformer_layer_spec = get_gpt_layer_local_spec(normalization=normalization) + transformer_layer_spec = get_gpt_layer_local_spec( + num_experts=num_moe_experts, + normalization=normalization, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, + ) else: assert HAS_TE, "Transformer Engine not installed" transformer_layer_spec = ( @@ -207,6 +230,7 @@ def squared_relu(x): share_embeddings_and_output_weights=False, position_embedding_type="rope", ) + if bf16: model = model.to(torch.bfloat16) @@ -392,6 +416,8 @@ def initialize_for_megatron( pipeline_model_parallel_size=1, seed=1234, context_parallel_size=1, + expert_model_parallel_size=1, + expert_tensor_parallel_size=None, ): """Initialize Megatron model parallelism. @@ -401,6 +427,9 @@ def initialize_for_megatron( tensor_model_parallel_size, pipeline_model_parallel_size, context_parallel_size=context_parallel_size, + expert_tensor_parallel_size=expert_tensor_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + order="tp-ep-dp-pp", ) model_parallel_cuda_manual_seed(seed) @@ -467,3 +496,196 @@ def convert_maybe_fp8(v): assert torch.allclose(logits_ref, logits_test), ( f"diff: {logits_diff.max()} ref: {logits_ref}, test: {logits_test}" ) + + +def compare_model_outputs(grouped_model, non_grouped_model, forward_fn, tolerance=1e-6): + """Compare outputs of grouped and non-grouped models.""" + # Set both models to eval mode + grouped_model.eval() + non_grouped_model.eval() + + with torch.no_grad(): + # Get outputs from both models + grouped_output = forward_fn(grouped_model) + non_grouped_output = forward_fn(non_grouped_model) + + # Compare outputs + if isinstance(grouped_output, tuple): + grouped_output = grouped_output[0] + if isinstance(non_grouped_output, tuple): + non_grouped_output = non_grouped_output[0] + + output_close = torch.allclose( + grouped_output, non_grouped_output, atol=tolerance, rtol=tolerance + ) + return output_close + + +def sync_amax(model): + amax_dict = { + "linear_fc1.input_quantizer": {}, + "linear_fc1.weight_quantizer": {}, + "linear_fc2.input_quantizer": {}, + "linear_fc2.weight_quantizer": {}, + } + for name, module in model.named_modules(): + if not isinstance(module, mtq.nn.TensorQuantizer): + continue + if not hasattr(module, "_amax"): + continue + if "local_experts" not in name: + continue + expert_name, local_expert_name = name.split("local_experts") + for key in amax_dict: + if key in local_expert_name: + amax_dict[key][expert_name] = max(amax_dict[key].get(expert_name, 0), module.amax) + + for name, module in model.named_modules(): + if not isinstance(module, mtq.nn.TensorQuantizer): + continue + if not hasattr(module, "_amax"): + continue + if "local_experts" not in name: + continue + expert_name, local_expert_name = name.split("local_experts") + for key in amax_dict: + if key in local_expert_name: + module.amax = amax_dict[key][expert_name] + + +def copy_weights_from_grouped_to_non_grouped(grouped_model, non_grouped_model): + """Copy weights from grouped MoE model to non-grouped MoE model.""" + grouped_state = grouped_model.state_dict() + non_grouped_state = non_grouped_model.state_dict() + + # Map grouped weights to non-grouped weights + weight_mapping = {} + non_grouped_key_template = "decoder.layers.{}.mlp.experts.local_experts.{}.linear_fc{}.weight" + for key, value in grouped_state.items(): + if "experts.linear_fc" in key and "weight" in key: + # Extract expert index from grouped weight name + # Format: decoder.layers.X.mlp.experts.linear_fcY.weightZ + parts = key.split(".") + layer_idx = parts[2] # X + fc_idx = parts[5] # Y (linear_fc1 or linear_fc2) + weight_idx = parts[6] # Z (weight0, weight1, etc.) + + # Map to non-grouped format: decoder.layers.X.mlp.experts.local_experts.Y.linear_fcZ.weight + expert_idx = weight_idx.replace("weight", "") + non_grouped_key = non_grouped_key_template.format(layer_idx, expert_idx, fc_idx[-1]) + weight_mapping[non_grouped_key] = value + elif isinstance(value, torch.Tensor): + weight_mapping[key] = value + + # Copy weights to non-grouped model + for non_grouped_key in non_grouped_state: + if non_grouped_key in weight_mapping: + non_grouped_state[non_grouped_key] = weight_mapping[non_grouped_key].clone() + + non_grouped_model.load_state_dict(non_grouped_state) + + +def compare_amax_sync_across_expert_parallel(model): + """ + Test if amax values are synchronized across expert parallel groups. + + Returns True if synchronized, False otherwise. + """ + + ep_group = get_expert_model_parallel_group(check_initialized=False) + etp_group = get_expert_tensor_parallel_group(check_initialized=False) + + # Check if we have either expert model parallel or expert tensor parallel + has_expert_parallel = (ep_group is not None and ep_group.size() > 1) or ( + etp_group is not None and etp_group.size() > 1 + ) + + assert has_expert_parallel, "No expert parallelism detected" + # Collect amax values from expert quantizers only + expert_amax_values = {} + for name, module in model.named_modules(): + if isinstance(module, mtq.nn.TensorQuantizer) and hasattr(module, "_amax"): + # Check for both grouped and non-grouped MoE patterns + if "local_experts" in name or ("experts" in name and "linear_fc" in name): + expert_amax_values[name] = ( + module.amax.item() if hasattr(module.amax, "item") else module.amax + ) + + # Early return if no expert quantizers found + assert expert_amax_values, "No expert quantizers found" + + # Gather amax values from all ranks + world_size = torch.distributed.get_world_size() + all_amax_values = [None] * world_size + torch.distributed.all_gather_object(all_amax_values, expert_amax_values) + + # Group quantizers by type (ignoring specific expert indices) and check sync + expert_quantizers = {} + for rank_idx, rank_amax in enumerate(all_amax_values): + for name, amax_val in rank_amax.items(): + # Create quantizer type key by normalizing the name + if "local_experts" in name: + # Non-grouped MoE: replace expert index with wildcard + import re + + quantizer_type = re.sub(r"local_experts\.\d+", "local_experts.*", name) + else: + # Grouped MoE: use the name as-is since experts are grouped + quantizer_type = name + + if quantizer_type not in expert_quantizers: + expert_quantizers[quantizer_type] = {} + expert_quantizers[quantizer_type][rank_idx] = amax_val + + # Check synchronization - fail fast on first inconsistency + for quantizer_type, rank_values in expert_quantizers.items(): + if len(rank_values) > 1: # Only check if we have multiple ranks + values = list(rank_values.values()) + max_diff = max(values) - min(values) + + if max_diff > 1e-6: # Allow for small floating point differences + return False + + return True + + +def disable_distributed_parallel_sync(model, expert_parallel_type: str = "tensor"): + """Disable distributed parallel synchronization groups.""" + module_parallel_groups = {} + + for name, module in model.named_modules(): + if isinstance(module, mtq.nn.QuantModule): + # Store original groups + module_parallel_groups[name] = { + "data_parallel_group": module.parallel_state.data_parallel_group, + "expert_tensor_parallel_group": module.parallel_state.expert_tensor_parallel_group, + "expert_model_parallel_group": module.parallel_state.expert_model_parallel_group, + } + + # Disable groups + module.parallel_state.data_parallel_group = DistributedProcessGroup(-1) + + if expert_parallel_type in ["tensor", "both"]: + module.parallel_state.expert_tensor_parallel_group = DistributedProcessGroup(-1) + if expert_parallel_type in ["model", "both"]: + module.parallel_state.expert_model_parallel_group = DistributedProcessGroup(-1) + + return module_parallel_groups + + +def enable_distributed_parallel_sync( + model, module_parallel_groups, expert_parallel_type: str = "tensor" +): + """Re-enable distributed parallel synchronization groups.""" + for name, module in model.named_modules(): + if isinstance(module, mtq.nn.QuantModule) and name in module_parallel_groups: + groups = module_parallel_groups[name] + + if expert_parallel_type in ["tensor", "both"]: + module.parallel_state.expert_tensor_parallel_group = groups[ + "expert_tensor_parallel_group" + ] + if expert_parallel_type in ["model", "both"]: + module.parallel_state.expert_model_parallel_group = groups[ + "expert_model_parallel_group" + ] diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index f32065bce..d1ba9dd47 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -40,6 +40,12 @@ def need_8_gpus(): pytest.skip("Need at least 8 GPUs to run this test") +@pytest.fixture +def need_4_gpus(): + if torch.cuda.device_count() < 4: + pytest.skip("Need at least 4 GPUs to run this test") + + @pytest.fixture(scope="module") def set_torch_dtype(request): orig_dtype = torch.get_default_dtype() diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index b45c191bc..0993e2565 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -21,10 +21,16 @@ from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job from _test_utils.torch_dist.plugins.megatron_common import ( MegatronModel, + compare_amax_sync_across_expert_parallel, + compare_model_outputs, + copy_weights_from_grouped_to_non_grouped, + disable_distributed_parallel_sync, + enable_distributed_parallel_sync, get_mcore_gpt_model, initialize_for_megatron, run_mcore_inference, sharded_state_dict_test_helper, + sync_amax, ) from _test_utils.torch_misc import set_seed from _test_utils.torch_quantization.models import RegularQuantModelForTP @@ -46,6 +52,7 @@ get_tensor_model_parallel_group, ) from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP import modelopt import modelopt.torch.opt as mto @@ -203,13 +210,25 @@ def test_data_tensor_context_parallel(need_8_gpus, config): ) -def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): +def _gpt_model_provider( + tp_size: int, + hidden_size=256, + vocab_size=64, + num_moe_experts=None, + moe_grouped_gemm=False, + meta_device=False, + ep_size=1, + etp_size=None, + use_te=False, +): """Build the model.""" if meta_device: with torch.device("meta"): gpt_model = get_mcore_gpt_model( tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, num_layers=4, ffn_hidden_size=None, num_attention_heads=8, @@ -218,10 +237,15 @@ def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_devic hidden_size=hidden_size, vocab_size=vocab_size, use_cpu_initialization=meta_device, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, ) else: gpt_model = get_mcore_gpt_model( tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, num_layers=4, ffn_hidden_size=None, num_attention_heads=8, @@ -229,12 +253,15 @@ def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_devic transformer_impl="local", hidden_size=hidden_size, vocab_size=vocab_size, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, ).cuda() return gpt_model.eval() def _test_sharded_state_dict( - tmp_path, config, hidden_size, modelopt_version, compress, meta_device, rank, size + tmp_path, config, hidden_size, modelopt_version, compress, meta_device, moe_config, rank, size ): # Must disable output_layer quantization since output_layer amax cannot be restore via # sharded_state_dict. All output_layer quantizers state are removed. @@ -244,10 +271,42 @@ def _test_sharded_state_dict( mto.conversion.__version__ = modelopt_version mtq.plugins.megatron.__version__ = modelopt_version - initialize_for_megatron(tensor_model_parallel_size=size, seed=SEED) + tp_size = moe_config.get("tp_size", size) + ep_size = moe_config.get("ep_size", 1) + etp_size = moe_config.get("etp_size", None) + num_moe_experts = moe_config.get("num_moe_experts", None) + moe_grouped_gemm = moe_config.get("moe_grouped_gemm", False) + use_te = moe_config.get("use_te", False) + + initialize_for_megatron( + tensor_model_parallel_size=tp_size, + seed=SEED, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, + ) - model_ref = _gpt_model_provider(size, hidden_size, vocab_size=256) - model_test = _gpt_model_provider(size, hidden_size, vocab_size=256, meta_device=meta_device) + model_ref = _gpt_model_provider( + tp_size, + hidden_size, + vocab_size=256, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, + meta_device=meta_device, + ep_size=ep_size, + etp_size=etp_size, + ) + model_test = _gpt_model_provider( + tp_size, + hidden_size, + vocab_size=256, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, + meta_device=meta_device, + ep_size=ep_size, + etp_size=etp_size, + ) prompt_tokens = torch.randint( 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) @@ -328,7 +387,9 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device) spawn_multiprocess_job( size=size, - job=partial(_test_sharded_state_dict, tmp_path, config, 256, None, compress, meta_device), + job=partial( + _test_sharded_state_dict, tmp_path, config, 256, None, compress, meta_device, {} + ), backend="nccl", ) @@ -347,7 +408,7 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device) def test_heterogenous_sharded_state_dict(need_2_gpus, tmp_path, config): spawn_multiprocess_job( size=2, - job=partial(_test_sharded_state_dict, tmp_path, config, 256, None, False, False), + job=partial(_test_sharded_state_dict, tmp_path, config, 256, None, False, False, {}), backend="nccl", ) @@ -368,7 +429,7 @@ def test_sharded_state_dict_old_checkpoints(need_2_gpus, tmp_path, config, model spawn_multiprocess_job( size=2, job=partial( - _test_sharded_state_dict, tmp_path, config, 256, modelopt_version, False, False + _test_sharded_state_dict, tmp_path, config, 256, modelopt_version, False, False, {} ), backend="nccl", ) @@ -451,3 +512,228 @@ def forward_fn(model): def test_fp8_real_quantize(): size = torch.cuda.device_count() spawn_multiprocess_job(size=size, job=_test_fp8_real_quantize_helper, backend="nccl") + + +@pytest.mark.parametrize( + "config", + [ + mtq.FP8_DEFAULT_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_moe_sharded_state_dict(need_8_gpus, tmp_path, config): + size = torch.cuda.device_count() + # TODO: Meta device doesn't work with TE + # TODO: Add support for compress=True for TEGroupedMLP + moe_config = { + "tp_size": 2, + "ep_size": 2, + "etp_size": 2, + "num_moe_experts": 4, + "moe_grouped_gemm": True, + "use_te": True, + } + spawn_multiprocess_job( + size=size, + job=partial( + _test_sharded_state_dict, + tmp_path, + config, + 256, + None, + False, + False, + moe_config, + ), + backend="nccl", + ) + + +def _test_grouped_vs_non_grouped_amax_helper(tp_size, ep_size, etp_size, rank, size): + """Test that grouped and non-grouped MoE models produce similar amax values.""" + initialize_for_megatron( + tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, + seed=SEED, + ) + + # Create input + prompt_tokens = torch.randint(0, 64, (2, 16)).cuda() + + def forward_fn(model): + return megatron_prefill(model, prompt_tokens) + + # Create grouped MoE model + grouped_model = _gpt_model_provider( + tp_size=tp_size, + ep_size=ep_size, + etp_size=etp_size, + hidden_size=32, + moe_grouped_gemm=True, + use_te=True, + num_moe_experts=4, + ) + num_grouped_mlp = sum(isinstance(module, TEGroupedMLP) for module in grouped_model.modules()) + assert num_grouped_mlp == 4, ( + f"TEGrupedMoEModel has {num_grouped_mlp} TEGroupedMLP modules, it should have 4" + ) + + # Create non-grouped MoE model + non_grouped_model = _gpt_model_provider( + tp_size=tp_size, + ep_size=ep_size, + etp_size=etp_size, + hidden_size=32, + moe_grouped_gemm=False, + num_moe_experts=4, + ) + num_sequential_mlp = sum( + isinstance(module, SequentialMLP) for module in non_grouped_model.modules() + ) + assert num_sequential_mlp == 4, ( + f"SequentialMoEModel has {num_sequential_mlp} SequentialMLP modules, it should have 4" + ) + # Copy weights from grouped to non-grouped model + copy_weights_from_grouped_to_non_grouped(grouped_model, non_grouped_model) + + output_comparison_before = compare_model_outputs(grouped_model, non_grouped_model, forward_fn) + assert output_comparison_before, "Outputs are not close before quantization" + + # Quantize grouped model + mtq.quantize(grouped_model, mtq.FP8_DEFAULT_CFG, forward_fn) + + # Quantize non-grouped model + mtq.quantize(non_grouped_model, mtq.FP8_DEFAULT_CFG, forward_fn) + + # sync amax across expert parallel + # TODO: Remove once amax sync is enabled by default for SequentialGroupedMLP + sync_amax(non_grouped_model) + + # Compare model outputs after quantization + output_comparison_after = compare_model_outputs(grouped_model, non_grouped_model, forward_fn) + assert output_comparison_after, "Outputs are not close after quantization" + + +def test_grouped_vs_non_grouped_amax(): + """Test that grouped and non-grouped MoE models produce similar amax values.""" + import time + + size = torch.cuda.device_count() + if size < 4: + pytest.skip("Requires at least 4 GPUs for expert parallel test") + + # Add small delay to avoid port conflicts + time.sleep(0.1) + + spawn_multiprocess_job( + size=size, job=partial(_test_grouped_vs_non_grouped_amax_helper, 1, 2, 2), backend="nccl" + ) + + +def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm): + """ + Test that demonstrates the requirement for expert parallel sync in model_calib.py + """ + # Create model with expert parallelism + model = _gpt_model_provider( + tp_size=1, + ep_size=ep_size, + etp_size=etp_size, + hidden_size=256, + moe_grouped_gemm=moe_grouped_gemm, + use_te=moe_grouped_gemm, + num_moe_experts=4, + ) + + # Create input and forward function + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + + def forward_fn(model): + return megatron_prefill(model, prompt_tokens) + + # Run forward pass and quantize + forward_fn(model) + config = mtq.FP8_DEFAULT_CFG + model = mtq.quantize(model, config, forward_fn) + + # Check initial sync status + initial_sync = compare_amax_sync_across_expert_parallel(model) + assert initial_sync, ( + "Inconsistent amax across expert parallel ranks, Amax should be synchronized across expert parallel ranks" + ) + + # Create inconsistent amax values + rank = torch.distributed.get_rank() + for name, module in model.named_modules(): + if isinstance(module, mtq.nn.TensorQuantizer): + # Check if this is an expert quantizer + is_expert_quantizer = ( + "local_experts" in name # Non-grouped MoE + or ("experts" in name and "linear_fc" in name) # Grouped MoE + ) + + if is_expert_quantizer and hasattr(module, "_amax"): + # Create rank-specific amax values to simulate missing sync + rank_offset = rank * 0.1 + module.amax = module.amax + rank_offset + + # Determine expert parallel type + expert_parallel_type = ( + "both" if ep_size > 1 and etp_size > 1 else ("model" if ep_size > 1 else "tensor") + ) + + # Disable parallel groups and test inconsistency + module_parallel_groups = disable_distributed_parallel_sync(model, expert_parallel_type) + mtq.model_calib.max_calibrate(model, forward_fn) + + inconsistent_sync = compare_amax_sync_across_expert_parallel(model) + assert not inconsistent_sync, ( + "Consistent amax across expert parallel ranks, " + "Amax should not be synchronized across expert parallel ranks since expert parallel is disabled" + ) + + # Re-enable parallel groups and test synchronization + enable_distributed_parallel_sync(model, module_parallel_groups, expert_parallel_type) + mtq.model_calib.max_calibrate(model, forward_fn) + + final_sync = compare_amax_sync_across_expert_parallel(model) + assert final_sync, ( + "Inconsistent amax across expert parallel ranks, Amax should be synchronized across expert parallel ranks" + ) + + +def _test_expert_parallel_sync_helper(ep_size, etp_size, moe_grouped_gemm, rank, size): + """Test expert parallel synchronization with different configurations.""" + initialize_for_megatron( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + context_parallel_size=1, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, + seed=42 + rank, + ) + + # Run the actual test + _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm) + + +@pytest.mark.parametrize(("ep_size", "etp_size"), [(1, 2), (2, 1), (2, 2)]) +@pytest.mark.parametrize("moe_grouped_gemm", [True, False]) +def test_expert_parallel_sync(need_4_gpus, ep_size, etp_size, moe_grouped_gemm): + """Test expert model parallel synchronization.""" + import time + + size = torch.cuda.device_count() + total_size = ep_size * etp_size + if size < total_size: + pytest.skip(f"Requires at least {total_size} GPUs for expert model parallel test") + + # Add small delay to avoid port conflicts + time.sleep(0.1) + + spawn_multiprocess_job( + size=total_size, + job=partial(_test_expert_parallel_sync_helper, ep_size, etp_size, moe_grouped_gemm), + backend="nccl", + ) From 22bfe0e9cb1c5baf8afaba1c2ea6e7faabc37bed Mon Sep 17 00:00:00 2001 From: Kinjal Patel Date: Tue, 7 Oct 2025 22:49:53 +0000 Subject: [PATCH 16/16] code cleanup Signed-off-by: Kinjal Patel --- .../torch/quantization/plugins/megatron.py | 72 ++----------------- .../torch_dist/plugins/megatron_common.py | 1 - .../quantization/plugins/test_megatron.py | 45 +++++------- 3 files changed, 26 insertions(+), 92 deletions(-) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index dfd4b42be..c414f99c8 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -37,7 +37,7 @@ ) from modelopt.torch.utils.distributed import ParallelState -from ..nn import QuantModuleRegistry, SequentialQuantizer, TensorQuantizer +from ..nn import QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear @@ -501,7 +501,6 @@ def _setup(self): self.parallel_state = ParallelState( data_parallel_group, mcore_parallel.get_tensor_model_parallel_group(), - mcore_parallel.get_context_parallel_group(), mcore_parallel.get_expert_model_parallel_group(), expert_tensor_parallel_group, ) @@ -544,70 +543,13 @@ def te_grouped_quantized_linear_fn(ctx, inp, m_splits, *args): ] def modelopt_post_restore(self, prefix: str = ""): - """Post restore to correctly configure the TensorQuantizer states for MCore/distributed frameworks. - - ModelOpt restores the TensorQuantizer states such as `_amax` and `_pre_quant_scale` to their - shape before saving. However this is not enough for MCore/distributed frameworks since the tensor parallelism - could change between saving and restoring. If the tensor parallelism changes, the shape of the quantizer - states also changes. So we need to re-calculate the quantizer states. - """ - from modelopt.torch.quantization.model_calib import max_calibrate - - def _check_unsupported_states(quantizer: TensorQuantizer): - for k in quantizer.state_dict(): - if k not in ["_amax", "_pre_quant_scale"]: - warnings.warn( - f"Restore of {k} for {prefix} is not supported. The restore of this layer might be " - f"incorrect. Please implement a custom restore for {k}." - ) - - def _has_state(quantizer, name): - # Handling for SequentialQuantizer - quantizer = quantizer[0] if isinstance(quantizer, SequentialQuantizer) else quantizer - return hasattr(quantizer, name) - - # weights for TEGroupedLinear are stored in weight0, weight1, etc. - if self.weight0 is None: - return - for quantizer in [self.weight_quantizer, self.input_quantizer, self.output_quantizer]: - _check_unsupported_states( - quantizer if isinstance(quantizer, TensorQuantizer) else quantizer[0] - ) - if _has_state(self.weight_quantizer, "_amax"): - self.weight_quantizer.reset_amax() - for i in range(self.num_gemms): - weight = getattr(self, f"weight{i}") - assert weight is not None, "weight is None" - - max_calibrate(self.weight_quantizer, lambda wq: wq(weight), distributed_sync=False) - if _has_state(self.input_quantizer, "_pre_quant_scale"): - if hasattr(self.input_quantizer, "_pre_quant_scale"): - delattr(self.input_quantizer, "_pre_quant_scale") - pqs = torch.zeros( - (weight.shape[1]), device=weight.device, dtype=self.original_weight_dtype - ) - self.input_quantizer.register_buffer("_pre_quant_scale", pqs) - - if _has_state(self.input_quantizer, "_amax"): - self.input_quantizer.reset_amax() - dummy_input = torch.ones( - (1, 1, self.weight0.shape[1]), - device=self.weight0.device, - dtype=self.original_weight_dtype, - ) - max_calibrate(self.input_quantizer, lambda iq: iq(dummy_input), distributed_sync=False) - if _has_state(self.output_quantizer, "_amax"): - self.output_quantizer.reset_amax() - dummy_input = torch.ones( - (1, 1, self.weight0.shape[0]), - device=self.weight0.device, - dtype=self.original_weight_dtype, - ) - max_calibrate(self.output_quantizer, lambda oq: oq(dummy_input), distributed_sync=False) - # If there are any other states, lets move them to the correct device - - self.weight = None + # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to + # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning + # self.weight0 to self.weight to run the quantizer states initialization. + self.weight = self.weight0 super().modelopt_post_restore(prefix=prefix) + # Revert the weight to None after post_restore to avoid the weight being None during forward pass. + self.weight = None def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 68268805e..5facf89a7 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -429,7 +429,6 @@ def initialize_for_megatron( context_parallel_size=context_parallel_size, expert_tensor_parallel_size=expert_tensor_parallel_size, expert_model_parallel_size=expert_model_parallel_size, - order="tp-ep-dp-pp", ) model_parallel_cuda_manual_seed(seed) diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 0993e2565..309755dbd 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -549,7 +549,7 @@ def test_moe_sharded_state_dict(need_8_gpus, tmp_path, config): ) -def _test_grouped_vs_non_grouped_amax_helper(tp_size, ep_size, etp_size, rank, size): +def _test_grouped_vs_non_grouped_quantize_helper(tp_size, ep_size, etp_size, rank, size): """Test that grouped and non-grouped MoE models produce similar amax values.""" initialize_for_megatron( tensor_model_parallel_size=tp_size, @@ -615,8 +615,8 @@ def forward_fn(model): assert output_comparison_after, "Outputs are not close after quantization" -def test_grouped_vs_non_grouped_amax(): - """Test that grouped and non-grouped MoE models produce similar amax values.""" +def test_grouped_vs_non_grouped_quantize(): + """Test that grouped and non-grouped MoE models produce similar quantized models.""" import time size = torch.cuda.device_count() @@ -627,14 +627,22 @@ def test_grouped_vs_non_grouped_amax(): time.sleep(0.1) spawn_multiprocess_job( - size=size, job=partial(_test_grouped_vs_non_grouped_amax_helper, 1, 2, 2), backend="nccl" + size=size, + job=partial(_test_grouped_vs_non_grouped_quantize_helper, 1, 2, 2), + backend="nccl", ) -def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm): - """ - Test that demonstrates the requirement for expert parallel sync in model_calib.py - """ +def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm, rank, size): + """Test expert parallel synchronization with different configurations.""" + initialize_for_megatron( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, + seed=SEED, + ) + # Create model with expert parallelism model = _gpt_model_provider( tp_size=1, @@ -664,7 +672,7 @@ def forward_fn(model): ) # Create inconsistent amax values - rank = torch.distributed.get_rank() + cur_rank = torch.distributed.get_rank() for name, module in model.named_modules(): if isinstance(module, mtq.nn.TensorQuantizer): # Check if this is an expert quantizer @@ -675,7 +683,7 @@ def forward_fn(model): if is_expert_quantizer and hasattr(module, "_amax"): # Create rank-specific amax values to simulate missing sync - rank_offset = rank * 0.1 + rank_offset = cur_rank * 0.1 module.amax = module.amax + rank_offset # Determine expert parallel type @@ -703,21 +711,6 @@ def forward_fn(model): ) -def _test_expert_parallel_sync_helper(ep_size, etp_size, moe_grouped_gemm, rank, size): - """Test expert parallel synchronization with different configurations.""" - initialize_for_megatron( - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, - context_parallel_size=1, - expert_model_parallel_size=ep_size, - expert_tensor_parallel_size=etp_size, - seed=42 + rank, - ) - - # Run the actual test - _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm) - - @pytest.mark.parametrize(("ep_size", "etp_size"), [(1, 2), (2, 1), (2, 2)]) @pytest.mark.parametrize("moe_grouped_gemm", [True, False]) def test_expert_parallel_sync(need_4_gpus, ep_size, etp_size, moe_grouped_gemm): @@ -734,6 +727,6 @@ def test_expert_parallel_sync(need_4_gpus, ep_size, etp_size, moe_grouped_gemm): spawn_multiprocess_job( size=total_size, - job=partial(_test_expert_parallel_sync_helper, ep_size, etp_size, moe_grouped_gemm), + job=partial(_test_expert_model_parallel_amax_sync, ep_size, etp_size, moe_grouped_gemm), backend="nccl", )