diff --git a/examples/nemo_run/qat/README.md b/examples/nemo_run/qat/README.md index 79715953c..b9be7ba0e 100644 --- a/examples/nemo_run/qat/README.md +++ b/examples/nemo_run/qat/README.md @@ -56,15 +56,16 @@ The resulting exported checkpoint also is much smaller in memory at 6.4GB compar You can run the example either locally or on a [Slurm cluster](ADVANCED.md). -To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.07 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container. +To run the example locally, launch a [NeMo container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) with version 25.09 or higher. Clone the `TensorRT-Model-Optimizer` repository and `NeMo` repository (checkout a specific commit for NeMo), then mount it onto your docker container. - `git clone https://github.com/NVIDIA/TensorRT-Model-Optimizer.git` -- `git clone https://github.com/NVIDIA-NeMo/NeMo.git && cd NeMo && git checkout 676ed1a` +- `git clone https://github.com/NVIDIA-NeMo/NeMo.git` +- `git clone https://github.com/NVIDIA/Megatron-LM.git` Example docker command: ```bash -docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.07 bash +docker run -v /home/user/:/home/user/ -v /home/user/NeMo:/opt/NeMo -v /home/user/TensorRT-Model-Optimizer/modelopt/:/usr/local/lib/python3.12/dist-packages/modelopt -v /home/user/Megatron-LM:/opt/megatron-lm --gpus all -it --shm-size 20g --rm nvcr.io/nvidia/nemo:25.09 bash ``` You will also need to set your Huggingface token with `export HF_TOKEN=`. You may also need to enable write access to the docker container to the `examples/nemo_run` folder by doing `chmod 777 nemo_run` so that logs can be written. @@ -92,7 +93,7 @@ In order to train using QAD, launch the example with `python qat/nemo_qat_flow.p To perform QAD training, run: ```bash -python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment +python qat/nemo_qat_flow.py --distill --log-dir /my/log/dir --experiment qad_experiment --tensor_parallelism 4 ``` ## Supported models diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 3ef014d65..d3cecc00c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -80,21 +80,26 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis if not distributed_sync: return - def sync_quantizer_amax_across_dp(quantizer, parallel_state): + def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): + """Synchronize the amax across all ranks in the data parallel and context parallel groups.""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - sync_quantizer_amax_across_dp(_q, parallel_state) + sync_quantizer_amax_across_dp_ep(_q, parallel_state) return if getattr(quantizer, "_amax", None) is not None: quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) + quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) + if parallel_state.expert_tensor_parallel_group is not None: + quantizer.sync_amax_across_distributed_group( + parallel_state.expert_tensor_parallel_group + ) # TODO: create sync_bias_across_distributed_group for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): - sync_quantizer_amax_across_dp(child, module.parallel_state) - + sync_quantizer_amax_across_dp_ep(child, module.parallel_state) # TP sync: # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same @@ -116,6 +121,7 @@ def sync_quantizer_amax_across_tp( ): if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: + "Syncing amax across TP for sequential quantizer" sync_quantizer_amax_across_tp( _q, linear_name, quantizer_type, axes_for_sync, parallel_state ) @@ -598,19 +604,32 @@ def forward(self, input, *args, **kwargs): # This will also perform distributed amax sync for input_quantizers max_calibrate(model, lambda model: None) + def sync_act_scale_across_dp(module, data_parallel_group): + """Sync activation scale across Data Parallel (DP).""" + if data_parallel_group.is_initialized(): + dist.all_reduce( + module.awq_lite.act_scale, op=dist.ReduceOp.AVG, group=data_parallel_group.group + ) + for name, module in model.named_modules(): if ( is_quantized_linear(module) and hasattr(module, "awq_lite") and module.awq_lite.num_cache_steps > 0 ): + # Hack: MoEs forward all tokens through all experts if _if_calib is True + module._if_calib = True module.awq_lite.act_scale = module.awq_lite.act_scale / module.awq_lite.num_cache_steps + if torch.any(torch.isnan(module.awq_lite.act_scale)) or torch.any( torch.isnan(module.awq_lite.weight_scale) ): module.awq_lite.is_enabled = False - # Hack: MoEs forward all tokens through all experts if _if_calib is True - module._if_calib = True + else: + sync_act_scale_across_dp( + module, + module.parallel_state.data_parallel_group, + ) AWQLiteHelper.cache_mode = False print_rank_0("awq_lite: Searching parameters...") diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 1cf9416ec..c414f99c8 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -15,6 +15,7 @@ """Support quantization for megatron linear layers.""" +import logging import warnings from typing import Any @@ -22,6 +23,9 @@ import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp import torch +import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear +from megatron.core.extensions import transformer_engine as megatron_te +from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint @@ -34,10 +38,12 @@ from modelopt.torch.utils.distributed import ParallelState from ..nn import QuantModuleRegistry, TensorQuantizer -from ..nn.modules.quant_linear import RealQuantLinear +from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear +logger = logging.getLogger(__name__) + __all__ = [] @@ -217,9 +223,23 @@ class _MegatronParallelLinear(_ParallelLinear): ] def _setup(self): + data_parallel_group = None + try: + data_parallel_group = get_data_parallel_group(with_context_parallel=True) + except AssertionError: + logger.warning("Context parallel group is not initialized, using data parallel group") + data_parallel_group = get_data_parallel_group() + + try: + expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() + except AssertionError: + expert_tensor_parallel_group = None + self.parallel_state = ParallelState( - getattr(mcore_parallel, "get_expert_data_parallel_group", "get_data_parallel_group")(), + data_parallel_group, mcore_parallel.get_tensor_model_parallel_group(), + mcore_parallel.get_expert_model_parallel_group(), + expert_tensor_parallel_group, ) super()._setup() @@ -462,3 +482,103 @@ class _RealQuantMegatronRowParallelLinear( def forward(self, input, *args, **kwargs): return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs) + + +# Register the public te.pytorch.GroupedLinear class +@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear_public"}) +class _QuantTEGroupedLinear(_MegatronParallelLinear): + def _setup(self): + data_parallel_group = None + try: + data_parallel_group = get_data_parallel_group(with_context_parallel=True) + except AssertionError: + data_parallel_group = get_data_parallel_group() + + try: + expert_tensor_parallel_group = mcore_parallel.get_expert_tensor_parallel_group() + except AssertionError: + expert_tensor_parallel_group = None + self.parallel_state = ParallelState( + data_parallel_group, + mcore_parallel.get_tensor_model_parallel_group(), + mcore_parallel.get_expert_model_parallel_group(), + expert_tensor_parallel_group, + ) + self.input_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_input) + self.weight_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_weight) + self.output_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_output) + self.output_quantizer.disable() + + # Memorize the original weight.dtype for modelopt_post_restore given that + # the dtype can change later. + self.original_weight_dtype = None if self.weight0 is None else self.weight0.dtype + + @property + def functionals_to_replace(self): + original_forward = te_grouped_linear._GroupedLinear.forward + + def te_grouped_quantized_linear_fn(ctx, inp, m_splits, *args): + num_gemms = len(m_splits) + weights_and_biases = args[-2 * num_gemms :] + weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] + quantized_inputs = self.input_quantizer(inp) + quantized_weights = [self.weight_quantizer(weight) for weight in weights] + + output = original_forward( + ctx, + quantized_inputs, + m_splits, + *args[: -2 * num_gemms], + *quantized_weights, + *biases, + ) + return self.output_quantizer(output) + + return [ + ( + te_grouped_linear._GroupedLinear, + "forward", + te_grouped_quantized_linear_fn, + ), + ] + + def modelopt_post_restore(self, prefix: str = ""): + # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to + # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning + # self.weight0 to self.weight to run the quantizer states initialization. + self.weight = self.weight0 + super().modelopt_post_restore(prefix=prefix) + # Revert the weight to None after post_restore to avoid the weight being None during forward pass. + self.weight = None + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in + # sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for + # TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] + # for modelopt checkpoint restore + filtered_state_dict = { + k: v + for k, v in state_dict.items() + if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms)) + } + return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs) + + def _process_quantizer_amax(self, k, v, quantizer_state_dict): + if v.ndim == 4: + quantizer_state_dict[k] = v.squeeze(1).squeeze(-1) + else: + quantizer_state_dict[k] = v.view(-1, 1) if v.numel() > 1 else v.view(-1) + + +@QuantModuleRegistry.register( + {megatron_te.TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"} +) +class _QuantTEGroupedColumnParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear): + _is_column_parallel = True + + +@QuantModuleRegistry.register( + {megatron_te.TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"} +) +class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronColumnParallelLinear): + _is_row_parallel = True diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index 76965dc0e..3d80d1464 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -241,13 +241,28 @@ def __init__( self, data_parallel_group: torch.distributed.ProcessGroup | int | None = None, tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1, + expert_model_parallel_group: torch.distributed.ProcessGroup | int | None = -1, + expert_tensor_parallel_group: torch.distributed.ProcessGroup | int | None = None, ): """Initialize the parallel state.""" self.data_parallel_group = DistributedProcessGroup(data_parallel_group) self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group) + self.expert_model_parallel_group = DistributedProcessGroup(expert_model_parallel_group) + self.expert_tensor_parallel_group = None + if expert_tensor_parallel_group is not None: + self.expert_tensor_parallel_group = DistributedProcessGroup( + expert_tensor_parallel_group + ) def __repr__(self) -> str: - return f"data_parallel_group: {self.data_parallel_group}, tensor_parallel_group: {self.tensor_parallel_group}" + parallel_groups = ( + f"data_parallel_group: {self.data_parallel_group}, " + f"tensor_parallel_group: {self.tensor_parallel_group}, " + f"expert_model_parallel_group: {self.expert_model_parallel_group}" + ) + if self.expert_tensor_parallel_group: + parallel_groups += f"expert_tensor_parallel_group: {self.expert_tensor_parallel_group}" + return parallel_groups def get_group(ranks: list[int]): diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index 9c1dd1bf7..5facf89a7 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -37,6 +37,8 @@ ) from megatron.core.models.mamba import MambaModel from megatron.core.parallel_state import ( + get_expert_model_parallel_group, + get_expert_tensor_parallel_group, initialize_model_parallel, is_pipeline_first_stage, is_pipeline_last_stage, @@ -48,12 +50,14 @@ from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig +import modelopt.torch.quantization as mtq from modelopt.torch.export.unified_export_megatron import import_mcore_gpt_from_hf from modelopt.torch.opt.plugins.mcore_dist_checkpointing import ( restore_sharded_modelopt_state, save_sharded_modelopt_state, ) from modelopt.torch.utils import to_empty_if_meta_device +from modelopt.torch.utils.distributed import DistributedProcessGroup try: from megatron.core.extensions.transformer_engine import TENorm @@ -83,9 +87,10 @@ class MegatronModel(MegatronModule): - def __init__(self, tp_size: int = 1, use_te_norm: bool = False): + def __init__(self, tp_size: int = 1, cp_size: int = 1, use_te_norm: bool = False): config = TransformerConfig( tensor_model_parallel_size=tp_size, + context_parallel_size=cp_size, pipeline_model_parallel_size=1, normalization="LayerNorm", # Unused parameters below are set to avoid ZeroDivisionError in __post_init__ @@ -126,13 +131,19 @@ def forward(self, x): x = x[0] return x - def get_dummy_input(self) -> torch.Tensor: + def get_dummy_input(self, seed: int | None = None) -> torch.Tensor: + if seed is not None: + gen = torch.Generator() + gen.manual_seed(seed) + return torch.randn(1, 4, 32, generator=gen) return torch.randn(1, 4, 32) def get_mcore_gpt_model( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + expert_model_parallel_size: int = 1, + expert_tensor_parallel_size: int = 1, initialize_megatron: bool = False, *, num_layers: int = 2, @@ -148,7 +159,10 @@ def get_mcore_gpt_model( normalization: str = "LayerNorm", transformer_impl: str = "modelopt" if HAS_TE else "local", use_cpu_initialization: bool = False, + num_moe_experts: int | None = None, + moe_grouped_gemm: bool = False, bf16: bool = True, + use_te: bool = False, ) -> GPTModel: assert activation_func in ["swiglu", "squared_relu"] assert normalization in ["LayerNorm", "RMSNorm"] @@ -156,7 +170,12 @@ def get_mcore_gpt_model( print(f"Using `{transformer_impl=}` model spec for building GPT Model.") if initialize_megatron: - initialize_for_megatron(tensor_model_parallel_size, pipeline_model_parallel_size) + initialize_for_megatron( + tensor_model_parallel_size, + pipeline_model_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + expert_tensor_parallel_size=expert_tensor_parallel_size, + ) def squared_relu(x): return torch.pow(F.relu(x), 2) @@ -164,7 +183,10 @@ def squared_relu(x): config = TransformerConfig( tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + expert_tensor_parallel_size=expert_tensor_parallel_size, sequence_parallel=False, + moe_grouped_gemm=moe_grouped_gemm, num_layers=num_layers, num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, @@ -172,6 +194,7 @@ def squared_relu(x): num_attention_heads=num_attention_heads, num_query_groups=num_query_groups, ffn_hidden_size=ffn_hidden_size, + num_moe_experts=num_moe_experts, activation_func=squared_relu if activation_func == "squared_relu" else F.silu, normalization=normalization, gated_linear_unit=(activation_func == "swiglu"), @@ -183,7 +206,12 @@ def squared_relu(x): if transformer_impl == "local": assert HAS_APEX, "Apex not installed" - transformer_layer_spec = get_gpt_layer_local_spec(normalization=normalization) + transformer_layer_spec = get_gpt_layer_local_spec( + num_experts=num_moe_experts, + normalization=normalization, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, + ) else: assert HAS_TE, "Transformer Engine not installed" transformer_layer_spec = ( @@ -202,6 +230,7 @@ def squared_relu(x): share_embeddings_and_output_weights=False, position_embedding_type="rope", ) + if bf16: model = model.to(torch.bfloat16) @@ -383,13 +412,24 @@ def run_mcore_inference_with_dummy_input( def initialize_for_megatron( - tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=1234 + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + seed=1234, + context_parallel_size=1, + expert_model_parallel_size=1, + expert_tensor_parallel_size=None, ): """Initialize Megatron model parallelism. NOTE: If used in a non-spawned process, make sure to call `megatron.core.parallel_state.destroy_model_parallel()`. """ - initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size) + initialize_model_parallel( + tensor_model_parallel_size, + pipeline_model_parallel_size, + context_parallel_size=context_parallel_size, + expert_tensor_parallel_size=expert_tensor_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + ) model_parallel_cuda_manual_seed(seed) @@ -455,3 +495,196 @@ def convert_maybe_fp8(v): assert torch.allclose(logits_ref, logits_test), ( f"diff: {logits_diff.max()} ref: {logits_ref}, test: {logits_test}" ) + + +def compare_model_outputs(grouped_model, non_grouped_model, forward_fn, tolerance=1e-6): + """Compare outputs of grouped and non-grouped models.""" + # Set both models to eval mode + grouped_model.eval() + non_grouped_model.eval() + + with torch.no_grad(): + # Get outputs from both models + grouped_output = forward_fn(grouped_model) + non_grouped_output = forward_fn(non_grouped_model) + + # Compare outputs + if isinstance(grouped_output, tuple): + grouped_output = grouped_output[0] + if isinstance(non_grouped_output, tuple): + non_grouped_output = non_grouped_output[0] + + output_close = torch.allclose( + grouped_output, non_grouped_output, atol=tolerance, rtol=tolerance + ) + return output_close + + +def sync_amax(model): + amax_dict = { + "linear_fc1.input_quantizer": {}, + "linear_fc1.weight_quantizer": {}, + "linear_fc2.input_quantizer": {}, + "linear_fc2.weight_quantizer": {}, + } + for name, module in model.named_modules(): + if not isinstance(module, mtq.nn.TensorQuantizer): + continue + if not hasattr(module, "_amax"): + continue + if "local_experts" not in name: + continue + expert_name, local_expert_name = name.split("local_experts") + for key in amax_dict: + if key in local_expert_name: + amax_dict[key][expert_name] = max(amax_dict[key].get(expert_name, 0), module.amax) + + for name, module in model.named_modules(): + if not isinstance(module, mtq.nn.TensorQuantizer): + continue + if not hasattr(module, "_amax"): + continue + if "local_experts" not in name: + continue + expert_name, local_expert_name = name.split("local_experts") + for key in amax_dict: + if key in local_expert_name: + module.amax = amax_dict[key][expert_name] + + +def copy_weights_from_grouped_to_non_grouped(grouped_model, non_grouped_model): + """Copy weights from grouped MoE model to non-grouped MoE model.""" + grouped_state = grouped_model.state_dict() + non_grouped_state = non_grouped_model.state_dict() + + # Map grouped weights to non-grouped weights + weight_mapping = {} + non_grouped_key_template = "decoder.layers.{}.mlp.experts.local_experts.{}.linear_fc{}.weight" + for key, value in grouped_state.items(): + if "experts.linear_fc" in key and "weight" in key: + # Extract expert index from grouped weight name + # Format: decoder.layers.X.mlp.experts.linear_fcY.weightZ + parts = key.split(".") + layer_idx = parts[2] # X + fc_idx = parts[5] # Y (linear_fc1 or linear_fc2) + weight_idx = parts[6] # Z (weight0, weight1, etc.) + + # Map to non-grouped format: decoder.layers.X.mlp.experts.local_experts.Y.linear_fcZ.weight + expert_idx = weight_idx.replace("weight", "") + non_grouped_key = non_grouped_key_template.format(layer_idx, expert_idx, fc_idx[-1]) + weight_mapping[non_grouped_key] = value + elif isinstance(value, torch.Tensor): + weight_mapping[key] = value + + # Copy weights to non-grouped model + for non_grouped_key in non_grouped_state: + if non_grouped_key in weight_mapping: + non_grouped_state[non_grouped_key] = weight_mapping[non_grouped_key].clone() + + non_grouped_model.load_state_dict(non_grouped_state) + + +def compare_amax_sync_across_expert_parallel(model): + """ + Test if amax values are synchronized across expert parallel groups. + + Returns True if synchronized, False otherwise. + """ + + ep_group = get_expert_model_parallel_group(check_initialized=False) + etp_group = get_expert_tensor_parallel_group(check_initialized=False) + + # Check if we have either expert model parallel or expert tensor parallel + has_expert_parallel = (ep_group is not None and ep_group.size() > 1) or ( + etp_group is not None and etp_group.size() > 1 + ) + + assert has_expert_parallel, "No expert parallelism detected" + # Collect amax values from expert quantizers only + expert_amax_values = {} + for name, module in model.named_modules(): + if isinstance(module, mtq.nn.TensorQuantizer) and hasattr(module, "_amax"): + # Check for both grouped and non-grouped MoE patterns + if "local_experts" in name or ("experts" in name and "linear_fc" in name): + expert_amax_values[name] = ( + module.amax.item() if hasattr(module.amax, "item") else module.amax + ) + + # Early return if no expert quantizers found + assert expert_amax_values, "No expert quantizers found" + + # Gather amax values from all ranks + world_size = torch.distributed.get_world_size() + all_amax_values = [None] * world_size + torch.distributed.all_gather_object(all_amax_values, expert_amax_values) + + # Group quantizers by type (ignoring specific expert indices) and check sync + expert_quantizers = {} + for rank_idx, rank_amax in enumerate(all_amax_values): + for name, amax_val in rank_amax.items(): + # Create quantizer type key by normalizing the name + if "local_experts" in name: + # Non-grouped MoE: replace expert index with wildcard + import re + + quantizer_type = re.sub(r"local_experts\.\d+", "local_experts.*", name) + else: + # Grouped MoE: use the name as-is since experts are grouped + quantizer_type = name + + if quantizer_type not in expert_quantizers: + expert_quantizers[quantizer_type] = {} + expert_quantizers[quantizer_type][rank_idx] = amax_val + + # Check synchronization - fail fast on first inconsistency + for quantizer_type, rank_values in expert_quantizers.items(): + if len(rank_values) > 1: # Only check if we have multiple ranks + values = list(rank_values.values()) + max_diff = max(values) - min(values) + + if max_diff > 1e-6: # Allow for small floating point differences + return False + + return True + + +def disable_distributed_parallel_sync(model, expert_parallel_type: str = "tensor"): + """Disable distributed parallel synchronization groups.""" + module_parallel_groups = {} + + for name, module in model.named_modules(): + if isinstance(module, mtq.nn.QuantModule): + # Store original groups + module_parallel_groups[name] = { + "data_parallel_group": module.parallel_state.data_parallel_group, + "expert_tensor_parallel_group": module.parallel_state.expert_tensor_parallel_group, + "expert_model_parallel_group": module.parallel_state.expert_model_parallel_group, + } + + # Disable groups + module.parallel_state.data_parallel_group = DistributedProcessGroup(-1) + + if expert_parallel_type in ["tensor", "both"]: + module.parallel_state.expert_tensor_parallel_group = DistributedProcessGroup(-1) + if expert_parallel_type in ["model", "both"]: + module.parallel_state.expert_model_parallel_group = DistributedProcessGroup(-1) + + return module_parallel_groups + + +def enable_distributed_parallel_sync( + model, module_parallel_groups, expert_parallel_type: str = "tensor" +): + """Re-enable distributed parallel synchronization groups.""" + for name, module in model.named_modules(): + if isinstance(module, mtq.nn.QuantModule) and name in module_parallel_groups: + groups = module_parallel_groups[name] + + if expert_parallel_type in ["tensor", "both"]: + module.parallel_state.expert_tensor_parallel_group = groups[ + "expert_tensor_parallel_group" + ] + if expert_parallel_type in ["model", "both"]: + module.parallel_state.expert_model_parallel_group = groups[ + "expert_model_parallel_group" + ] diff --git a/tests/_test_utils/torch_quantization/quantize_common.py b/tests/_test_utils/torch_quantization/quantize_common.py index 505eac2b6..6dbb5b213 100644 --- a/tests/_test_utils/torch_quantization/quantize_common.py +++ b/tests/_test_utils/torch_quantization/quantize_common.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +from unittest.mock import patch import pytest import torch @@ -22,7 +23,9 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq +import modelopt.torch.quantization.model_calib as model_calib_module # needed for patching awq_lite from modelopt.torch.quantization.backends.gemm_registry import enable_real_quant_gemm +from modelopt.torch.quantization.nn.modules.tensor_quantizer import SequentialQuantizer from modelopt.torch.quantization.utils import is_quantized_linear from modelopt.torch.utils import torch_to @@ -116,8 +119,26 @@ def save_restore_test(model_cls, device, quant_config, compress=False, version=N mto.restore_from_modelopt_state(model_ref, state_dict) -def tensor_parallel_test_helper(model, config, tp_group, dp_group): - # The input to fist layer, the column parallel should be the same across all tp ranks +def _reduce_quantizer_attr(quantizer, attr: str, op=dist.ReduceOp.MAX, group=None): + quantizer_attr = getattr(quantizer, attr).clone() + print("quantizer.attr before reduce", getattr(quantizer, attr)) + dist.all_reduce(quantizer_attr, op=op, group=group) + print("quantizer.attr after reduce", getattr(quantizer, attr)) + print("quantizer_attr after reduce", quantizer_attr) + assert torch.allclose(quantizer_attr, getattr(quantizer, attr)) + + +original_awq_lite = model_calib_module.awq_lite + + +def _debug_awq_lite(model, forward_loop, alpha_step=0.1, debug=True): + """Function to mock awq_lite function to always use debug=True for testing""" + return original_awq_lite(model, forward_loop, alpha_step, debug=True) + + +@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) +def tensor_parallel_test_helper(model, config, tp_group, mock_awq_lite): + # The input to first layer, the column parallel should be the same across all tp ranks calib_data = model.get_dummy_input().cuda() dist.all_reduce(calib_data, op=dist.ReduceOp.AVG, group=tp_group) @@ -125,31 +146,133 @@ def forward_loop(model): model(calib_data) model = mtq.quantize(model, config, forward_loop) - # Sanity check forward_loop(model) if config in [mtq.INT8_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG, mtq.INT8_SMOOTHQUANT_CFG]: # Lets check the amax for row parallel input quantizer; it should be the same across all tp ranks - activation_amax = model.fc2.input_quantizer.amax.clone() - dist.all_reduce(activation_amax, op=dist.ReduceOp.MAX, group=tp_group) - assert torch.allclose(activation_amax, model.fc2.input_quantizer.amax) - + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group) # Lets check the row parallel weight amax; it should be the same across all tp ranks - weight_amax = model.fc2.weight_quantizer.amax.clone() - dist.all_reduce(weight_amax, op=dist.ReduceOp.MAX, group=tp_group) - assert torch.allclose(weight_amax, model.fc2.weight_quantizer.amax) + _reduce_quantizer_attr( + model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=tp_group + ) if config in [mtq.INT8_SMOOTHQUANT_CFG, mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: # Lets check the column parallel pre_quant_scale; it should be the same across all tp ranks input_quantizer = model.fc1.input_quantizer - pre_quant_scale = input_quantizer.pre_quant_scale.clone() - dist.all_reduce(pre_quant_scale, op=dist.ReduceOp.MAX, group=tp_group) - assert torch.allclose(pre_quant_scale, input_quantizer.pre_quant_scale) + _reduce_quantizer_attr( + input_quantizer, "pre_quant_scale", dist.ReduceOp.MAX, group=tp_group + ) + + if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: + # Check activation scale for AWQ lite + _reduce_quantizer_attr( + model.fc1.awq_lite, + "act_scale", + dist.ReduceOp.AVG, + group=tp_group, + ) dist.destroy_process_group() +@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) +def dp_cp_parallel_test_helper(model, config, group, mock_awq_lite): + calib_data = model.get_dummy_input().cuda() + + def forward_loop(model): + model(calib_data) + + model = mtq.quantize(model, config, forward_loop) + + # Sanity check + forward_loop(model) + + # Input quantizer amax + if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: + _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX, group=group) + + # Weight quantizer amax + if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc1.weight_quantizer: + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) + else: + _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) + if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc2.weight_quantizer: + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX, group=group) + else: + _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX, group=group) + + if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: + # Check act scale + _reduce_quantizer_attr( + model.fc1.awq_lite, + "act_scale", + dist.ReduceOp.AVG, + group=group, + ) + _reduce_quantizer_attr( + model.fc2.awq_lite, + "act_scale", + dist.ReduceOp.AVG, + group=group, + ) + + +@patch("modelopt.torch.quantization.model_calib.awq_lite", side_effect=_debug_awq_lite) +def data_tensor_context_parallel_test_helper(model, config, dp_group, tp_group, mock_awq_lite): + # Calib data should be same across each DP rank + dp_rank = dist.get_rank(group=dp_group) + calib_data = model.get_dummy_input(seed=dp_rank).cuda() + + def forward_loop(model): + model(calib_data) + + model = mtq.quantize(model, config, forward_loop) + + def _reduce_quantizer_attr(quantizer, attr=str, op=dist.ReduceOp.MAX): + quantizer_attr = getattr(quantizer, attr).clone() + + # Perform all-reduce operations + dist.all_reduce(quantizer_attr, op=op, group=tp_group) + + dist.all_reduce(quantizer_attr, op=op, group=dp_group) + + assert torch.allclose(quantizer_attr, getattr(quantizer, attr)), getattr(quantizer, attr) + + # Input quantizer amax + if config not in [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT4_AWQ_CFG]: + _reduce_quantizer_attr(model.fc1.input_quantizer, "amax", dist.ReduceOp.MAX) + _reduce_quantizer_attr(model.fc2.input_quantizer, "amax", dist.ReduceOp.MAX) + + # Per-tensor quantization (FP8/NVFP4) expects same amax across row and column parallel ranks + # Channel-wise (INT8) only expects same amax across row parallel ranks + # Block-wise quantization does not expect same amax across row and column parallel ranks + if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG]: + if isinstance(model.fc1.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc1.weight_quantizer: + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) + else: + _reduce_quantizer_attr(model.fc1.weight_quantizer, "amax", dist.ReduceOp.MAX) + + if config in [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT8_DEFAULT_CFG]: + if isinstance(model.fc2.weight_quantizer, SequentialQuantizer): + for quantizer in model.fc2.weight_quantizer: + _reduce_quantizer_attr(quantizer, "amax", dist.ReduceOp.MAX) + else: + _reduce_quantizer_attr(model.fc2.weight_quantizer, "amax", dist.ReduceOp.MAX) + + # Check act scale + if config in [mtq.INT4_AWQ_CFG, mtq.W4A8_AWQ_BETA_CFG]: + _reduce_quantizer_attr( + model.fc1.awq_lite, + "act_scale", + dist.ReduceOp.AVG, + ) + + def auto_quantize_helper(model): model, search_state = mtq.auto_quantize( model, diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index 208fb2287..d1ba9dd47 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -34,6 +34,18 @@ def need_2_gpus(): pytest.skip("Need at least 2 GPUs to run this test") +@pytest.fixture +def need_8_gpus(): + if torch.cuda.device_count() < 8: + pytest.skip("Need at least 8 GPUs to run this test") + + +@pytest.fixture +def need_4_gpus(): + if torch.cuda.device_count() < 4: + pytest.skip("Need at least 4 GPUs to run this test") + + @pytest.fixture(scope="module") def set_torch_dtype(request): orig_dtype = torch.get_default_dtype() diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index c3630e028..309755dbd 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -21,16 +21,24 @@ from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job from _test_utils.torch_dist.plugins.megatron_common import ( MegatronModel, + compare_amax_sync_across_expert_parallel, + compare_model_outputs, + copy_weights_from_grouped_to_non_grouped, + disable_distributed_parallel_sync, + enable_distributed_parallel_sync, get_mcore_gpt_model, initialize_for_megatron, run_mcore_inference, sharded_state_dict_test_helper, + sync_amax, ) from _test_utils.torch_misc import set_seed from _test_utils.torch_quantization.models import RegularQuantModelForTP from _test_utils.torch_quantization.quant_utils import get_model_size from _test_utils.torch_quantization.quantize_common import ( auto_quantize_helper, + data_tensor_context_parallel_test_helper, + dp_cp_parallel_test_helper, tensor_parallel_test_helper, ) from packaging.version import Version @@ -44,6 +52,7 @@ get_tensor_model_parallel_group, ) from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP import modelopt import modelopt.torch.opt as mto @@ -92,13 +101,12 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1): destroy_model_parallel() +# 1. Tensor Parallel Test def _test_tensor_parallel_helper(config, rank, size): initialize_for_megatron(tensor_model_parallel_size=2, seed=SEED) - model = MegatronModel(size).cuda() + model = MegatronModel(tp_size=size).cuda() - tensor_parallel_test_helper( - model, config, get_tensor_model_parallel_group(), get_data_parallel_group() - ) + tensor_parallel_test_helper(model, config, get_tensor_model_parallel_group()) @pytest.mark.parametrize( @@ -119,38 +127,141 @@ def test_tensor_parallel(need_2_gpus, config): ) -def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): +# 2. Data Parallel Test +def _test_data_parallel_helper(config, rank, size): + initialize_for_megatron(seed=SEED + rank) # modify seed so data is different across ranks + model = MegatronModel().cuda() + + dp_cp_parallel_test_helper(model, config, get_data_parallel_group()) + + +@pytest.mark.parametrize( + "config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.W4A8_AWQ_BETA_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + mtq.INT4_AWQ_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_data_parallel(need_2_gpus, config): + spawn_multiprocess_job(size=2, job=partial(_test_data_parallel_helper, config), backend="nccl") + + +# 3. Context Parallel Test +def _test_context_parallel_helper(config, rank, size): + initialize_for_megatron( + context_parallel_size=size, seed=SEED + rank + ) # modify seed so data is different across ranks + model = MegatronModel(cp_size=size).cuda() + + dp_cp_parallel_test_helper(model, config, get_data_parallel_group(with_context_parallel=True)) + + +@pytest.mark.parametrize( + "config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.W4A8_AWQ_BETA_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + mtq.INT4_AWQ_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_context_parallel(need_2_gpus, config): + spawn_multiprocess_job( + size=2, job=partial(_test_context_parallel_helper, config), backend="nccl" + ) + + +# 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs) +def _test_data_tensor_context_parallel_helper(config, rank, size): + initialize_for_megatron(tensor_model_parallel_size=2, context_parallel_size=2, seed=SEED + rank) + model = MegatronModel(tp_size=2, cp_size=2).cuda() + + data_tensor_context_parallel_test_helper( + model, + config, + get_data_parallel_group(with_context_parallel=True), + get_tensor_model_parallel_group(), + ) + + +@pytest.mark.parametrize( + "config", + [ + mtq.INT8_DEFAULT_CFG, + mtq.FP8_DEFAULT_CFG, + mtq.W4A8_AWQ_BETA_CFG, + mtq.INT8_SMOOTHQUANT_CFG, + mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, + mtq.INT4_AWQ_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_data_tensor_context_parallel(need_8_gpus, config): + spawn_multiprocess_job( + size=8, job=partial(_test_data_tensor_context_parallel_helper, config), backend="nccl" + ) + + +def _gpt_model_provider( + tp_size: int, + hidden_size=256, + vocab_size=64, + num_moe_experts=None, + moe_grouped_gemm=False, + meta_device=False, + ep_size=1, + etp_size=None, + use_te=False, +): """Build the model.""" if meta_device: with torch.device("meta"): gpt_model = get_mcore_gpt_model( tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, num_layers=4, ffn_hidden_size=None, - num_attention_heads=4, + num_attention_heads=8, activation_func="squared_relu", transformer_impl="local", hidden_size=hidden_size, vocab_size=vocab_size, use_cpu_initialization=meta_device, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, ) else: gpt_model = get_mcore_gpt_model( tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, num_layers=4, ffn_hidden_size=None, - num_attention_heads=4, + num_attention_heads=8, activation_func="squared_relu", transformer_impl="local", hidden_size=hidden_size, vocab_size=vocab_size, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, ).cuda() return gpt_model.eval() def _test_sharded_state_dict( - tmp_path, config, hidden_size, modelopt_version, compress, meta_device, rank, size + tmp_path, config, hidden_size, modelopt_version, compress, meta_device, moe_config, rank, size ): # Must disable output_layer quantization since output_layer amax cannot be restore via # sharded_state_dict. All output_layer quantizers state are removed. @@ -160,10 +271,42 @@ def _test_sharded_state_dict( mto.conversion.__version__ = modelopt_version mtq.plugins.megatron.__version__ = modelopt_version - initialize_for_megatron(tensor_model_parallel_size=size, seed=SEED) + tp_size = moe_config.get("tp_size", size) + ep_size = moe_config.get("ep_size", 1) + etp_size = moe_config.get("etp_size", None) + num_moe_experts = moe_config.get("num_moe_experts", None) + moe_grouped_gemm = moe_config.get("moe_grouped_gemm", False) + use_te = moe_config.get("use_te", False) + + initialize_for_megatron( + tensor_model_parallel_size=tp_size, + seed=SEED, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, + ) - model_ref = _gpt_model_provider(size, hidden_size, vocab_size=256) - model_test = _gpt_model_provider(size, hidden_size, vocab_size=256, meta_device=meta_device) + model_ref = _gpt_model_provider( + tp_size, + hidden_size, + vocab_size=256, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, + meta_device=meta_device, + ep_size=ep_size, + etp_size=etp_size, + ) + model_test = _gpt_model_provider( + tp_size, + hidden_size, + vocab_size=256, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, + meta_device=meta_device, + ep_size=ep_size, + etp_size=etp_size, + ) prompt_tokens = torch.randint( 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) @@ -244,7 +387,9 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device) spawn_multiprocess_job( size=size, - job=partial(_test_sharded_state_dict, tmp_path, config, 256, None, compress, meta_device), + job=partial( + _test_sharded_state_dict, tmp_path, config, 256, None, compress, meta_device, {} + ), backend="nccl", ) @@ -263,7 +408,7 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device) def test_heterogenous_sharded_state_dict(need_2_gpus, tmp_path, config): spawn_multiprocess_job( size=2, - job=partial(_test_sharded_state_dict, tmp_path, config, 256, None, False, False), + job=partial(_test_sharded_state_dict, tmp_path, config, 256, None, False, False, {}), backend="nccl", ) @@ -284,7 +429,7 @@ def test_sharded_state_dict_old_checkpoints(need_2_gpus, tmp_path, config, model spawn_multiprocess_job( size=2, job=partial( - _test_sharded_state_dict, tmp_path, config, 256, modelopt_version, False, False + _test_sharded_state_dict, tmp_path, config, 256, modelopt_version, False, False, {} ), backend="nccl", ) @@ -367,3 +512,221 @@ def forward_fn(model): def test_fp8_real_quantize(): size = torch.cuda.device_count() spawn_multiprocess_job(size=size, job=_test_fp8_real_quantize_helper, backend="nccl") + + +@pytest.mark.parametrize( + "config", + [ + mtq.FP8_DEFAULT_CFG, + mtq.NVFP4_DEFAULT_CFG, + ], +) +def test_moe_sharded_state_dict(need_8_gpus, tmp_path, config): + size = torch.cuda.device_count() + # TODO: Meta device doesn't work with TE + # TODO: Add support for compress=True for TEGroupedMLP + moe_config = { + "tp_size": 2, + "ep_size": 2, + "etp_size": 2, + "num_moe_experts": 4, + "moe_grouped_gemm": True, + "use_te": True, + } + spawn_multiprocess_job( + size=size, + job=partial( + _test_sharded_state_dict, + tmp_path, + config, + 256, + None, + False, + False, + moe_config, + ), + backend="nccl", + ) + + +def _test_grouped_vs_non_grouped_quantize_helper(tp_size, ep_size, etp_size, rank, size): + """Test that grouped and non-grouped MoE models produce similar amax values.""" + initialize_for_megatron( + tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, + seed=SEED, + ) + + # Create input + prompt_tokens = torch.randint(0, 64, (2, 16)).cuda() + + def forward_fn(model): + return megatron_prefill(model, prompt_tokens) + + # Create grouped MoE model + grouped_model = _gpt_model_provider( + tp_size=tp_size, + ep_size=ep_size, + etp_size=etp_size, + hidden_size=32, + moe_grouped_gemm=True, + use_te=True, + num_moe_experts=4, + ) + num_grouped_mlp = sum(isinstance(module, TEGroupedMLP) for module in grouped_model.modules()) + assert num_grouped_mlp == 4, ( + f"TEGrupedMoEModel has {num_grouped_mlp} TEGroupedMLP modules, it should have 4" + ) + + # Create non-grouped MoE model + non_grouped_model = _gpt_model_provider( + tp_size=tp_size, + ep_size=ep_size, + etp_size=etp_size, + hidden_size=32, + moe_grouped_gemm=False, + num_moe_experts=4, + ) + num_sequential_mlp = sum( + isinstance(module, SequentialMLP) for module in non_grouped_model.modules() + ) + assert num_sequential_mlp == 4, ( + f"SequentialMoEModel has {num_sequential_mlp} SequentialMLP modules, it should have 4" + ) + # Copy weights from grouped to non-grouped model + copy_weights_from_grouped_to_non_grouped(grouped_model, non_grouped_model) + + output_comparison_before = compare_model_outputs(grouped_model, non_grouped_model, forward_fn) + assert output_comparison_before, "Outputs are not close before quantization" + + # Quantize grouped model + mtq.quantize(grouped_model, mtq.FP8_DEFAULT_CFG, forward_fn) + + # Quantize non-grouped model + mtq.quantize(non_grouped_model, mtq.FP8_DEFAULT_CFG, forward_fn) + + # sync amax across expert parallel + # TODO: Remove once amax sync is enabled by default for SequentialGroupedMLP + sync_amax(non_grouped_model) + + # Compare model outputs after quantization + output_comparison_after = compare_model_outputs(grouped_model, non_grouped_model, forward_fn) + assert output_comparison_after, "Outputs are not close after quantization" + + +def test_grouped_vs_non_grouped_quantize(): + """Test that grouped and non-grouped MoE models produce similar quantized models.""" + import time + + size = torch.cuda.device_count() + if size < 4: + pytest.skip("Requires at least 4 GPUs for expert parallel test") + + # Add small delay to avoid port conflicts + time.sleep(0.1) + + spawn_multiprocess_job( + size=size, + job=partial(_test_grouped_vs_non_grouped_quantize_helper, 1, 2, 2), + backend="nccl", + ) + + +def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm, rank, size): + """Test expert parallel synchronization with different configurations.""" + initialize_for_megatron( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, + seed=SEED, + ) + + # Create model with expert parallelism + model = _gpt_model_provider( + tp_size=1, + ep_size=ep_size, + etp_size=etp_size, + hidden_size=256, + moe_grouped_gemm=moe_grouped_gemm, + use_te=moe_grouped_gemm, + num_moe_experts=4, + ) + + # Create input and forward function + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + + def forward_fn(model): + return megatron_prefill(model, prompt_tokens) + + # Run forward pass and quantize + forward_fn(model) + config = mtq.FP8_DEFAULT_CFG + model = mtq.quantize(model, config, forward_fn) + + # Check initial sync status + initial_sync = compare_amax_sync_across_expert_parallel(model) + assert initial_sync, ( + "Inconsistent amax across expert parallel ranks, Amax should be synchronized across expert parallel ranks" + ) + + # Create inconsistent amax values + cur_rank = torch.distributed.get_rank() + for name, module in model.named_modules(): + if isinstance(module, mtq.nn.TensorQuantizer): + # Check if this is an expert quantizer + is_expert_quantizer = ( + "local_experts" in name # Non-grouped MoE + or ("experts" in name and "linear_fc" in name) # Grouped MoE + ) + + if is_expert_quantizer and hasattr(module, "_amax"): + # Create rank-specific amax values to simulate missing sync + rank_offset = cur_rank * 0.1 + module.amax = module.amax + rank_offset + + # Determine expert parallel type + expert_parallel_type = ( + "both" if ep_size > 1 and etp_size > 1 else ("model" if ep_size > 1 else "tensor") + ) + + # Disable parallel groups and test inconsistency + module_parallel_groups = disable_distributed_parallel_sync(model, expert_parallel_type) + mtq.model_calib.max_calibrate(model, forward_fn) + + inconsistent_sync = compare_amax_sync_across_expert_parallel(model) + assert not inconsistent_sync, ( + "Consistent amax across expert parallel ranks, " + "Amax should not be synchronized across expert parallel ranks since expert parallel is disabled" + ) + + # Re-enable parallel groups and test synchronization + enable_distributed_parallel_sync(model, module_parallel_groups, expert_parallel_type) + mtq.model_calib.max_calibrate(model, forward_fn) + + final_sync = compare_amax_sync_across_expert_parallel(model) + assert final_sync, ( + "Inconsistent amax across expert parallel ranks, Amax should be synchronized across expert parallel ranks" + ) + + +@pytest.mark.parametrize(("ep_size", "etp_size"), [(1, 2), (2, 1), (2, 2)]) +@pytest.mark.parametrize("moe_grouped_gemm", [True, False]) +def test_expert_parallel_sync(need_4_gpus, ep_size, etp_size, moe_grouped_gemm): + """Test expert model parallel synchronization.""" + import time + + size = torch.cuda.device_count() + total_size = ep_size * etp_size + if size < total_size: + pytest.skip(f"Requires at least {total_size} GPUs for expert model parallel test") + + # Add small delay to avoid port conflicts + time.sleep(0.1) + + spawn_multiprocess_job( + size=total_size, + job=partial(_test_expert_model_parallel_amax_sync, ep_size, etp_size, moe_grouped_gemm), + backend="nccl", + )