From f17131f82d2e4cd41fcda0c07129b7c67f605eb4 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 24 Sep 2025 00:27:11 +0000 Subject: [PATCH 01/14] sync amax in context parallel and awq act scale Signed-off-by: Jennifer Chen --- examples/nemo_run/qat/README.md | 2 +- modelopt/torch/quantization/model_calib.py | 19 +++- .../torch/quantization/plugins/megatron.py | 9 +- modelopt/torch/utils/distributed.py | 4 +- .../torch_dist/plugins/megatron_common.py | 7 +- .../torch_quantization/quantize_common.py | 57 +++++++++++- tests/gpu/torch/conftest.py | 6 ++ .../quantization/plugins/test_megatron.py | 89 ++++++++++++++++++- 8 files changed, 178 insertions(+), 15 deletions(-) diff --git a/examples/nemo_run/qat/README.md b/examples/nemo_run/qat/README.md index 79715953c..cd74c96e2 100644 --- a/examples/nemo_run/qat/README.md +++ b/examples/nemo_run/qat/README.md @@ -92,7 +92,7 @@ In order to train using QAD, launch the example with `python qat/nemo_qat_flow.p To perform QAD training, run: ```bash -python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment +python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment --tensor_parallelism 4 ``` ## Supported models diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 5276b1334..6d1b7c86f 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -79,21 +79,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis if not distributed_sync: return - def sync_quantizer_amax_across_dp(quantizer, parallel_state): + def sync_quantizer_amax_across_dp_cp(quantizer, parallel_state): + """Synchronize the amax across all ranks in the data parallel and context parallel groups.""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - sync_quantizer_amax_across_dp(_q, parallel_state) + sync_quantizer_amax_across_dp_cp(_q, parallel_state) return if getattr(quantizer, "_amax", None) is not None: quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) + quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group) # TODO: create sync_bias_across_distributed_group for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): - sync_quantizer_amax_across_dp(child, module.parallel_state) - + sync_quantizer_amax_across_dp_cp(child, module.parallel_state) # TP sync: # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same @@ -624,6 +625,14 @@ def forward(self, input, *args, **kwargs): # This will also perform distributed amax sync for input_quantizers max_calibrate(model, lambda model: None) + def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group): + # Sync across Data Parallel (DP) + if data_parallel_group.is_initialized(): + dist.all_reduce(module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group) + # Sync across Context Parallel (CP) + if context_parallel_group.is_initialized(): + dist.all_reduce(module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group) + for name, module in model.named_modules(): if ( is_quantized_linear(module) @@ -631,6 +640,8 @@ def forward(self, input, *args, **kwargs): and module.awq_lite.num_cache_steps > 0 ): module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps + sync_act_scale_across_dp_cp(module, module.parallel_state.data_parallel_group, module.parallel_state.context_parallel_group) + # Hack: MoEs forward all tokens through all experts if _if_calib is True module._if_calib = True diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index ab64a795a..72442526f 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -23,6 +23,7 @@ import megatron.core.transformer.mlp as megatron_mlp import torch from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region +from megatron.core.parallel_state import get_data_parallel_group from megatron.core.transformer import MegatronModule from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint from megatron.core.utils import get_tensor_model_parallel_group_if_none @@ -217,9 +218,15 @@ class _MegatronParallelLinear(_ParallelLinear): ] def _setup(self): + data_parallel_group = None + try: + data_parallel_group = get_data_parallel_group(with_context_parallel=True) + except: + data_parallel_group = get_data_parallel_group() self.parallel_state = ParallelState( - getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(), + data_parallel_group, mcore_parallel.get_tensor_model_parallel_group(), + mcore_parallel.get_context_parallel_group(), ) super()._setup() diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 7aebc992d..c1a313e48 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -241,13 +241,15 @@ def __init__( self, data_parallel_group: torch.distributed.ProcessGroup | int | None = None, tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1, + context_parallel_group: torch.distributed.ProcessGroup | int | None = -1, ): """Initialize the parallel state.""" self.data_parallel_group = DistributedProcessGroup(data_parallel_group) self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group) + self.context_parallel_group = DistributedProcessGroup(context_parallel_group) def __repr__(self) -> str: - return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}" + return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}, context_parallel_group: {self.context_parallel_group}" def get_group(ranks: list[int]): diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 9c1dd1bf7..ab9d467ea 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -83,9 +83,10 @@ class MegatronModel(MegatronModule): - def __init__(self, tp_size: int = 1, use_te_norm: bool = False): + def __init__(self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False): config = TransformerConfig( tensor_model_parallel_size=tp_size, + context_parallel_size=cp_size, pipeline_model_parallel_size=1, normalization="LayerNorm", # Unused parameters below are set to avoid ZeroDivisionError in __post_init__ @@ -383,13 +384,13 @@ def run_mcore_inference_with_dummy_input( def initialize_for_megatron( - tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234 + tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=1, seed=1234 ): """Initialize Megatron model parallelism. NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`. """ - initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size) + initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, context_parallel_size=context_parallel_size) model_parallel_cuda_manual_seed(seed) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 505eac2b6..65cc39a5e 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -116,8 +116,8 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N mto.restore_from_modelopt_state(model_ref, state_dict) -def tensor_parallel_test_helper(model, config, tp_group, dp_group): - # The input to fist layer, the column parallel should be the same across all tp ranks +def tensor_parallel_test_helper(model, config, tp_group): + # The input to first layer, the column parallel should be the same across all tp ranks calib_data = model.get_dummy_input().cuda() dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) @@ -149,6 +149,59 @@ def forward_loop(model): dist.destroy_process_group() +def data_parallel_test_helper(model, config, dp_group): + calib_data = model.get_dummy_input().cuda() + + def forward_loop(model): + model(calib_data) + + model = mtq.quantize(model, config, forward_loop) + + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + +def context_parallel_test_helper(model, config, cp_group): + calib_data = model.get_dummy_input().cuda() + + def forward_loop(model): + model(calib_data) + + model = mtq.quantize(model, config, forward_loop) + + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + +def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): + calib_data = model.get_dummy_input().cuda() + # data should be same across each TP rank + dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) + + def forward_loop(model): + model(calib_data) + + model = mtq.quantize(model, config, forward_loop) + + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) def auto_quantize_helper(model): model, search_state = mtq.auto_quantize( diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index 208fb2287..05de03012 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -33,6 +33,12 @@ def need_2_gpus(): if torch.cuda.device_count() < 2: pytest.skip("Need at least 2 GPUs to run this test") +@pytest.fixture +def need_8_gpus(): + if torch.cuda.device_count() < 8: + pytest.skip("Need at least 8 GPUs to run this test") + + @pytest.fixture(scope="module") def set_torch_dtype(request): diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index c3630e028..05ba40f7c 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -32,6 +32,9 @@ from _test_utils.torch_quantization.quantize_common import ( auto_quantize_helper, tensor_parallel_test_helper, + data_parallel_test_helper, + context_parallel_test_helper, + data_tensor_context_parallel_test_helper, ) from packaging.version import Version @@ -41,6 +44,7 @@ from megatron.core.parallel_state import ( destroy_model_parallel, get_data_parallel_group, + get_context_parallel_group, get_tensor_model_parallel_group, ) from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear @@ -91,13 +95,13 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1): # Clean up since this is not a spawned process destroy_model_parallel() - +# 1. Tensor Parallel Test def _test_tensor_parallel_helper(config, rank, size): initialize_for_megatron(tensor_model_parallel_size=2, seed=SEED) - model = MegatronModel(size).cuda() + model = MegatronModel(tp_size=size).cuda() tensor_parallel_test_helper( - model, config, get_tensor_model_parallel_group(), get_data_parallel_group() + model, config, get_tensor_model_parallel_group() ) @@ -118,6 +122,85 @@ def test_tensor_parallel(need_2_gpus, config): size=2, job=partial(_test_tensor_parallel_helper, config), backend="nccl" ) +# 2. Data Parallel Test +def _test_data_parallel_helper(config, rank, size): + # TODO does this model automatically get copied to both DP ranks? + initialize_for_megatron(seed=SEED) + model = MegatronModel().cuda() + + data_parallel_test_helper( + model, config, get_data_parallel_group() + ) + + +@pytest.mark.parametrize( + "config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.W4A8_AWQ_BETA_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + mtq.INT4_AWQ_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_data_parallel(need_2_gpus, config): + spawn_multiprocess_job( + size=2, job=partial(_test_data_parallel_helper, config), backend="nccl" + ) + +# 3. Context Parallel Test +def _test_context_parallel_helper(config, rank, size): + initialize_for_megatron(context_parallel_size=size, seed=SEED) + model = MegatronModel(cp_size=size).cuda() + + context_parallel_test_helper( + model, config, get_context_parallel_group() + ) + +@pytest.mark.parametrize( + "config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.W4A8_AWQ_BETA_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + mtq.INT4_AWQ_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_context_parallel(need_2_gpus, config): + spawn_multiprocess_job( + size=2, job=partial(_test_context_parallel_helper, config), backend="nccl" + ) + +# 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs) +def _test_data_tensor_context_parallel_helper(config, rank, size): + initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED) + model = MegatronModel(tp_size=2, cp_size=2).cuda() + + data_tensor_context_parallel_test_helper( + model, config, get_data_parallel_group(), get_tensor_model_parallel_group(), get_context_parallel_group() + ) + +@pytest.mark.parametrize( + "config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.W4A8_AWQ_BETA_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + mtq.INT4_AWQ_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_data_tensor_context_parallel(need_8_gpus, config): + spawn_multiprocess_job( + size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" + ) def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): """Build the model.""" From 42519ccd809c50a73b85bd59d87b5859a885c3fe Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Thu, 25 Sep 2025 18:54:28 +0000 Subject: [PATCH 02/14] lint Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 16 ++++++-- .../torch/quantization/plugins/megatron.py | 4 +- modelopt/torch/utils/distributed.py | 6 ++- .../torch_dist/plugins/megatron_common.py | 6 ++- .../torch_quantization/quantize_common.py | 4 ++ tests/gpu/torch/conftest.py | 2 +- .../quantization/plugins/test_megatron.py | 37 ++++++++++--------- 7 files changed, 49 insertions(+), 26 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 6d1b7c86f..03f4936ae 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -628,10 +628,14 @@ def forward(self, input, *args, **kwargs): def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group): # Sync across Data Parallel (DP) if data_parallel_group.is_initialized(): - dist.all_reduce(module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group) + dist.all_reduce( + module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group + ) # Sync across Context Parallel (CP) if context_parallel_group.is_initialized(): - dist.all_reduce(module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group) + dist.all_reduce( + module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group + ) for name, module in model.named_modules(): if ( @@ -640,8 +644,12 @@ def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_gr and module.awq_lite.num_cache_steps > 0 ): module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps - sync_act_scale_across_dp_cp(module, module.parallel_state.data_parallel_group, module.parallel_state.context_parallel_group) - + sync_act_scale_across_dp_cp( + module, + module.parallel_state.data_parallel_group, + module.parallel_state.context_parallel_group, + ) + # Hack: MoEs forward all tokens through all experts if _if_calib is True module._if_calib = True diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 72442526f..96e8c61e1 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -22,8 +22,8 @@ import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp import torch -from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.parallel_state import get_data_parallel_group +from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint from megatron.core.utils import get_tensor_model_parallel_group_if_none @@ -221,7 +221,7 @@ def _setup(self): data_parallel_group = None try: data_parallel_group = get_data_parallel_group(with_context_parallel=True) - except: + except AssertionError: data_parallel_group = get_data_parallel_group() self.parallel_state = ParallelState( data_parallel_group, diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index c1a313e48..18c2f40c2 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -249,7 +249,11 @@ def __init__( self.context_parallel_group = DistributedProcessGroup(context_parallel_group) def __repr__(self) -> str: - return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}, context_parallel_group: {self.context_parallel_group}" + return ( + f"data_parallel_group: {self.data_parallel_group}, " + f"tensor_parallel_group: {self.tensor_parallel_group}, " + f"context_parallel_group: {self.context_parallel_group}" + ) def get_group(ranks: list[int]): diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index ab9d467ea..7c497a895 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -390,7 +390,11 @@ def initialize_for_megatron( NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`. """ - initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, context_parallel_size=context_parallel_size) + initialize_model_parallel( + tensor_model_parallel_size, + pipeline_model_parallel_size, + context_parallel_size=context_parallel_size, + ) model_parallel_cuda_manual_seed(seed) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 65cc39a5e..e8c40af75 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -149,6 +149,7 @@ def forward_loop(model): dist.destroy_process_group() + def data_parallel_test_helper(model, config, dp_group): calib_data = model.get_dummy_input().cuda() @@ -165,6 +166,7 @@ def forward_loop(model): dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + def context_parallel_test_helper(model, config, cp_group): calib_data = model.get_dummy_input().cuda() @@ -181,6 +183,7 @@ def forward_loop(model): dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): calib_data = model.get_dummy_input().cuda() # data should be same across each TP rank @@ -203,6 +206,7 @@ def forward_loop(model): dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + def auto_quantize_helper(model): model, search_state = mtq.auto_quantize( model, diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index 05de03012..f32065bce 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -33,13 +33,13 @@ def need_2_gpus(): if torch.cuda.device_count() < 2: pytest.skip("Need at least 2 GPUs to run this test") + @pytest.fixture def need_8_gpus(): if torch.cuda.device_count() < 8: pytest.skip("Need at least 8 GPUs to run this test") - @pytest.fixture(scope="module") def set_torch_dtype(request): orig_dtype = torch.get_default_dtype() diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 05ba40f7c..486756e2a 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -31,10 +31,10 @@ from _test_utils.torch_quantization.quant_utils import get_model_size from _test_utils.torch_quantization.quantize_common import ( auto_quantize_helper, - tensor_parallel_test_helper, - data_parallel_test_helper, context_parallel_test_helper, + data_parallel_test_helper, data_tensor_context_parallel_test_helper, + tensor_parallel_test_helper, ) from packaging.version import Version @@ -43,8 +43,8 @@ import megatron.core from megatron.core.parallel_state import ( destroy_model_parallel, - get_data_parallel_group, get_context_parallel_group, + get_data_parallel_group, get_tensor_model_parallel_group, ) from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear @@ -95,14 +95,13 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1): # Clean up since this is not a spawned process destroy_model_parallel() + # 1. Tensor Parallel Test def _test_tensor_parallel_helper(config, rank, size): initialize_for_megatron(tensor_model_parallel_size=2, seed=SEED) model = MegatronModel(tp_size=size).cuda() - tensor_parallel_test_helper( - model, config, get_tensor_model_parallel_group() - ) + tensor_parallel_test_helper(model, config, get_tensor_model_parallel_group()) @pytest.mark.parametrize( @@ -122,15 +121,14 @@ def test_tensor_parallel(need_2_gpus, config): size=2, job=partial(_test_tensor_parallel_helper, config), backend="nccl" ) + # 2. Data Parallel Test def _test_data_parallel_helper(config, rank, size): # TODO does this model automatically get copied to both DP ranks? initialize_for_megatron(seed=SEED) model = MegatronModel().cuda() - data_parallel_test_helper( - model, config, get_data_parallel_group() - ) + data_parallel_test_helper(model, config, get_data_parallel_group()) @pytest.mark.parametrize( @@ -146,18 +144,16 @@ def _test_data_parallel_helper(config, rank, size): ], ) def test_data_parallel(need_2_gpus, config): - spawn_multiprocess_job( - size=2, job=partial(_test_data_parallel_helper, config), backend="nccl" - ) + spawn_multiprocess_job(size=2, job=partial(_test_data_parallel_helper, config), backend="nccl") + # 3. Context Parallel Test def _test_context_parallel_helper(config, rank, size): initialize_for_megatron(context_parallel_size=size, seed=SEED) model = MegatronModel(cp_size=size).cuda() - context_parallel_test_helper( - model, config, get_context_parallel_group() - ) + context_parallel_test_helper(model, config, get_context_parallel_group()) + @pytest.mark.parametrize( "config", @@ -176,15 +172,21 @@ def test_context_parallel(need_2_gpus, config): size=2, job=partial(_test_context_parallel_helper, config), backend="nccl" ) + # 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs) def _test_data_tensor_context_parallel_helper(config, rank, size): initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED) model = MegatronModel(tp_size=2, cp_size=2).cuda() data_tensor_context_parallel_test_helper( - model, config, get_data_parallel_group(), get_tensor_model_parallel_group(), get_context_parallel_group() + model, + config, + get_data_parallel_group(), + get_tensor_model_parallel_group(), + get_context_parallel_group(), ) + @pytest.mark.parametrize( "config", [ @@ -199,9 +201,10 @@ def _test_data_tensor_context_parallel_helper(config, rank, size): ) def test_data_tensor_context_parallel(need_8_gpus, config): spawn_multiprocess_job( - size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" + size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" ) + def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): """Build the model.""" From 264adbb1b13221bfaec2c59b161e6061bafcafd9 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Thu, 25 Sep 2025 21:46:30 +0000 Subject: [PATCH 03/14] test weight quantizer too Signed-off-by: Jennifer Chen --- examples/speculative_decoding/README.md | 195 ++++++++++-------- examples/speculative_decoding/launch_train.sh | 2 +- modelopt/torch/distill/plugins/megatron.py | 29 ++- modelopt/torch/quantization/export_onnx.py | 115 ++++++----- .../nn/modules/tensor_quantizer.py | 6 +- .../torch/quantization/plugins/diffusers.py | 23 ++- .../torch_quantization/quantize_common.py | 62 ++++-- .../speculative_decoding/test_eagle.py | 3 +- .../speculative_decoding/test_medusa.py | 3 +- 9 files changed, 266 insertions(+), 172 deletions(-) diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 503cf303f..aeea25adb 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -15,8 +15,11 @@ This example focuses on training with Hugging Face. To train with Megatron‑LM, | **Section** | **Description** | **Jump To** | | :------------: | :------------: | :------------: | | Pre-Requisites | Required & optional dependencies | \[[Link](#pre-requisites)\] | -| Simplified Workflow | Train, evaluate, and export eagle model with one-line command | \[[Link](#getting-started-simplified-workflow)\] | -| Complete Workflow | Full example with configurable training pipeline | \[[Link](#complete-workflow)\] | +| Simplified Workflow | Train, evaluate, and export EAGLE model with one-line command | \[[Link](#getting-started-simplified-workflow)\] | +| Online Training | Train draft model alongside base model in GPU memory | \[[Link](#training-draft-model-with-online-base-model)\] | +| Offline Training | Train draft model using pre-computed hidden states | \[[Link](#training-draft-model-with-offline-base-model)\] | +| After Training | Evaluation, export and deployment | \[[Link](#model-validation)\] | +| Advanced Usage | Data synthesis, vocab compression, and configuration | \[[Link](#advanced-usage)\] | | Support Matrix | Supported models for speculative decoding training | \[[Link](#support-matrix)\] | | Speculation Module Checkpoints | View pre-trained speculation modules ready to deploy! | \[[Link](#speculation-module-checkpoints)\] | | Resources | Extra links to relevant resources | \[[Link](#resources)\] | @@ -61,13 +64,113 @@ This one-line command runs a minimal example workflow of training and exporting - Evaluates the acceptance rate on [MT-Bench](https://huggingface.co/datasets/HuggingFaceH4/mt_bench_prompts) - Exports a checkpoint ready for deployment -## Complete Workflow +## Training Draft Model with Online Base Model -This section presents a more comprehensive example for customizing speculative decoding training with Modelopt, including optional steps to enhance training quality and efficiency. +For small base models that fit in GPU memory, we can collocate them with draft models and train with the following command: -### (Optional) Data Synthesis +```bash +./launch_train.sh --model $BASE_MODEL \ + --output_dir $OUTPUT_DIR \ + --data Daring-Anteater/train.jsonl \ + --num_gpu $NUM_GPU \ + --num_epochs $NUM_EPOCH \ + --eagle_config eagle_config.json +``` + +This command will launch `main.py` with `accelerate`. See [section: interact with modelopt.torch.speculative](#interact-with-modelopttorchspeculative) for more details. +The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. + +## Training Draft Model with Offline Base Model + +For large models, you can export intermediate hidden states to disk and train only the draft model. This significantly reduces GPU memory requirements, but requires several to tens of terabytes of storage depending on dataset size. + +First, dump the base model's hidden states with the following command: + +```bash +python collect_hidden_states/compute_hidden_states_hf.py \ + --model $BASE_MODEL \ + --input-file Daring-Anteater/train.jsonl \ + --output-dir $HIDDEN_STATES_DIR +``` + +See [`run_hf_compute_hiddens_dp.sh`](./collect_hidden_states/run_hf_compute_hiddens_dp.sh) for a simple example using data parallelism (DP) to accelerate hidden state generation. + +Then, train draft model with `--offline-data` argument: + +```bash +./launch_train.sh --model $BASE_MODEL \ + --output_dir $OUTPUT_DIR \ + --data $DATA \ + --num_gpu $NUM_GPU \ + --num_epochs $NUM_EPOCH \ + --eagle_config eagle_config.json \ + --offline-data $HIDDEN_STATES_DIR +``` + +## Model Validation + +After training draft model, we can evaluate the saved modelopt checkpoint on MT-bench by: + +```bash +python ar_validate.py --model_path $OUTPUT_DIR +``` + +Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below. + +## Export + +```bash +python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH +``` + +This exports the model from a ModelOpt checkpoint to a deployment-compatible format. + +## Deployment + +The exported checkpoint can be deployed on TRT-LLM or SGLang. + +### TRT-LLM + +To serve the checkpoint with TRT-LLM, run trtllm-serve with: + +```bash +trtllm-serve --host 0.0.0.0 --port 8000 --backend pytorch --max_batch_size 32 --max_num_tokens 8192 --max_seq_len 8192 --extra_llm_api_options extra-llm-api-config.yml +``` + +, with `extra-llm-api-config.yml` being + +```yaml +enable_attention_dp: false +disable_overlap_scheduler: true +enable_autotuner: false + +cuda_graph_config: + max_batch_size: 1 + +speculative_config: + decoding_type: Eagle + max_draft_len: 3 + speculative_model_dir: + +kv_cache_config: + enable_block_reuse: false +``` + +Please refer to [TRT-LLM Doc: Speculative Decoding](https://nvidia.github.io/TensorRT-LLM/examples/llm_speculative_decoding.html) for detailed usage. + +### SGLang -To achieve higher acceptance rates during speculative decoding, it is beneficial to use conversations generated by the base model as training data, ensuring that the draft model’s output distribution closely aligns with that of the base model. +Please refer to [SGLang Doc: Speculative Decoding](https://docs.sglang.ai/advanced_features/speculative_decoding.html#EAGLE-3-Decoding) for detailed usage. + +### Deploying Quantized model + +See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/README.md). + +## Advanced Usage + +### Data Synthesis + +To achieve higher acceptance rates during speculative decoding, it is beneficial to use conversations generated by the base model as training data. This ensures that the draft model's output distribution closely aligns with that of the base model. To prepare such data, we launch an inference server with the base model: @@ -78,7 +181,7 @@ vllm serve meta-llama/Llama-3.2-1B-Instruct --api-key token-abc123 --port 8000 Note: Add `--quantization=modelopt` flag for quantized models. -Then, we generate conversations with base model and prompts from Daring-Anteater: +Then, we generate conversations with the base model using prompts from Daring-Anteater: ```bash python server_generate.py --data_path Daring-Anteater/train.jsonl --output_path synthetic/train.jsonl @@ -88,7 +191,7 @@ To add a system prompt, use the `--system_prompt ` argument. For large scale data generation, please see [SLURM prepare data](SLURM_prepare_data.md) for SLURM support. -### (Optional) Draft Vocabulary Compression +### Draft Vocabulary Compression We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set: @@ -98,7 +201,7 @@ python calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct --data This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`. -### (Optional) Configuring Draft Model +### Configuring Draft Model For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. In this example, we override `draft_vocab_size` in `eagle_config.json`: @@ -108,7 +211,7 @@ For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](htt } ``` -### Training Draft Model with Modelopt +### Interact with `modelopt.torch.speculative` `main.py` provides an example for converting a HF base model for speculative decoding and training it. It consists of a few simple steps: First, load the base model and tokenizer from Hugging Face: @@ -162,78 +265,6 @@ trainer.save_state() trainer.save_model("") ``` -We omitted details like tokenizer initialization for simplicity. A complete training example is provided in `main.py`, along with a bash script to launch training with Hugging Face Accelerate in `launch_train.sh`, which can be run by: - -```bash -./launch_train.sh --model $BASE_MODEL \ - --output_dir $OUTPUT_DIR \ - --data $DATA \ - --num_gpu $NUM_GPU \ - --num_epochs 10 \ - --eagle_config eagle_config.json #This is where we optionally overwrite default eagle configs -``` - -The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. - -### Model Validation - -After training draft model, we can evaluate the saved modelopt checkpoint on MT-bench by: - -```bash -python ar_validate.py --model_path $OUTPUT_DIR -``` - -Alternatively, we can export the checkpoint and run evaluation on serving frameworks. See sections below. - -### Export - -```bash -python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH -``` - -This exports the model from a ModelOpt checkpoint to a deployment‑compatible format. - -### Deployment - -The exported checkpoint can be deployed on TRT-LLM or SGLang. - -#### TRT-LLM - -To serve the checkpoint with trtllm, run trtllm-serve with: - -```bash -trtllm-serve --host 0.0.0.0 --port 8000 --backend pytorch --max_batch_size 32 --max_num_tokens 8192 --max_seq_len 8192 --extra_llm_api_options extra-llm-api-config.yml -``` - -, with `extra-llm-api-config.yml` being - -```yaml -enable_attention_dp: false -disable_overlap_scheduler: true -enable_autotuner: false - -cuda_graph_config: - max_batch_size: 1 - -speculative_config: - decoding_type: Eagle - max_draft_len: 3 - speculative_model_dir: - -kv_cache_config: - enable_block_reuse: false -``` - -Please refer to [TRT-LLM Doc: Speculative Decoding](https://nvidia.github.io/TensorRT-LLM/examples/llm_speculative_decoding.html) for detailed usage. - -#### SGLang - -Please refer to [SGLang Doc: Speculative Decoding](https://docs.sglang.ai/advanced_features/speculative_decoding.html#EAGLE-3-Decoding) for detailed usage. - -#### Deploying Quantized model - -See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/README.md). - ## Support Matrix | Model | Medusa | EAGLE1/2 | EAGLE3 | diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index 3ecd4238a..2d0a4abe7 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -129,7 +129,7 @@ if [[ "$OFFLINE_DATA_PATH" != "" ]]; then echo "Offline data path $OFFLINE_DATA_PATH does not exist or is not a directory." exit 1 else - OFFLINE_TRAINING_ARGS="--offline-data-path $OFFLINE_DATA_PATH" + OFFLINE_TRAINING_ARGS="--offline-data-path $OFFLINE_DATA_PATH --ar_validate_steps -1" fi else OFFLINE_TRAINING_ARGS="" diff --git a/modelopt/torch/distill/plugins/megatron.py b/modelopt/torch/distill/plugins/megatron.py index 7078cca36..c1fa45f6b 100644 --- a/modelopt/torch/distill/plugins/megatron.py +++ b/modelopt/torch/distill/plugins/megatron.py @@ -59,7 +59,7 @@ class DistillationConfig: logit_kl_temperature: Temperature for the logit KL-divergence loss. """ - intermediate_layer_pairs: list[tuple[str, str]] = field(default_factory=list) + intermediate_layer_pairs: list[tuple[str, ...]] = field(default_factory=list) logit_layers: tuple[str, str] = ("output_layer", "output_layer") skip_lm_loss: bool = True kd_loss_scale: float = 1.0 @@ -69,12 +69,28 @@ class DistillationConfig: def __post_init__(self): assert len(self.logit_layers) == 2, f"{self.logit_layers=}" - assert all(len(pair) == 2 for pair in self.intermediate_layer_pairs), ( + assert all(len(pair) in (2, 3) for pair in self.intermediate_layer_pairs), ( f"{self.intermediate_layer_pairs=}" ) assert self.kd_loss_scale > 0, f"{self.kd_loss_scale=}" assert self.logit_kl_temperature > 0, f"{self.logit_kl_temperature=}" + @staticmethod + def parse_intermediate_entry(entry: tuple[str, ...]) -> tuple[str, str, Callable]: + """Parse an intermediate entry into a student layer, teacher layer, and loss function.""" + if len(entry) == 3: + student_layer, teacher_layer, loss_fn_name = entry + if loss_fn_name == "cosine": + loss_fn = HiddenStateCosineLoss + elif loss_fn_name == "mse": + loss_fn = MSELoss + else: + raise ValueError(f"Unknown intermediate loss function: {loss_fn_name}") + else: + student_layer, teacher_layer = entry + loss_fn = HiddenStateCosineLoss # default to cosine loss + return student_layer, teacher_layer, loss_fn + def load_distillation_config( config_path: str | None, student_cfg: "TransformerConfig", teacher_cfg: "TransformerConfig" @@ -105,7 +121,8 @@ def load_distillation_config( # NOTE: Projection layer shared among intermediate layer pairs. projection_layer = ProjectionLayer(student_cfg, teacher_cfg) - for student_layer, teacher_layer in cfg.intermediate_layer_pairs: + for entry in cfg.intermediate_layer_pairs: + student_layer, teacher_layer, loss_fn = cfg.parse_intermediate_entry(entry) if parallel_state.get_tensor_and_context_parallel_rank() == 0: logger.info( "Distillation: Adding intermediate loss between" @@ -114,7 +131,7 @@ def load_distillation_config( ) student_layer = _adjust_layer_index_for_pp(student_layer, student_cfg) teacher_layer = _adjust_layer_index_for_pp(teacher_layer, teacher_cfg) - criterion[(student_layer, teacher_layer)] = HiddenStateCosineLoss( + criterion[(student_layer, teacher_layer)] = loss_fn( student_cfg, projection_layer=projection_layer ) @@ -202,9 +219,9 @@ def forward(self, predictions: Tensor, targets: Tensor) -> Tensor: predictions, targets = self.pre_forward(predictions, targets) loss = F.mse_loss(predictions, targets, reduction="none") - loss = loss.sum(dim=-1) + loss = loss.mean(dim=-1) - return self.post_forward(loss) + return self.post_forward(loss, is_sequence_parallel=self._config.sequence_parallel) class HiddenStateCosineLoss(BaseLoss): diff --git a/modelopt/torch/quantization/export_onnx.py b/modelopt/torch/quantization/export_onnx.py index fe9bd927b..e8a38b162 100644 --- a/modelopt/torch/quantization/export_onnx.py +++ b/modelopt/torch/quantization/export_onnx.py @@ -103,13 +103,18 @@ """Utility to export a quantized torch model to quantized ONNX.""" import contextlib +from typing import TYPE_CHECKING import onnx import torch from torch.onnx import symbolic_helper from torch.onnx import symbolic_helper as sym_help -from torch.onnx._internal import jit_utils -from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask + +if TYPE_CHECKING: + if hasattr(torch.onnx._internal, "jit_utils"): + from torch.onnx._internal.jit_utils import GraphContext + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter.jit_utils import GraphContext onnx_dtype_map = { "BFloat16": onnx.TensorProto.BFLOAT16, @@ -125,7 +130,7 @@ def export_int8( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, amax: torch.Tensor, num_bits: int, @@ -184,7 +189,7 @@ def export_int8( def export_int4( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, amax: torch.Tensor, num_bits: int, @@ -208,7 +213,7 @@ def export_int4( def _fp8_quantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale_inv: float, trt_high_precision_dtype: str, @@ -236,7 +241,7 @@ def _fp8_quantize( def _fp8_dequantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale_inv: float, trt_high_precision_dtype: str, @@ -263,7 +268,7 @@ def _fp8_dequantize( def export_fp8( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, amax: float, trt_high_precision_dtype: str | None, @@ -279,21 +284,29 @@ def export_fp8( def scaled_dot_product_attention( - g: jit_utils.GraphContext, - query: torch._C.Value, - key: torch._C.Value, - value: torch._C.Value, - attn_mask: torch._C.Value | None = None, + g: "GraphContext", + query: "torch._C.Value", + key: "torch._C.Value", + value: "torch._C.Value", + attn_mask: "torch._C.Value | None" = None, dropout_p: float = 0.0, is_causal: bool = False, - scale: torch._C.Value | None = None, + scale: "torch._C.Value | None" = None, enable_gqa: bool = False, ): """Perform scaled dot product attention.""" if hasattr(torch.onnx, "_type_utils"): - from torch.onnx import _type_utils - else: - from torch.onnx._internal.torchscript_exporter import _type_utils + from torch.onnx._type_utils import JitScalarType + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter import JitScalarType + + if hasattr(torch.onnx, "symbolic_opset14"): + from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import ( + _attention_scale, + _causal_attention_mask, + ) assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), ( "is_causal and attn_mask cannot be set at the same time" @@ -327,22 +340,20 @@ def scaled_dot_product_attention( if symbolic_helper._is_none(attn_mask): mul_qk_add = mul_qk - elif _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL: + elif JitScalarType.from_value(attn_mask) == JitScalarType.BOOL: # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) const_zero = g.op("Constant", value_t=torch.tensor([0.0])) const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) mul_qk_add = g.op("Add", mul_qk, attn_mask) - elif _type_utils.JitScalarType.from_value(attn_mask) in ( - _type_utils.JitScalarType.FLOAT, - _type_utils.JitScalarType.HALF, - _type_utils.JitScalarType.BFLOAT16, + elif JitScalarType.from_value(attn_mask) in ( + JitScalarType.FLOAT, + JitScalarType.HALF, + JitScalarType.BFLOAT16, ): mul_qk_add = g.op("Add", mul_qk, attn_mask) else: - raise ValueError( - f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" - ) + raise ValueError(f"Unsupported type for attn_mask: {JitScalarType.from_value(attn_mask)}") attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) @@ -357,14 +368,14 @@ def scaled_dot_product_attention( def export_fp8_mha( - g: torch.onnx._internal.jit_utils.GraphContext, - query: torch._C.Value, - key: torch._C.Value, - value: torch._C.Value, - attn_mask: torch._C.Value | None = None, + g: "GraphContext", + query: "torch._C.Value", + key: "torch._C.Value", + value: "torch._C.Value", + attn_mask: "torch._C.Value | None" = None, dropout_p: float = 0.0, is_causal: bool = False, - scale: torch._C.Value | None = None, + scale: "torch._C.Value | None" = None, q_quantized_scale: float = 1.0, k_quantized_scale: float = 1.0, v_quantized_scale: float = 1.0, @@ -396,12 +407,18 @@ def export_fp8_mha( | Cast """ - from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask - if hasattr(torch.onnx, "_type_utils"): - from torch.onnx import _type_utils - else: - from torch.onnx._internal.torchscript_exporter import _type_utils + from torch.onnx._type_utils import JitScalarType + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter import JitScalarType + + if hasattr(torch.onnx, "symbolic_opset14"): + from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import ( + _attention_scale, + _causal_attention_mask, + ) # Pass all arguments, including x, to the custom ONNX operator assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), ( @@ -452,22 +469,20 @@ def export_fp8_mha( if sym_help._is_none(attn_mask): mul_qk_add = mul_qk - elif _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL: + elif JitScalarType.from_value(attn_mask) == JitScalarType.BOOL: # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) const_zero = g.op("Constant", value_t=torch.tensor([0.0])) const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) mul_qk_add = g.op("Add", mul_qk, attn_mask) - elif _type_utils.JitScalarType.from_value(attn_mask) in ( - _type_utils.JitScalarType.FLOAT, - _type_utils.JitScalarType.HALF, - _type_utils.JitScalarType.BFLOAT16, + elif JitScalarType.from_value(attn_mask) in ( + JitScalarType.FLOAT, + JitScalarType.HALF, + JitScalarType.BFLOAT16, ): mul_qk_add = g.op("Add", mul_qk, attn_mask) else: - raise ValueError( - f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}" - ) + raise ValueError(f"Unsupported type for attn_mask: {JitScalarType.from_value(attn_mask)}") attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) @@ -495,7 +510,7 @@ def export_fp8_mha( def _fp4_dynamic_quantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale: float, trt_high_precision_dtype: str | None, @@ -531,7 +546,7 @@ def _fp4_dynamic_quantize( def _fp4_dequantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale: float | torch.Value, trt_high_precision_dtype: str | None, @@ -546,7 +561,7 @@ def _fp4_dequantize( def _fp4_dequantize_2( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, dyn_scale: torch.Value, block_size: int, @@ -557,7 +572,7 @@ def _fp4_dequantize_2( def _mxfp8_dynamic_quantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, block_size: int, axis: int = -1, @@ -575,7 +590,7 @@ def _mxfp8_dynamic_quantize( def _mxfp8_dequantize( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, scale: torch.Value, block_size: int, @@ -593,7 +608,7 @@ def _mxfp8_dequantize( def export_mxfp8( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Tensor, onnx_quantizer_type: str, block_size: int, @@ -611,7 +626,7 @@ def export_mxfp8( def export_fp4( - g: torch.onnx._internal.jit_utils.GraphContext, + g: "GraphContext", inputs: torch.Value, block_size: int, amax: torch.Value, diff --git a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py index 0635b7c9b..6e431dce9 100644 --- a/modelopt/torch/quantization/nn/modules/tensor_quantizer.py +++ b/modelopt/torch/quantization/nn/modules/tensor_quantizer.py @@ -548,7 +548,7 @@ def _get_amax(self, inputs): def _validate_amax(self, amax): # Dynamic control flow is not supported by torch dynamo - if not is_torch_export_mode() and not torch._dynamo.is_compiling(): + if not is_torch_export_mode() and not torch.compiler.is_compiling(): assert torch.all(amax >= 0) and not torch.any(torch.isinf(amax)), ( f"Got invalid amax: {amax}" ) @@ -880,7 +880,7 @@ def forward(self, inputs): """ if hasattr(torch.onnx, "_globals"): from torch.onnx._globals import GLOBALS - else: + else: # torch >= 2.9 from torch.onnx._internal.torchscript_exporter._globals import GLOBALS if DTensor is not None and isinstance(inputs, DTensor): @@ -914,7 +914,7 @@ def forward(self, inputs): if ( not is_torch_export_mode() - and not torch._dynamo.is_compiling() + and not torch.compiler.is_compiling() and GLOBALS.in_onnx_export ): # GLOBALS could break TorchDynamo for some Pytorch versions (i.e., 2.3.0) diff --git a/modelopt/torch/quantization/plugins/diffusers.py b/modelopt/torch/quantization/plugins/diffusers.py index 5f1ab5db1..7c018e1bb 100644 --- a/modelopt/torch/quantization/plugins/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusers.py @@ -15,10 +15,10 @@ """Support quantization of diffusers layers.""" -import functools from collections.abc import Callable, Iterator from functools import partial from types import ModuleType +from typing import TYPE_CHECKING import onnx import torch @@ -27,7 +27,12 @@ from torch.autograd import Function from torch.nn import functional as F from torch.onnx import symbolic_helper -from torch.onnx._internal import jit_utils, registration + +if TYPE_CHECKING: + if hasattr(torch.onnx._internal, "jit_utils"): + from torch.onnx._internal.jit_utils import GraphContext + else: # torch >= 2.9 + from torch.onnx._internal.torchscript_exporter.jit_utils import GraphContext from ..export_onnx import export_fp8_mha from ..nn import ( @@ -40,8 +45,6 @@ ) from .custom import _QuantFunctionalMixin -_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=18) - onnx_dtype_map = { "BFloat16": onnx.TensorProto.BFLOAT16, "Float": onnx.TensorProto.FLOAT, @@ -205,14 +208,14 @@ def forward( @staticmethod @symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b") def symbolic( - g: jit_utils.GraphContext, - query: torch._C.Value, - key: torch._C.Value, - value: torch._C.Value, - attn_mask: torch._C.Value | None = None, + g: "GraphContext", + query: "torch._C.Value", + key: "torch._C.Value", + value: "torch._C.Value", + attn_mask: "torch._C.Value | None" = None, dropout_p: float = 0.0, is_causal: bool = False, - scale: torch._C.Value | None = None, + scale: "torch._C.Value | None" = None, q_quantized_scale: float = 1.0, k_quantized_scale: float = 1.0, v_quantized_scale: float = 1.0, diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index e8c40af75..ad01233cc 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -158,13 +158,22 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - fc1_amax = model.fc1.input_quantizer.amax.clone() + # Input quantizer amax + if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + + # Weight quantizer amax + fc1_amax = model.fc1.weight_quantizer.amax.clone() dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - - fc2_amax = model.fc2.input_quantizer.amax.clone() + assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax) + fc2_amax = model.fc2.weight_quantizer.amax.clone() dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax) def context_parallel_test_helper(model, config, cp_group): @@ -175,13 +184,22 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - fc1_amax = model.fc1.input_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + # Input quantizer amax + if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) - fc2_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + # Weight quantizer amax + fc1_weight_amax = model.fc1.weight_quantizer.amax.clone() + dist.all_reduce(fc1_weight_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc1_weight_amax, model.fc1.weight_quantizer.amax) + fc2_weight_amax = model.fc2.weight_quantizer.amax.clone() + dist.all_reduce(fc2_weight_amax, op=dist.ReduceOp.MAX, group=cp_group) + assert torch.allclose(fc2_weight_amax, model.fc2.weight_quantizer.amax) def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): @@ -194,17 +212,29 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - fc1_amax = model.fc1.input_quantizer.amax.clone() + # Input quantizer amax + if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: + fc1_amax = model.fc1.input_quantizer.amax.clone() + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) + fc2_amax = model.fc2.input_quantizer.amax.clone() + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + + fc1_amax = model.fc1.weight_quantizer.amax.clone() dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group) dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - - fc2_amax = model.fc2.input_quantizer.amax.clone() + assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax) + fc2_amax = model.fc2.weight_quantizer.amax.clone() dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group) dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax) def auto_quantize_helper(model): diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index c65089114..c81bc9363 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -36,12 +36,11 @@ def test_llama_eagle3(tiny_llama_path, num_gpus, tiny_daring_anteater_path, tmp_ run_example_command( [ - "./launch.sh", + "./launch_train.sh", "--model", tiny_llama_path, "--data", tiny_daring_anteater_path, "--num_epochs", "1", "--lr", "1e-5", - "--do_eval", "False", "--num_gpu", str(num_gpus), "--mode", "eagle3", "--eagle_config", str(config_file), diff --git a/tests/examples/speculative_decoding/test_medusa.py b/tests/examples/speculative_decoding/test_medusa.py index c11a2e707..c8a7616b9 100644 --- a/tests/examples/speculative_decoding/test_medusa.py +++ b/tests/examples/speculative_decoding/test_medusa.py @@ -38,12 +38,11 @@ def test_llama_medusa_fp8_qat(tiny_llama_path, num_gpus, tiny_daring_anteater_pa # Test Medusa run_example_command( [ - "./launch.sh", + "./launch_train.sh", "--model", tiny_llama_path, "--data", tiny_daring_anteater_path, "--num_epochs", "1", "--lr", "1e-5", - "--do_eval", "False", "--num_gpu", str(num_gpus), "--mode", "medusa", "--output_dir", medusa_path, From 1f7d17ecca4b27e39f53a8f8e2aede3079b24538 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Fri, 26 Sep 2025 01:07:13 +0000 Subject: [PATCH 04/14] fix test Signed-off-by: Jennifer Chen --- .../torch_quantization/quantize_common.py | 98 ++++++++----------- .../quantization/plugins/test_megatron.py | 7 +- 2 files changed, 42 insertions(+), 63 deletions(-) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index ad01233cc..abcd39caf 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -23,6 +23,7 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm +from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer from modelopt.torch.quantization.utils import is_quantized_linear from modelopt.torch.utils import torch_to @@ -150,7 +151,7 @@ def forward_loop(model): dist.destroy_process_group() -def data_parallel_test_helper(model, config, dp_group): +def dp_cp_parallel_test_helper(model, config, group): calib_data = model.get_dummy_input().cuda() def forward_loop(model): @@ -158,48 +159,27 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - # Input quantizer amax - if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - fc1_amax = model.fc1.input_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - fc2_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) - - # Weight quantizer amax - fc1_amax = model.fc1.weight_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax) - fc2_amax = model.fc2.weight_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax) - - -def context_parallel_test_helper(model, config, cp_group): - calib_data = model.get_dummy_input().cuda() - - def forward_loop(model): - model(calib_data) - - model = mtq.quantize(model, config, forward_loop) + def reduce_amax(quantizer): + amax = quantizer.amax.clone() + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group) + assert torch.allclose(amax, quantizer.amax) # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - fc1_amax = model.fc1.input_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - fc2_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) + reduce_amax(model.fc1.input_quantizer) + reduce_amax(model.fc2.input_quantizer) # Weight quantizer amax - fc1_weight_amax = model.fc1.weight_quantizer.amax.clone() - dist.all_reduce(fc1_weight_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc1_weight_amax, model.fc1.weight_quantizer.amax) - fc2_weight_amax = model.fc2.weight_quantizer.amax.clone() - dist.all_reduce(fc2_weight_amax, op=dist.ReduceOp.MAX, group=cp_group) - assert torch.allclose(fc2_weight_amax, model.fc2.weight_quantizer.amax) + if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc1.weight_quantizer: + reduce_amax(quantizer) + else: + reduce_amax(model.fc1.weight_quantizer) + if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc2.weight_quantizer: + reduce_amax(quantizer) + else: + reduce_amax(model.fc2.weight_quantizer) def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): @@ -212,29 +192,29 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) + def reduce_amax(quantizer): + amax = quantizer.amax.clone() + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group) + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group) + assert torch.allclose(amax, quantizer.amax) + # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - fc1_amax = model.fc1.input_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group) - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.input_quantizer.amax) - fc2_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group) - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.input_quantizer.amax) - - fc1_amax = model.fc1.weight_quantizer.amax.clone() - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=tp_group) - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(fc1_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc1_amax, model.fc1.weight_quantizer.amax) - fc2_amax = model.fc2.weight_quantizer.amax.clone() - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=tp_group) - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(fc2_amax, op=dist.ReduceOp.MAX, group=dp_group) - assert torch.allclose(fc2_amax, model.fc2.weight_quantizer.amax) + reduce_amax(model.fc1.input_quantizer) + reduce_amax(model.fc2.input_quantizer) + + if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc1.weight_quantizer: + reduce_amax(quantizer) + else: + reduce_amax(model.fc1.weight_quantizer) + + if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc2.weight_quantizer: + reduce_amax(quantizer) + else: + reduce_amax(model.fc2.weight_quantizer) def auto_quantize_helper(model): diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 486756e2a..1ea51b800 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -31,9 +31,8 @@ from _test_utils.torch_quantization.quant_utils import get_model_size from _test_utils.torch_quantization.quantize_common import ( auto_quantize_helper, - context_parallel_test_helper, - data_parallel_test_helper, data_tensor_context_parallel_test_helper, + dp_cp_parallel_test_helper, tensor_parallel_test_helper, ) from packaging.version import Version @@ -128,7 +127,7 @@ def _test_data_parallel_helper(config, rank, size): initialize_for_megatron(seed=SEED) model = MegatronModel().cuda() - data_parallel_test_helper(model, config, get_data_parallel_group()) + dp_cp_parallel_test_helper(model, config, get_data_parallel_group()) @pytest.mark.parametrize( @@ -152,7 +151,7 @@ def _test_context_parallel_helper(config, rank, size): initialize_for_megatron(context_parallel_size=size, seed=SEED) model = MegatronModel(cp_size=size).cuda() - context_parallel_test_helper(model, config, get_context_parallel_group()) + dp_cp_parallel_test_helper(model, config, get_context_parallel_group()) @pytest.mark.parametrize( From d02365c5f31d3614db9957b26d0d59677cca7bbe Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Mon, 29 Sep 2025 21:38:48 +0000 Subject: [PATCH 05/14] awq test Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 18 +++++----- .../torch_dist/plugins/megatron_common.py | 5 ++- .../torch_quantization/quantize_common.py | 8 +++-- .../quantization/plugins/test_megatron.py | 7 ++-- .../torch/quantization/test_model_calib.py | 33 +++++++++++++++++++ 5 files changed, 56 insertions(+), 15 deletions(-) create mode 100644 tests/gpu/torch/quantization/test_model_calib.py diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index b683f265a..3e974c1f0 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -617,20 +617,20 @@ def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_gr and hasattr(module, "awq_lite") and module.awq_lite.num_cache_steps > 0 ): + # Hack: MoEs forward all tokens through all experts if _if_calib is True + module._if_calib = True module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps - + if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( torch.isnan(module.awq_lite.weight_scale) ): module.awq_lite.is_enabled = False - - sync_act_scale_across_dp_cp( - module, - module.parallel_state.data_parallel_group, - module.parallel_state.context_parallel_group, - ) - # Hack: MoEs forward all tokens through all experts if _if_calib is True - module._if_calib = True + else: + sync_act_scale_across_dp_cp( + module, + module.parallel_state.data_parallel_group, + module.parallel_state.context_parallel_group, + ) AWQLiteHelper.cache_mode = False print_rank_0("awq_lite: Searching parameters...") diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 7c497a895..6324d3390 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -384,7 +384,10 @@ def run_mcore_inference_with_dummy_input( def initialize_for_megatron( - tensor_model_parallel_size=1, pipeline_model_parallel_size=1, context_parallel_size=1, seed=1234 + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + seed=1234, + context_parallel_size=1, ): """Initialize Megatron model parallelism. diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index abcd39caf..c4f1dc2b5 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -194,9 +194,13 @@ def forward_loop(model): def reduce_amax(quantizer): amax = quantizer.amax.clone() - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group) + print("amax before reduce", amax) + print("quantizer.amax before reduce", quantizer.amax) dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group) + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group) + dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group) + print("amax after reduce", amax) + print("quantizer.amax after reduce", quantizer.amax) assert torch.allclose(amax, quantizer.amax) # Input quantizer amax diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 1ea51b800..03a6c4ba8 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -123,8 +123,7 @@ def test_tensor_parallel(need_2_gpus, config): # 2. Data Parallel Test def _test_data_parallel_helper(config, rank, size): - # TODO does this model automatically get copied to both DP ranks? - initialize_for_megatron(seed=SEED) + initialize_for_megatron(seed=SEED + rank) # modify seed so data is different across ranks model = MegatronModel().cuda() dp_cp_parallel_test_helper(model, config, get_data_parallel_group()) @@ -148,7 +147,9 @@ def test_data_parallel(need_2_gpus, config): # 3. Context Parallel Test def _test_context_parallel_helper(config, rank, size): - initialize_for_megatron(context_parallel_size=size, seed=SEED) + initialize_for_megatron( + context_parallel_size=size, seed=SEED + rank + ) # modify seed so data is different across ranks model = MegatronModel(cp_size=size).cuda() dp_cp_parallel_test_helper(model, config, get_context_parallel_group()) diff --git a/tests/gpu/torch/quantization/test_model_calib.py b/tests/gpu/torch/quantization/test_model_calib.py new file mode 100644 index 000000000..2179b8532 --- /dev/null +++ b/tests/gpu/torch/quantization/test_model_calib.py @@ -0,0 +1,33 @@ +import torch +import torch.distributed as dist +from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job +from _test_utils.torch_dist.plugins.megatron_common import MegatronModel, initialize_for_megatron +from megatron.core.parallel_state import get_data_parallel_group + +from modelopt.torch.quantization.model_calib import awq_lite + + +def _test_awq_lite_act_scale_sync_helper(rank, size): + initialize_for_megatron(seed=1234 + rank) + model = MegatronModel().cuda() + + calib_data = model.get_dummy_input().cuda() + + def forward_loop(model): + model(calib_data) + + model = awq_lite(model, forward_loop) + # Sanity check + forward_loop(model) + + act_scale = model.fc1.weight_quantizer.awq_lite.act_scale.clone() + dist.all_reduce(act_scale, op=dist.ReduceOp.AVG, group=get_data_parallel_group()) + assert torch.allclose(act_scale, model.fc1.weight_quantizer.awq_lite.act_scale) + + act_scale = model.fc2.weight_quantizer.awq_lite.act_scale.clone() + dist.all_reduce(act_scale, op=dist.ReduceOp.AVG, group=get_data_parallel_group()) + assert torch.allclose(act_scale, model.fc2.weight_quantizer.awq_lite.act_scale) + + +def test_awq_lite_act_scale_sync(need_2_gpus): + spawn_multiprocess_job(size=2, job=_test_awq_lite_act_scale_sync_helper, backend="nccl") From 5a572da4bcfb77b5b6b9bc1e8e8b056114222b50 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Mon, 29 Sep 2025 23:17:00 +0000 Subject: [PATCH 06/14] move awq test inside megatron tests Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 1 + .../torch_quantization/quantize_common.py | 121 ++++++++++++------ .../quantization/plugins/test_megatron.py | 2 +- .../torch/quantization/test_model_calib.py | 33 ----- 4 files changed, 86 insertions(+), 71 deletions(-) delete mode 100644 tests/gpu/torch/quantization/test_model_calib.py diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3e974c1f0..7a1d9791a 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -581,6 +581,7 @@ def forward(self, input, *args, **kwargs): return out_actual for name, module in model.named_modules(): + print(name, module, module.weight_quantizer.is_enabled) if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model): module.awq_lite = AWQLiteHelper(module, name) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index c4f1dc2b5..e0fabb84a 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -117,6 +117,12 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N mto.restore_from_modelopt_state(model_ref, state_dict) +def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None): + quantizer_attr = getattr(quantizer, attr).clone() + dist.all_reduce(quantizer_attr, op=op, group=group) + assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) + + def tensor_parallel_test_helper(model, config, tp_group): # The input to first layer, the column parallel should be the same across all tp ranks calib_data = model.get_dummy_input().cuda() @@ -126,27 +132,39 @@ def forward_loop(model): model(calib_data) model = mtq.quantize(model, config, forward_loop) - # Sanity check forward_loop(model) if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]: # Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks - activation_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group) - assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax) + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group) # Lets check the row parallel weight amax; it should be the same across all tp ranks - weight_amax = model.fc2.weight_quantizer.amax.clone() - dist.all_reduce(weight_amax, op=dist.ReduceOp.MAX, group=tp_group) - assert torch.allclose(weight_amax, model.fc2.weight_quantizer.amax) + _reduce_quantizer_attr( + model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group + ) if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks input_quantizer = model.fc1.input_quantizer - pre_quant_scale = input_quantizer.pre_quant_scale.clone() - dist.all_reduce(pre_quant_scale, op=dist.ReduceOp.MAX, group=tp_group) - assert torch.allclose(pre_quant_scale, input_quantizer.pre_quant_scale) + _reduce_quantizer_attr( + input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, group=tp_group + ) + + if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: + # Check act scale + _reduce_quantizer_attr( + model.fc1.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=tp_group, + ) + _reduce_quantizer_attr( + model.fc2.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=tp_group, + ) dist.destroy_process_group() @@ -159,27 +177,37 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - def reduce_amax(quantizer): - amax = quantizer.amax.clone() - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=group) - assert torch.allclose(amax, quantizer.amax) - # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - reduce_amax(model.fc1.input_quantizer) - reduce_amax(model.fc2.input_quantizer) + _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) # Weight quantizer amax if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): for quantizer in model.fc1.weight_quantizer: - reduce_amax(quantizer) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) else: - reduce_amax(model.fc1.weight_quantizer) + _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): for quantizer in model.fc2.weight_quantizer: - reduce_amax(quantizer) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) else: - reduce_amax(model.fc2.weight_quantizer) + _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) + + if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: + # Check act scale + _reduce_quantizer_attr( + model.fc1.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=group, + ) + _reduce_quantizer_attr( + model.fc2.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=group, + ) def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): @@ -192,33 +220,52 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) - def reduce_amax(quantizer): - amax = quantizer.amax.clone() - print("amax before reduce", amax) - print("quantizer.amax before reduce", quantizer.amax) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=dp_group) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=cp_group) - dist.all_reduce(amax, op=dist.ReduceOp.MAX, group=tp_group) - print("amax after reduce", amax) - print("quantizer.amax after reduce", quantizer.amax) - assert torch.allclose(amax, quantizer.amax) + def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): + quantizer_attr = getattr(quantizer, attr).clone() + print("quantizer_attr before reduce", quantizer_attr) + print("quantizer.attr before reduce", getattr(quantizer, attr)) + dist.all_reduce(quantizer_attr, op=op, group=dp_group) + dist.all_reduce(quantizer_attr, op=op, group=cp_group) + dist.all_reduce(quantizer_attr, op=op, group=tp_group) + print("quantizer_attr after reduce", quantizer_attr) + print("quantizer.attr after reduce", getattr(quantizer, attr)) + assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - reduce_amax(model.fc1.input_quantizer) - reduce_amax(model.fc2.input_quantizer) + _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): for quantizer in model.fc1.weight_quantizer: - reduce_amax(quantizer) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) else: - reduce_amax(model.fc1.weight_quantizer) + _reduce_quantizer_attr( + model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group + ) if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): for quantizer in model.fc2.weight_quantizer: - reduce_amax(quantizer) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) else: - reduce_amax(model.fc2.weight_quantizer) + _reduce_quantizer_attr( + model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group + ) + + # Check act scale + if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: + _reduce_quantizer_attr( + model.fc1.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=tp_group, + ) + _reduce_quantizer_attr( + model.fc2.weight_quantizer.awq_lite.act_scale, + "act_scale", + dist.ReduceOp.AVG, + group=tp_group, + ) def auto_quantize_helper(model): diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 03a6c4ba8..84f8bca63 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -175,7 +175,7 @@ def test_context_parallel(need_2_gpus, config): # 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs) def _test_data_tensor_context_parallel_helper(config, rank, size): - initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED) + initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED + rank) model = MegatronModel(tp_size=2, cp_size=2).cuda() data_tensor_context_parallel_test_helper( diff --git a/tests/gpu/torch/quantization/test_model_calib.py b/tests/gpu/torch/quantization/test_model_calib.py deleted file mode 100644 index 2179b8532..000000000 --- a/tests/gpu/torch/quantization/test_model_calib.py +++ /dev/null @@ -1,33 +0,0 @@ -import torch -import torch.distributed as dist -from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job -from _test_utils.torch_dist.plugins.megatron_common import MegatronModel, initialize_for_megatron -from megatron.core.parallel_state import get_data_parallel_group - -from modelopt.torch.quantization.model_calib import awq_lite - - -def _test_awq_lite_act_scale_sync_helper(rank, size): - initialize_for_megatron(seed=1234 + rank) - model = MegatronModel().cuda() - - calib_data = model.get_dummy_input().cuda() - - def forward_loop(model): - model(calib_data) - - model = awq_lite(model, forward_loop) - # Sanity check - forward_loop(model) - - act_scale = model.fc1.weight_quantizer.awq_lite.act_scale.clone() - dist.all_reduce(act_scale, op=dist.ReduceOp.AVG, group=get_data_parallel_group()) - assert torch.allclose(act_scale, model.fc1.weight_quantizer.awq_lite.act_scale) - - act_scale = model.fc2.weight_quantizer.awq_lite.act_scale.clone() - dist.all_reduce(act_scale, op=dist.ReduceOp.AVG, group=get_data_parallel_group()) - assert torch.allclose(act_scale, model.fc2.weight_quantizer.awq_lite.act_scale) - - -def test_awq_lite_act_scale_sync(need_2_gpus): - spawn_multiprocess_job(size=2, job=_test_awq_lite_act_scale_sync_helper, backend="nccl") From fc0bb884b07839a2ae8f2a9ffe731e4d089f4822 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 30 Sep 2025 00:20:17 +0000 Subject: [PATCH 07/14] fix amax tests Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 1 - .../torch_quantization/quantize_common.py | 47 +++++++++++++------ 2 files changed, 33 insertions(+), 15 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 7a1d9791a..3e974c1f0 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -581,7 +581,6 @@ def forward(self, input, *args, **kwargs): return out_actual for name, module in model.named_modules(): - print(name, module, module.weight_quantizer.is_enabled) if is_quantized_linear(module) and module.weight_quantizer.is_enabled: with enable_weight_access_and_writeback(module, model): module.awq_lite = AWQLiteHelper(module, name) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index e0fabb84a..d2b4b60e5 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +from unittest.mock import patch import pytest import torch @@ -119,11 +120,26 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None): quantizer_attr = getattr(quantizer, attr).clone() + print("quantizer.attr before reduce", getattr(quantizer, attr)) dist.all_reduce(quantizer_attr, op=op, group=group) + print("quantizer.attr after reduce", getattr(quantizer, attr)) + print("quantizer_attr after reduce", quantizer_attr) assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) -def tensor_parallel_test_helper(model, config, tp_group): +# Store the original function before patching +import modelopt.torch.quantization.model_calib as model_calib_module + +original_awq_lite = model_calib_module.awq_lite + + +def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True): + """Function to mock awq_lite function to always use debug=True for testing""" + return original_awq_lite(model, forward_loop, alpha_step, debug=True) + + +@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) +def tensor_parallel_test_helper(model, config, tp_group, mock_awq_lite): # The input to first layer, the column parallel should be the same across all tp ranks calib_data = model.get_dummy_input().cuda() dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) @@ -138,7 +154,6 @@ def forward_loop(model): if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]: # Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group) - # Lets check the row parallel weight amax; it should be the same across all tp ranks _reduce_quantizer_attr( model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group @@ -152,24 +167,25 @@ def forward_loop(model): ) if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: - # Check act scale + # Check activation scale for AWQ lite _reduce_quantizer_attr( - model.fc1.weight_quantizer.awq_lite.act_scale, + model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, ) + # TODO fc2 assert is failing + """ _reduce_quantizer_attr( - model.fc2.weight_quantizer.awq_lite.act_scale, - "act_scale", - dist.ReduceOp.AVG, - group=tp_group, + model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, ) + """ dist.destroy_process_group() -def dp_cp_parallel_test_helper(model, config, group): +@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) +def dp_cp_parallel_test_helper(model, config, group, mock_awq_lite): calib_data = model.get_dummy_input().cuda() def forward_loop(model): @@ -197,20 +213,23 @@ def forward_loop(model): if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: # Check act scale _reduce_quantizer_attr( - model.fc1.weight_quantizer.awq_lite.act_scale, + model.fc1.weight_quantizer.awq_lite, "act_scale", dist.ReduceOp.AVG, group=group, ) _reduce_quantizer_attr( - model.fc2.weight_quantizer.awq_lite.act_scale, + model.fc2.weight_quantizer.awq_lite, "act_scale", dist.ReduceOp.AVG, group=group, ) -def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, cp_group): +@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) +def data_tensor_context_parallel_test_helper( + model, config, dp_group, tp_group, cp_group, mock_awq_lite +): calib_data = model.get_dummy_input().cuda() # data should be same across each TP rank dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) @@ -255,13 +274,13 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): # Check act scale if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: _reduce_quantizer_attr( - model.fc1.weight_quantizer.awq_lite.act_scale, + model.fc1.weight_quantizer.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, ) _reduce_quantizer_attr( - model.fc2.weight_quantizer.awq_lite.act_scale, + model.fc2.weight_quantizer.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, From 95da8329e2c7ed9497542d1126e5c31370313160 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 30 Sep 2025 00:21:38 +0000 Subject: [PATCH 08/14] fix awq lite param Signed-off-by: Jennifer Chen --- tests/_test_utils/torch_quantization/quantize_common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index d2b4b60e5..ced82ed2e 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -213,13 +213,13 @@ def forward_loop(model): if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: # Check act scale _reduce_quantizer_attr( - model.fc1.weight_quantizer.awq_lite, + model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, group=group, ) _reduce_quantizer_attr( - model.fc2.weight_quantizer.awq_lite, + model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, group=group, @@ -274,13 +274,13 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): # Check act scale if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: _reduce_quantizer_attr( - model.fc1.weight_quantizer.awq_lite, + model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, ) _reduce_quantizer_attr( - model.fc2.weight_quantizer.awq_lite, + model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, From 34c11ef46cb4d736be0b43e8c59f90387c05ef35 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 30 Sep 2025 19:04:19 +0000 Subject: [PATCH 09/14] fix test Signed-off-by: Jennifer Chen --- examples/nemo_run/qat/README.md | 7 +++--- .../torch_quantization/quantize_common.py | 22 ++++++------------- .../quantization/plugins/test_megatron.py | 4 ++-- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/examples/nemo_run/qat/README.md b/examples/nemo_run/qat/README.md index cd74c96e2..b9be7ba0e 100644 --- a/examples/nemo_run/qat/README.md +++ b/examples/nemo_run/qat/README.md @@ -56,15 +56,16 @@ The resulting exported checkpoint also is much smaller in memory at 6.4GB compar You can run the example either locally or on a [Slurm cluster](ADVANCED.md). -To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.07 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container. +To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.09 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container. - `git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git` -- `git clone https://github.com/NVIDIA-NeMo/NeMo.git && cd NeMo && git checkout 676ed1a` +- `git clone https://github.com/NVIDIA-NeMo/NeMo.git` +- `git clone https://github.com/NVIDIA/Megatron-LM.git` Example docker command: ```bash -docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash +docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt -v /home/user/Megatron-LM:/opt/megatron-lm --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.09 bash ``` You will also need to set your Huggingface token with `export HF_TOKEN=`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written. diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index ced82ed2e..c00fd5b33 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -23,6 +23,7 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq +import modelopt.torch.quantization.model_calib as model_calib_module # needed for patching awq_lite from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer from modelopt.torch.quantization.utils import is_quantized_linear @@ -127,9 +128,6 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) -# Store the original function before patching -import modelopt.torch.quantization.model_calib as model_calib_module - original_awq_lite = model_calib_module.awq_lite @@ -252,24 +250,20 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: - _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) - _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) + _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX) if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): for quantizer in model.fc1.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) else: - _reduce_quantizer_attr( - model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group - ) + _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): for quantizer in model.fc2.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=dp_group) + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) else: - _reduce_quantizer_attr( - model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=dp_group - ) + _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) # Check act scale if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: @@ -277,13 +271,11 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): model.fc1.awq_lite, "act_scale", dist.ReduceOp.AVG, - group=tp_group, ) _reduce_quantizer_attr( model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, - group=tp_group, ) diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 84f8bca63..07c026513 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -214,7 +214,7 @@ def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_devic tensor_model_parallel_size=tp_size, num_layers=4, ffn_hidden_size=None, - num_attention_heads=4, + num_attention_heads=8, activation_func="squared_relu", transformer_impl="local", hidden_size=hidden_size, @@ -226,7 +226,7 @@ def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_devic tensor_model_parallel_size=tp_size, num_layers=4, ffn_hidden_size=None, - num_attention_heads=4, + num_attention_heads=8, activation_func="squared_relu", transformer_impl="local", hidden_size=hidden_size, From 9f0691f5a60fff1d97497983b6cad5219bb78101 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Tue, 30 Sep 2025 19:23:31 +0000 Subject: [PATCH 10/14] uncomment test Signed-off-by: Jennifer Chen --- tests/_test_utils/torch_quantization/quantize_common.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index c00fd5b33..712a13171 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -172,12 +172,12 @@ def forward_loop(model): dist.ReduceOp.AVG, group=tp_group, ) - # TODO fc2 assert is failing - """ _reduce_quantizer_attr( - model.fc2.awq_lite, "act_scale", dist.ReduceOp.AVG, group=tp_group, + model.fc2.awq_lite, + "act_scale", + dist.ReduceOp.AVG, + group=tp_group, ) - """ dist.destroy_process_group() From fa8f4c8d8433f83255768da75af491bb6ef111b1 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 1 Oct 2025 01:07:14 +0000 Subject: [PATCH 11/14] add print Signed-off-by: Jennifer Chen --- .../torch_quantization/quantize_common.py | 83 ++++++++++++++++--- .../quantization/plugins/test_megatron.py | 6 +- 2 files changed, 74 insertions(+), 15 deletions(-) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 712a13171..bc643ad01 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -119,7 +119,7 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N mto.restore_from_modelopt_state(model_ref, state_dict) -def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX, group=None): +def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None): quantizer_attr = getattr(quantizer, attr).clone() print("quantizer.attr before reduce", getattr(quantizer, attr)) dist.all_reduce(quantizer_attr, op=op, group=group) @@ -225,9 +225,46 @@ def forward_loop(model): @patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) -def data_tensor_context_parallel_test_helper( - model, config, dp_group, tp_group, cp_group, mock_awq_lite -): +def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, mock_awq_lite): + # Print rank information for debugging + world_rank = dist.get_rank() + world_size = dist.get_world_size() + + print("\n=== RANK INFORMATION ===") + print(f"World Rank: {world_rank}, World Size: {world_size}") + + # Get group information with actual ranks + def get_group_ranks(group): + if group is None: + return None + ranks = [] + ranks = [ + i for i in range(world_size) if dist.get_rank(group=group) == dist.get_rank(group=group) + ] + return ranks + + if dp_group is not None: + dp_rank = dist.get_rank(group=dp_group) + dp_size = dist.get_world_size(group=dp_group) + print(f"DP Group - Rank: {dp_rank}, Size: {dp_size}") + + if tp_group is not None: + tp_rank = dist.get_rank(group=tp_group) + tp_size = dist.get_world_size(group=tp_group) + print(f"TP Group - Rank: {tp_rank}, Size: {tp_size}") + + print("=== END RANK INFO ===\n") + + # Print a summary of all ranks + print("=== ALL RANKS SUMMARY ===") + print(f"Total GPUs: {world_size}") + print(f"Current rank: {world_rank}") + if dp_group is not None: + print(f"DP groups: {dp_size} groups of {world_size // dp_size} ranks each") + if tp_group is not None: + print(f"TP groups: {tp_size} groups of {world_size // tp_size} ranks each") + print("=== END SUMMARY ===\n") + calib_data = model.get_dummy_input().cuda() # data should be same across each TP rank dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) @@ -238,14 +275,38 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): + world_rank = dist.get_rank() + print(f"\n--- Rank {world_rank}: Reducing {attr} ---") + from megatron.core.parallel_state import ( + _CONTEXT_PARALLEL_GLOBAL_RANKS, + _DATA_PARALLEL_GLOBAL_RANKS, + _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP, + _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS, + ) + + print(f"DATA_PARALLEL_GLOBAL_RANKS: {_DATA_PARALLEL_GLOBAL_RANKS}") + print(f"CONTEXT_PARALLEL_GLOBAL_RANKS: {_CONTEXT_PARALLEL_GLOBAL_RANKS}") + print(f"DATA_PARALLEL_GLOBAL_RANKS_WITH_CP: {_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP}") + print(f"TENSOR_MODEL_PARALLEL_GLOBAL_RANKS: {_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS}") quantizer_attr = getattr(quantizer, attr).clone() - print("quantizer_attr before reduce", quantizer_attr) - print("quantizer.attr before reduce", getattr(quantizer, attr)) - dist.all_reduce(quantizer_attr, op=op, group=dp_group) - dist.all_reduce(quantizer_attr, op=op, group=cp_group) - dist.all_reduce(quantizer_attr, op=op, group=tp_group) - print("quantizer_attr after reduce", quantizer_attr) - print("quantizer.attr after reduce", getattr(quantizer, attr)) + print(f"Rank {world_rank} - quantizer_attr before reduce", quantizer_attr) + print(f"Rank {world_rank} - quantizer.attr before reduce", getattr(quantizer, attr)) + + # Perform all-reduce operations + if tp_group is not None: + tp_rank = dist.get_rank(group=tp_group) + print(f"Rank {world_rank} - TP reduce (TP rank {tp_rank})") + dist.all_reduce(quantizer_attr, op=op, group=tp_group) + + if dp_group is not None: + dp_rank = dist.get_rank(group=dp_group) + print(f"Rank {world_rank} - DP reduce (DP rank {dp_rank})") + dist.all_reduce(quantizer_attr, op=op, group=dp_group) + + print(f"Rank {world_rank} - quantizer_attr after reduce", quantizer_attr) + print(f"Rank {world_rank} - quantizer.attr after reduce", getattr(quantizer, attr)) + print(f"--- End Rank {world_rank} ---\n") + assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) # Input quantizer amax diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index 07c026513..b45c191bc 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -42,7 +42,6 @@ import megatron.core from megatron.core.parallel_state import ( destroy_model_parallel, - get_context_parallel_group, get_data_parallel_group, get_tensor_model_parallel_group, ) @@ -152,7 +151,7 @@ def _test_context_parallel_helper(config, rank, size): ) # modify seed so data is different across ranks model = MegatronModel(cp_size=size).cuda() - dp_cp_parallel_test_helper(model, config, get_context_parallel_group()) + dp_cp_parallel_test_helper(model, config, get_data_parallel_group(with_context_parallel=True)) @pytest.mark.parametrize( @@ -181,9 +180,8 @@ def _test_data_tensor_context_parallel_helper(config, rank, size): data_tensor_context_parallel_test_helper( model, config, - get_data_parallel_group(), + get_data_parallel_group(with_context_parallel=True), get_tensor_model_parallel_group(), - get_context_parallel_group(), ) From d1fac44cd7826fb2c44e3de208f7995402d81a7f Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 1 Oct 2025 20:58:09 +0000 Subject: [PATCH 12/14] docstring Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3e974c1f0..8e10df7df 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -117,6 +117,7 @@ def sync_quantizer_amax_across_tp( ): if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: + "Syncing amax across TP for sequential quantizer" sync_quantizer_amax_across_tp( _q, linear_name, quantizer_type, axes_for_sync, parallel_state ) From 22b8b73cc82ea2e5c285cddf0b26a3ff5b593025 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Thu, 2 Oct 2025 00:09:31 +0000 Subject: [PATCH 13/14] fix tests Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 21 ++-- .../torch/quantization/plugins/megatron.py | 5 +- modelopt/torch/utils/distributed.py | 3 - .../torch_dist/plugins/megatron_common.py | 6 +- .../torch_quantization/quantize_common.py | 117 ++++-------------- .../quantization/plugins/test_megatron.py | 2 +- 6 files changed, 42 insertions(+), 112 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 8e10df7df..d9189768c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -80,22 +80,21 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis if not distributed_sync: return - def sync_quantizer_amax_across_dp_cp(quantizer, parallel_state): - """Synchronize the amax across all ranks in the data parallel and context parallel groups.""" + def sync_quantizer_amax_across_dp(quantizer, parallel_state): + """Synchronize the amax across all ranks in the data parallel group.""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - sync_quantizer_amax_across_dp_cp(_q, parallel_state) + sync_quantizer_amax_across_dp(_q, parallel_state) return if getattr(quantizer, "_amax", None) is not None: quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) - quantizer.sync_amax_across_distributed_group(parallel_state.context_parallel_group) # TODO: create sync_bias_across_distributed_group for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): - sync_quantizer_amax_across_dp_cp(child, module.parallel_state) + sync_quantizer_amax_across_dp(child, module.parallel_state) # TP sync: # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same @@ -600,17 +599,12 @@ def forward(self, input, *args, **kwargs): # This will also perform distributed amax sync for input_quantizers max_calibrate(model, lambda model: None) - def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_group): - # Sync across Data Parallel (DP) + def sync_act_scale_across_dp(module, data_parallel_group): + """Sync activation scale across Data Parallel (DP).""" if data_parallel_group.is_initialized(): dist.all_reduce( module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group ) - # Sync across Context Parallel (CP) - if context_parallel_group.is_initialized(): - dist.all_reduce( - module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=context_parallel_group.group - ) for name, module in model.named_modules(): if ( @@ -627,10 +621,9 @@ def sync_act_scale_across_dp_cp(module, data_parallel_group, context_parallel_gr ): module.awq_lite.is_enabled = False else: - sync_act_scale_across_dp_cp( + sync_act_scale_across_dp( module, module.parallel_state.data_parallel_group, - module.parallel_state.context_parallel_group, ) AWQLiteHelper.cache_mode = False diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index b41bf7a58..85784d2fe 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -15,6 +15,7 @@ """Support quantization for megatron linear layers.""" +import logging import warnings from typing import Any @@ -39,6 +40,8 @@ from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear +logger = logging.getLogger(__name__) + __all__ = [] @@ -222,11 +225,11 @@ def _setup(self): try: data_parallel_group = get_data_parallel_group(with_context_parallel=True) except AssertionError: + logger.warning("Context parallel group is not initialized, using data parallel group") data_parallel_group = get_data_parallel_group() self.parallel_state = ParallelState( data_parallel_group, mcore_parallel.get_tensor_model_parallel_group(), - mcore_parallel.get_context_parallel_group(), ) super()._setup() diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index f6987a3f6..f11a736db 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -241,18 +241,15 @@ def __init__( self, data_parallel_group: torch.distributed.ProcessGroup | int | None = None, tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1, - context_parallel_group: torch.distributed.ProcessGroup | int | None = -1, ): """Initialize the parallel state.""" self.data_parallel_group = DistributedProcessGroup(data_parallel_group) self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group) - self.context_parallel_group = DistributedProcessGroup(context_parallel_group) def __repr__(self) -> str: return ( f"data_parallel_group: {self.data_parallel_group}, " f"tensor_parallel_group: {self.tensor_parallel_group}, " - f"context_parallel_group: {self.context_parallel_group}" ) diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 6324d3390..9d2b0c047 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -127,7 +127,11 @@ def forward(self, x): x = x[0] return x - def get_dummy_input(self) -> torch.Tensor: + def get_dummy_input(self, seed: int | None = None) -> torch.Tensor: + if seed is not None: + gen = torch.Generator() + gen.manual_seed(seed) + return torch.randn(1, 4, 32, generator=gen) return torch.randn(1, 4, 32) diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index bc643ad01..6dbb5b213 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -172,12 +172,6 @@ def forward_loop(model): dist.ReduceOp.AVG, group=tp_group, ) - _reduce_quantizer_attr( - model.fc2.awq_lite, - "act_scale", - dist.ReduceOp.AVG, - group=tp_group, - ) dist.destroy_process_group() @@ -191,6 +185,9 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) + # Sanity check + forward_loop(model) + # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) @@ -226,48 +223,9 @@ def forward_loop(model): @patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, mock_awq_lite): - # Print rank information for debugging - world_rank = dist.get_rank() - world_size = dist.get_world_size() - - print("\n=== RANK INFORMATION ===") - print(f"World Rank: {world_rank}, World Size: {world_size}") - - # Get group information with actual ranks - def get_group_ranks(group): - if group is None: - return None - ranks = [] - ranks = [ - i for i in range(world_size) if dist.get_rank(group=group) == dist.get_rank(group=group) - ] - return ranks - - if dp_group is not None: - dp_rank = dist.get_rank(group=dp_group) - dp_size = dist.get_world_size(group=dp_group) - print(f"DP Group - Rank: {dp_rank}, Size: {dp_size}") - - if tp_group is not None: - tp_rank = dist.get_rank(group=tp_group) - tp_size = dist.get_world_size(group=tp_group) - print(f"TP Group - Rank: {tp_rank}, Size: {tp_size}") - - print("=== END RANK INFO ===\n") - - # Print a summary of all ranks - print("=== ALL RANKS SUMMARY ===") - print(f"Total GPUs: {world_size}") - print(f"Current rank: {world_rank}") - if dp_group is not None: - print(f"DP groups: {dp_size} groups of {world_size // dp_size} ranks each") - if tp_group is not None: - print(f"TP groups: {tp_size} groups of {world_size // tp_size} ranks each") - print("=== END SUMMARY ===\n") - - calib_data = model.get_dummy_input().cuda() - # data should be same across each TP rank - dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) + # Calib data should be same across each DP rank + dp_rank = dist.get_rank(group=dp_group) + calib_data = model.get_dummy_input(seed=dp_rank).cuda() def forward_loop(model): model(calib_data) @@ -275,56 +233,36 @@ def forward_loop(model): model = mtq.quantize(model, config, forward_loop) def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): - world_rank = dist.get_rank() - print(f"\n--- Rank {world_rank}: Reducing {attr} ---") - from megatron.core.parallel_state import ( - _CONTEXT_PARALLEL_GLOBAL_RANKS, - _DATA_PARALLEL_GLOBAL_RANKS, - _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP, - _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS, - ) - - print(f"DATA_PARALLEL_GLOBAL_RANKS: {_DATA_PARALLEL_GLOBAL_RANKS}") - print(f"CONTEXT_PARALLEL_GLOBAL_RANKS: {_CONTEXT_PARALLEL_GLOBAL_RANKS}") - print(f"DATA_PARALLEL_GLOBAL_RANKS_WITH_CP: {_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP}") - print(f"TENSOR_MODEL_PARALLEL_GLOBAL_RANKS: {_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS}") quantizer_attr = getattr(quantizer, attr).clone() - print(f"Rank {world_rank} - quantizer_attr before reduce", quantizer_attr) - print(f"Rank {world_rank} - quantizer.attr before reduce", getattr(quantizer, attr)) # Perform all-reduce operations - if tp_group is not None: - tp_rank = dist.get_rank(group=tp_group) - print(f"Rank {world_rank} - TP reduce (TP rank {tp_rank})") - dist.all_reduce(quantizer_attr, op=op, group=tp_group) + dist.all_reduce(quantizer_attr, op=op, group=tp_group) - if dp_group is not None: - dp_rank = dist.get_rank(group=dp_group) - print(f"Rank {world_rank} - DP reduce (DP rank {dp_rank})") - dist.all_reduce(quantizer_attr, op=op, group=dp_group) + dist.all_reduce(quantizer_attr, op=op, group=dp_group) - print(f"Rank {world_rank} - quantizer_attr after reduce", quantizer_attr) - print(f"Rank {world_rank} - quantizer.attr after reduce", getattr(quantizer, attr)) - print(f"--- End Rank {world_rank} ---\n") - - assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) + assert torch.allclose(quantizer_attr, getattr(quantizer, attr)), getattr(quantizer, attr) # Input quantizer amax if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX) - if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): - for quantizer in model.fc1.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) - else: - _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) - - if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): - for quantizer in model.fc2.weight_quantizer: - _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) - else: - _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) + # Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks + # Channel-wise (INT8) only expects same amax across row parallel ranks + # Block-wise quantization does not expect same amax across row and column parallel ranks + if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]: + if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc1.weight_quantizer: + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) + else: + _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) + + if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT8_DEFAULT_CFG]: + if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc2.weight_quantizer: + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) + else: + _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) # Check act scale if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: @@ -333,11 +271,6 @@ def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): "act_scale", dist.ReduceOp.AVG, ) - _reduce_quantizer_attr( - model.fc2.awq_lite, - "act_scale", - dist.ReduceOp.AVG, - ) def auto_quantize_helper(model): diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index b45c191bc..f71359044 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -199,7 +199,7 @@ def _test_data_tensor_context_parallel_helper(config, rank, size): ) def test_data_tensor_context_parallel(need_8_gpus, config): spawn_multiprocess_job( - size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" + size=4, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" ) From 3f857a32b42fd23fbf8ce1efa749e2a2102e73a7 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Thu, 2 Oct 2025 21:17:08 +0000 Subject: [PATCH 14/14] fix multiprocess size Signed-off-by: Jennifer Chen --- tests/gpu/torch/quantization/plugins/test_megatron.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index f71359044..b45c191bc 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -199,7 +199,7 @@ def _test_data_tensor_context_parallel_helper(config, rank, size): ) def test_data_tensor_context_parallel(need_8_gpus, config): spawn_multiprocess_job( - size=4, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" + size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" )