diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index f987efcd6..c1e6feb06 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -80,21 +80,22 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis if not distributed_sync: return - def sync_quantizer_amax_across_dp(quantizer, parallel_state): - """Synchronize the amax across all ranks in the data parallel group.""" + def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): + """Synchronize the amax across all ranks in the data parallel and expert parallel groups.""" if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: - sync_quantizer_amax_across_dp(_q, parallel_state) + sync_quantizer_amax_across_dp_ep(_q, parallel_state) return if getattr(quantizer, "_amax", None) is not None: quantizer.sync_amax_across_distributed_group(parallel_state.data_parallel_group) + quantizer.sync_amax_across_distributed_group(parallel_state.expert_model_parallel_group) # TODO: create sync_bias_across_distributed_group for name, module in model.named_modules(): if isinstance(module, QuantModule): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): - sync_quantizer_amax_across_dp(child, module.parallel_state) + sync_quantizer_amax_across_dp_ep(child, module.parallel_state) # TP sync: # Objective: the quantization parameters when TP = 8 then changed to TP=4 then back to TP=8 should be the same @@ -117,6 +118,7 @@ def sync_quantizer_amax_across_tp( # Syncing amax across TP for sequential quantizer if isinstance(quantizer, SequentialQuantizer): for _q in quantizer: + # Syncing amax across TP for sequential quantizer sync_quantizer_amax_across_tp( _q, linear_name, quantizer_type, axes_for_sync, parallel_state ) @@ -174,6 +176,10 @@ def sync_quantizer_amax_across_tp( parallel_state=module.parallel_state, ) + for name, module in model.named_modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() + def enable_stats_collection(model: nn.Module): """Enable stats collection for all quantizers in the model.""" diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 85784d2fe..fd6b0660d 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -22,6 +22,7 @@ import megatron.core.parallel_state as mcore_parallel import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp +import megatron.core.transformer.moe.experts as megatron_moe import torch from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region @@ -40,6 +41,18 @@ from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelGroupedLinear, + TERowParallelGroupedLinear, + ) + + from .transformer_engine import _QuantTEGroupedLinear + + HAS_TE = True +except ImportError: + HAS_TE = False + logger = logging.getLogger(__name__) __all__ = [] @@ -221,16 +234,19 @@ class _MegatronParallelLinear(_ParallelLinear): ] def _setup(self): - data_parallel_group = None - try: - data_parallel_group = get_data_parallel_group(with_context_parallel=True) - except AssertionError: - logger.warning("Context parallel group is not initialized, using data parallel group") - data_parallel_group = get_data_parallel_group() - self.parallel_state = ParallelState( - data_parallel_group, - mcore_parallel.get_tensor_model_parallel_group(), - ) + if not hasattr(self, "parallel_state") or self.parallel_state is None: + data_parallel_group = None + try: + data_parallel_group = get_data_parallel_group(with_context_parallel=True) + except AssertionError: + logger.warning( + "Context parallel group is not initialized, using data parallel group" + ) + data_parallel_group = get_data_parallel_group() + self.parallel_state = ParallelState( + data_parallel_group, + mcore_parallel.get_tensor_model_parallel_group(), + ) super()._setup() def _process_quantizer_amax(self, k, v, quantizer_state_dict): @@ -472,3 +488,95 @@ class _RealQuantMegatronRowParallelLinear( def forward(self, input, *args, **kwargs): return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs) + + +@QuantModuleRegistry.register({megatron_moe.SequentialMLP: "megatron_moe_SequentialMLP"}) +class _MegatronSequentialMLP(_MegatronMLP): + def _setup(self): + if not hasattr(self, "parallel_state") or self.parallel_state is None: + self.parallel_state = ParallelState( + mcore_parallel.get_expert_data_parallel_group(), + tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), + expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), + ) + + # Initialize parallel state for submodules local_experts.*.linear_fc1 and local_experts.*.linear_fc2 + for expert in self.local_experts: + expert.linear_fc1.parallel_state = self.parallel_state + expert.linear_fc2.parallel_state = self.parallel_state + + def sync_moe_local_experts_amax(self): + """Sync amax across local experts in a SequentialMLP. + + amax across EP and ETP (for RowParallel) are synchronized as part of model_calib.max_calibrate(). + This function is called to synchronize the amax values across local experts s.t. all localexperts will + share the same amax. + """ + torch.distributed.barrier() + # Collect amax from all local experts + amax_dict = {} + for expert in self.local_experts: + for name, module in expert.named_modules(): + if isinstance(module, TensorQuantizer) and module.amax is not None: + stored_amax = amax_dict.get(name) + amax_tensor = module.amax.detach().clone() + amax_dict[name] = ( + amax_tensor + if stored_amax is None + else torch.maximum(stored_amax, amax_tensor) + ) + + # Apply synchronized amax values back to all local experts + for expert in self.local_experts: + for name, module in expert.named_modules(): + if isinstance(module, TensorQuantizer) and module.amax is not None: + module.amax = amax_dict[name].detach().clone().to(module.amax.device) + + +if HAS_TE: + # Quantized subclasses to support TEGroupedMLP quantization + class _QuantMegatronTEGroupedLinear(_QuantTEGroupedLinear, _MegatronParallelLinear): + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + # _sharded_state_dict_grouped adds _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] in + # sharded_state_dict which is same as _extra_state. The _extra_state{gemm_idx} is used for + # TE Fp8 checkpoint, we need to remove the _extra_state{gemm_idx} for gemm_idx:[1, num_gemms] + # for modelopt checkpoint restore + filtered_state_dict = { + k: v + for k, v in state_dict.items() + if not any(k.endswith(f"_extra_state{num}") for num in range(1, self.num_gemms)) + } + return super()._load_from_state_dict(filtered_state_dict, prefix, *args, **kwargs) + + def _process_quantizer_amax(self, k, v, quantizer_state_dict): + assert v.numel() == 1, "TEGroupedLinear only supports per-tensor quantization" + quantizer_state_dict[k] = v.view(-1) + + @QuantModuleRegistry.register( + {TEColumnParallelGroupedLinear: "megatron_TEColumnParallelGroupedLinear"} + ) + class _MegatronTEGroupedColumnParallelLinear( + _QuantMegatronTEGroupedLinear, _MegatronColumnParallelLinear + ): + pass + + @QuantModuleRegistry.register( + {TERowParallelGroupedLinear: "megatron_TERowParallelGroupedLinear"} + ) + class _MegatronTEGroupedRowParallelLinear( + _QuantMegatronTEGroupedLinear, _MegatronRowParallelLinear + ): + pass + + @QuantModuleRegistry.register({megatron_moe.TEGroupedMLP: "megatron_moe_TEGroupedMLP"}) + class _MegatronTEGroupedMLP(_MegatronMLP): + def _setup(self): + if not hasattr(self, "parallel_state") or self.parallel_state is None: + self.parallel_state = ParallelState( + mcore_parallel.get_expert_data_parallel_group(), + tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), + expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), + ) + # initialize parallel state for submodules linear_fc1 and linear_fc2 + self.linear_fc1.parallel_state = self.parallel_state + self.linear_fc2.parallel_state = self.parallel_state diff --git a/modelopt/torch/quantization/plugins/transformer_engine.py b/modelopt/torch/quantization/plugins/transformer_engine.py index b068ebca7..5199bbf34 100644 --- a/modelopt/torch/quantization/plugins/transformer_engine.py +++ b/modelopt/torch/quantization/plugins/transformer_engine.py @@ -17,6 +17,7 @@ import torch import transformer_engine as te +import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear import transformer_engine.pytorch.module.linear as te_linear from ..nn import QuantModuleRegistry @@ -58,3 +59,60 @@ def te_quantized_linear_fn(package, func_name, self, *args, **kwargs): # Override the quantized linear function _quantized_linear_fn = te_quantized_linear_fn + + +# Register the public te.pytorch.GroupedLinear class +@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"}) +class _QuantTEGroupedLinear(_ParallelLinear): + _functionals_to_replace = [ + (te_grouped_linear._GroupedLinear, "forward"), + (te_grouped_linear._GroupedLinear, "apply"), + ] + + def _setup(self): + # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to + # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning + # self.weight0 to self.weight to run the quantizer states initialization. + assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear" + self.weight = self.weight0 + # Memorize the original weight.dtype for modelopt_post_restore given that + # the dtype can change later. + super()._setup() + # Remove self.weight after setup. + delattr(self, "weight") + + def modelopt_post_restore(self, prefix: str = ""): + # GroupedMLP stores the weights as weight0, weight1, etc. To run post_restore in order to + # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning + # self.weight0 to self.weight to run the quantizer states initialization. + assert not hasattr(self, "weight"), "self.weight should not exist for TEGroupedLinear" + self.weight = self.weight0 + super().modelopt_post_restore(prefix=prefix) + # Remove self.weight after post_restore. + delattr(self, "weight") + + @staticmethod + def te_grouped_quantized_linear_fn(package, func_name, self, *args): + idx = 1 if func_name == "_forward" else 0 + inp = args[idx] + num_gemms = len(args[idx + 1]) + weights_and_biases = args[-2 * num_gemms :] + weights, biases = weights_and_biases[:num_gemms], weights_and_biases[num_gemms:] + quantized_inputs = self.input_quantizer(inp) + quantized_weights = [self.weight_quantizer(weight) for weight in weights] + + output = getattr(package, func_name)( + *( + args[0], + quantized_inputs, + ) + if func_name == "_forward" + else (quantized_inputs,), + *args[idx + 1 : -2 * num_gemms], + *quantized_weights, + *biases, + ) + return self.output_quantizer(output) + + # Override the quantized linear function + _quantized_linear_fn = te_grouped_quantized_linear_fn diff --git a/modelopt/torch/quantization/utils.py b/modelopt/torch/quantization/utils.py index 6167daf23..43e269fa1 100644 --- a/modelopt/torch/quantization/utils.py +++ b/modelopt/torch/quantization/utils.py @@ -251,8 +251,11 @@ def is_quantized_linear(module): isinstance(module, QuantModule) and isinstance(getattr(module, "input_quantizer", None), TensorQuantizer) and hasattr(module, "weight_quantizer") - and getattr(module, "weight", None) is not None - and module.weight.dim() == 2 + and ( + (getattr(module, "weight", None) is not None and module.weight.dim() == 2) + # module.weight0 check is required to support TEGroupedLinear + or (getattr(module, "weight0", None) is not None and module.weight0.dim() == 2) + ) ) diff --git a/modelopt/torch/utils/distributed.py b/modelopt/torch/utils/distributed.py index f11a736db..bcebd0492 100644 --- a/modelopt/torch/utils/distributed.py +++ b/modelopt/torch/utils/distributed.py @@ -241,16 +241,20 @@ def __init__( self, data_parallel_group: torch.distributed.ProcessGroup | int | None = None, tensor_parallel_group: torch.distributed.ProcessGroup | int | None = -1, + expert_model_parallel_group: torch.distributed.ProcessGroup | int | None = -1, ): """Initialize the parallel state.""" self.data_parallel_group = DistributedProcessGroup(data_parallel_group) self.tensor_parallel_group = DistributedProcessGroup(tensor_parallel_group) + self.expert_model_parallel_group = DistributedProcessGroup(expert_model_parallel_group) def __repr__(self) -> str: - return ( + parallel_groups = ( f"data_parallel_group: {self.data_parallel_group}, " f"tensor_parallel_group: {self.tensor_parallel_group}, " + f"expert_model_parallel_group: {self.expert_model_parallel_group}" ) + return parallel_groups def get_group(ranks: list[int]): diff --git a/tests/_test_utils/torch_dist/plugins/megatron_common.py b/tests/_test_utils/torch_dist/plugins/megatron_common.py index ca6b9bff7..c913bd7d2 100644 --- a/tests/_test_utils/torch_dist/plugins/megatron_common.py +++ b/tests/_test_utils/torch_dist/plugins/megatron_common.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import re +from collections import defaultdict from warnings import warn import torch @@ -38,6 +40,9 @@ ) from megatron.core.models.mamba import MambaModel from megatron.core.parallel_state import ( + get_expert_model_parallel_group, + get_expert_tensor_parallel_group, + get_expert_tensor_parallel_rank, initialize_model_parallel, is_pipeline_first_stage, is_pipeline_last_stage, @@ -49,6 +54,7 @@ from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.transformer_config import TransformerConfig +import modelopt.torch.quantization as mtq from modelopt.torch.export.unified_export_megatron import import_mcore_gpt_from_hf from modelopt.torch.opt.plugins.mcore_dist_checkpointing import ( restore_sharded_modelopt_state, @@ -143,6 +149,8 @@ def get_dummy_input(self, seed: int | None = None) -> torch.Tensor: def get_mcore_gpt_model( tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, + expert_model_parallel_size: int = 1, + expert_tensor_parallel_size: int | None = None, initialize_megatron: bool = False, *, num_layers: int = 2, @@ -158,7 +166,10 @@ def get_mcore_gpt_model( normalization: str = "LayerNorm", transformer_impl: str = "modelopt" if HAS_TE else "local", use_cpu_initialization: bool = False, + num_moe_experts: int | None = None, + moe_grouped_gemm: bool = False, bf16: bool = True, + use_te: bool = False, ) -> GPTModel: assert activation_func in ["swiglu", "squared_relu"] assert normalization in ["LayerNorm", "RMSNorm"] @@ -166,7 +177,12 @@ def get_mcore_gpt_model( print(f"Using `{transformer_impl=}` model spec for building GPT Model.") if initialize_megatron: - initialize_for_megatron(tensor_model_parallel_size, pipeline_model_parallel_size) + initialize_for_megatron( + tensor_model_parallel_size, + pipeline_model_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + expert_tensor_parallel_size=expert_tensor_parallel_size, + ) def squared_relu(x): return torch.pow(F.relu(x), 2) @@ -174,7 +190,10 @@ def squared_relu(x): config = TransformerConfig( tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + expert_tensor_parallel_size=expert_tensor_parallel_size, sequence_parallel=False, + moe_grouped_gemm=moe_grouped_gemm, num_layers=num_layers, num_layers_in_first_pipeline_stage=num_layers_in_first_pipeline_stage, num_layers_in_last_pipeline_stage=num_layers_in_last_pipeline_stage, @@ -182,6 +201,7 @@ def squared_relu(x): num_attention_heads=num_attention_heads, num_query_groups=num_query_groups, ffn_hidden_size=ffn_hidden_size, + num_moe_experts=num_moe_experts, activation_func=squared_relu if activation_func == "squared_relu" else F.silu, normalization=normalization, gated_linear_unit=(activation_func == "swiglu"), @@ -193,11 +213,22 @@ def squared_relu(x): if transformer_impl == "local": assert HAS_APEX, "Apex not installed" - transformer_layer_spec = get_gpt_layer_local_spec(normalization=normalization) + transformer_layer_spec = get_gpt_layer_local_spec( + num_experts=num_moe_experts, + normalization=normalization, + moe_grouped_gemm=moe_grouped_gemm, + # TODO: uncomment this when TEGroupedMLP is enabled in Megatron-LM + # use_te=use_te, + ) else: assert HAS_TE, "Transformer Engine not installed" transformer_layer_spec = ( - get_gpt_modelopt_spec(config, remap_te_layernorm=True) + get_gpt_modelopt_spec( + config, + remap_te_layernorm=True, + # TODO: uncomment this when TEGroupedMLP is enabled in Megatron-LM + # moe_grouped_gemm=moe_grouped_gemm + ) if transformer_impl == "modelopt" else get_gpt_layer_with_transformer_engine_spec() ) @@ -212,6 +243,7 @@ def squared_relu(x): share_embeddings_and_output_weights=False, position_embedding_type="rope", ) + if bf16: model = model.to(torch.bfloat16) @@ -403,6 +435,8 @@ def initialize_for_megatron( pipeline_model_parallel_size=1, seed=1234, context_parallel_size=1, + expert_model_parallel_size=1, + expert_tensor_parallel_size=None, ): """Initialize Megatron model parallelism. @@ -412,6 +446,8 @@ def initialize_for_megatron( tensor_model_parallel_size, pipeline_model_parallel_size, context_parallel_size=context_parallel_size, + expert_tensor_parallel_size=expert_tensor_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, ) model_parallel_cuda_manual_seed(seed) @@ -478,3 +514,149 @@ def convert_maybe_fp8(v): assert torch.allclose(logits_ref, logits_test), ( f"diff: {logits_diff.max()} ref: {logits_ref}, test: {logits_test}" ) + + +def copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_moe_model): + """Copy weights from TEGrouped MoE model to sequential MoE model.""" + te_grouped_state = te_grouped_moe_model.state_dict() + sequential_state = sequential_moe_model.state_dict() + + # Map grouped weights to sequential weights + weight_mapping = {} + sequential_key_template = "decoder.layers.{}.mlp.experts.local_experts.{}.linear_fc{}" + for key, value in te_grouped_state.items(): + if "experts.linear_fc" in key and any(param in key for param in ("weight", "bias")): + # Extract expert index from grouped weight name + # Format: decoder.layers.X.mlp.experts.linear_fcY.weightZ + parts = key.split(".") + layer_idx = parts[2] # X + fc_idx = parts[5] # Y (linear_fc1 or linear_fc2) + param_idx = parts[6] # weight0 / bias0 / etc. + match = re.search(r"\d+", param_idx) + expert_idx = match.group(0) if match else "0" # Z for expert index + # Map to sequential format: decoder.layers.X.mlp.experts.local_experts.Y.linear_fcZ + sequential_key = sequential_key_template.format(layer_idx, expert_idx, fc_idx[-1]) + param_name = "weight" if "weight" in param_idx else "bias" + weight_mapping[f"{sequential_key}.{param_name}"] = value + elif isinstance(value, torch.Tensor): + weight_mapping[key] = value + + # Copy weights to sequential model + for sequential_key in sequential_state: + if sequential_key in weight_mapping: + sequential_state[sequential_key] = weight_mapping[sequential_key].clone() + + sequential_moe_model.load_state_dict(sequential_state) + + +def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True): + """ + Test if amax values are synchronized across expert parallel groups. + + Returns True if synchronized, False otherwise. + """ + + ep_group = get_expert_model_parallel_group(check_initialized=False) + etp_group = get_expert_tensor_parallel_group(check_initialized=False) + + # Check if we have either expert model parallel or expert tensor parallel + has_expert_parallel = (ep_group is not None and ep_group.size() > 1) or ( + etp_group is not None and etp_group.size() > 1 + ) + + assert has_expert_parallel, "No expert parallelism detected" + # Collect amax values from expert quantizers only + expert_amax_values = {} + for name, module in model.named_modules(): + if isinstance(module, mtq.nn.TensorQuantizer) and hasattr(module, "_amax"): + # Check for both TEGrouped and sequential MoE patterns + if "local_experts" in name or ("experts" in name and "linear_fc" in name): + # Convert to scalar only if tensor has a single element + expert_amax_values[name] = module.amax.detach().clone().cpu() + + # Early return if no expert quantizers found + assert expert_amax_values, "No expert quantizers found" + + # Gather amax values from all ranks + world_size = torch.distributed.get_world_size() + all_amax_values = [None] * world_size + torch.distributed.all_gather_object(all_amax_values, expert_amax_values) + + # Group quantizers by type (ignoring specific expert indices) and check sync + expert_quantizers = defaultdict(dict) + for rank_idx, rank_amax in enumerate(all_amax_values): + for name, amax_val in rank_amax.items(): + # Create quantizer type key by normalizing the name + quantizer_type = ( + re.sub(r"local_experts\.\d+", "local_experts.*", name) + if "local_experts" in name + else name + ) + + if ( + quantizer_type in expert_quantizers + and rank_idx in expert_quantizers[quantizer_type] + ): + if compare_across_experts: + # compare expert value across expert for sequential MoE + prev_val = expert_quantizers[quantizer_type][rank_idx] + # Handle both scalar and tensor comparisons + if isinstance(amax_val, torch.Tensor) and isinstance(prev_val, torch.Tensor): + are_equal = torch.allclose(prev_val, amax_val, rtol=1e-6, atol=1e-6) + else: + are_equal = prev_val == amax_val + assert are_equal, ( + f"{rank_idx}, {quantizer_type}, expert_quantizers[quantizer_type][rank_idx]: " + f"{expert_quantizers[quantizer_type][rank_idx]}, amax_val: {amax_val}" + ) + expert_quantizers[quantizer_type][rank_idx] = amax_val + + rank_info = { + "global_rank": torch.distributed.get_rank(), + "etp_rank": get_expert_tensor_parallel_rank(), + } + + all_rank_info = [None] * world_size + torch.distributed.all_gather_object(all_rank_info, rank_info) + + # Group ranks by ETP rank for fc1 (ColumnParallel: same output channels should match) + etp_groups = defaultdict(list) + for info in all_rank_info: + etp_groups[info["etp_rank"] if info["etp_rank"] else 0].append(info["global_rank"]) + + for quantizer_type, rank_values in expert_quantizers.items(): + # Determine which ranks should have same amax + # Find which rank should have same amax + # + # fc1: ColumnParallel: X @ [A_1, A_2] (weights split along Cout) + # so amax should be the same across same ETP rank + # if EP is 2, ETP is 2, we have 4 ranks, EP1, ETP1: 0, EP1, ETP2: 1, EP2, ETP1: 2, EP2, ETP2: 3 + # so we need to compare amax across same ETP rank [0, 2] [1, 3] for per-channel quantization + # + # fc2: RowParallel: [X_1, X_2] @ [A_1 + # A_2] (weights split along Cin) + # amax should be the same across all ranks + rank_groups = ( + list(etp_groups.values()) + if "linear_fc1" in quantizer_type and (next(iter(rank_values.values()))).ndim > 0 + else [list(range(world_size))] + ) + + # Check each group independently + for group in rank_groups: + group_values = [rank_values[r] for r in group if r in rank_values] + if len(group_values) > 1: + # All values in this group should be identical + first_val = group_values[0] + for val in group_values[1:]: + if isinstance(first_val, torch.Tensor): + if not torch.allclose(first_val, val, rtol=1e-6, atol=1e-6): + group_rank_values = { + r: rank_values[r] for r in group if r in rank_values + } + return False, f"{quantizer_type} (group {group})", group_rank_values + elif abs(first_val - val) > 1e-6: + group_rank_values = {r: rank_values[r] for r in group if r in rank_values} + return False, f"{quantizer_type} (group {group})", group_rank_values + + return True, None, None diff --git a/tests/gpu/torch/conftest.py b/tests/gpu/torch/conftest.py index f32065bce..d1ba9dd47 100644 --- a/tests/gpu/torch/conftest.py +++ b/tests/gpu/torch/conftest.py @@ -40,6 +40,12 @@ def need_8_gpus(): pytest.skip("Need at least 8 GPUs to run this test") +@pytest.fixture +def need_4_gpus(): + if torch.cuda.device_count() < 4: + pytest.skip("Need at least 4 GPUs to run this test") + + @pytest.fixture(scope="module") def set_torch_dtype(request): orig_dtype = torch.get_default_dtype() diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index b63462ef3..d67ff44d0 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -21,6 +21,8 @@ from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job from _test_utils.torch_dist.plugins.megatron_common import ( MegatronModel, + compare_amax_sync_across_expert_parallel, + copy_weights_from_grouped_to_non_grouped, get_mcore_gpt_model, initialize_for_megatron, run_mcore_inference, @@ -42,6 +44,8 @@ get_tensor_model_parallel_group, ) from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.moe.experts import SequentialMLP, TEGroupedMLP +from megatron.core.transformer.moe.router import TopKRouter import modelopt import modelopt.torch.opt as mto @@ -227,38 +231,59 @@ def test_data_tensor_context_parallel(need_8_gpus, config): ) -def _gpt_model_provider(tp_size: int, hidden_size=256, vocab_size=64, meta_device=False): +def _gpt_model_provider( + tp_size: int, + hidden_size=256, + vocab_size=64, + num_moe_experts=None, + moe_grouped_gemm=False, + meta_device=False, + ep_size=1, + etp_size=None, + use_te=False, + transformer_impl="local", +): """Build the model.""" if meta_device: with torch.device("meta"): gpt_model = get_mcore_gpt_model( tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, num_layers=4, ffn_hidden_size=None, num_attention_heads=8, activation_func="squared_relu", - transformer_impl="local", + transformer_impl=transformer_impl, hidden_size=hidden_size, vocab_size=vocab_size, use_cpu_initialization=meta_device, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, ) else: gpt_model = get_mcore_gpt_model( tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, num_layers=4, ffn_hidden_size=None, num_attention_heads=8, activation_func="squared_relu", - transformer_impl="local", + transformer_impl=transformer_impl, hidden_size=hidden_size, vocab_size=vocab_size, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, ).cuda() return gpt_model.eval() def _test_sharded_state_dict( - tmp_path, config, hidden_size, modelopt_version, compress, meta_device, rank, size + tmp_path, config, hidden_size, modelopt_version, compress, meta_device, moe_config, rank, size ): # Must disable output_layer quantization since output_layer amax cannot be restore via # sharded_state_dict. All output_layer quantizers state are removed. @@ -268,10 +293,44 @@ def _test_sharded_state_dict( mto.conversion.__version__ = modelopt_version mtq.plugins.megatron.__version__ = modelopt_version - initialize_for_megatron(tensor_model_parallel_size=size, seed=SEED) + tp_size = moe_config.get("tp_size", size) + ep_size = moe_config.get("ep_size", 1) + etp_size = moe_config.get("etp_size", None) + num_moe_experts = moe_config.get("num_moe_experts", None) + moe_grouped_gemm = moe_config.get("moe_grouped_gemm", False) + use_te = moe_config.get("use_te", False) + transformer_impl = moe_config.get("transformer_impl", "local") - model_ref = _gpt_model_provider(size, hidden_size, vocab_size=256) - model_test = _gpt_model_provider(size, hidden_size, vocab_size=256, meta_device=meta_device) + initialize_for_megatron( + tensor_model_parallel_size=tp_size, + seed=SEED, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, + ) + + model_ref = _gpt_model_provider( + tp_size, + hidden_size, + vocab_size=256, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, + ep_size=ep_size, + etp_size=etp_size, + transformer_impl=transformer_impl, + ) + model_test = _gpt_model_provider( + tp_size, + hidden_size, + vocab_size=256, + num_moe_experts=num_moe_experts, + moe_grouped_gemm=moe_grouped_gemm, + use_te=use_te, + meta_device=meta_device, + ep_size=ep_size, + etp_size=etp_size, + transformer_impl=transformer_impl, + ) prompt_tokens = torch.randint( 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) @@ -352,7 +411,9 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device) spawn_multiprocess_job( size=size, - job=partial(_test_sharded_state_dict, tmp_path, config, 256, None, compress, meta_device), + job=partial( + _test_sharded_state_dict, tmp_path, config, 256, None, compress, meta_device, {} + ), backend="nccl", ) @@ -367,7 +428,7 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device) def test_heterogenous_sharded_state_dict(need_2_gpus, tmp_path, config): spawn_multiprocess_job( size=2, - job=partial(_test_sharded_state_dict, tmp_path, config, 256, None, False, False), + job=partial(_test_sharded_state_dict, tmp_path, config, 256, None, False, False, {}), backend="nccl", ) @@ -388,7 +449,7 @@ def test_sharded_state_dict_old_checkpoints(need_2_gpus, tmp_path, config, model spawn_multiprocess_job( size=2, job=partial( - _test_sharded_state_dict, tmp_path, config, 256, modelopt_version, False, False + _test_sharded_state_dict, tmp_path, config, 256, modelopt_version, False, False, {} ), backend="nccl", ) @@ -471,3 +532,206 @@ def forward_fn(model): def test_fp8_real_quantize(): size = torch.cuda.device_count() spawn_multiprocess_job(size=size, job=_test_fp8_real_quantize_helper, backend="nccl") + + +@pytest.mark.parametrize( + "config", + [mtq.FP8_DEFAULT_CFG, mtq.NVFP4_DEFAULT_CFG, mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG], +) +@pytest.mark.parametrize("moe_grouped_gemm", [True, False]) +def test_moe_sharded_state_dict(need_4_gpus, tmp_path, config, moe_grouped_gemm): + if moe_grouped_gemm: + pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently") + size = torch.cuda.device_count() + # TODO: Add support for compress=True for TEGroupedMLP + moe_config = { + "tp_size": 2, + "ep_size": 2, + "etp_size": 2, + "num_moe_experts": 4, + "moe_grouped_gemm": moe_grouped_gemm, + "use_te": moe_grouped_gemm, + "transformer_impl": "modelopt", + } + spawn_multiprocess_job( + size=size, + job=partial( + _test_sharded_state_dict, + tmp_path, + config, + 256, + None, + False, + False, + moe_config, + ), + backend="nccl", + ) + + +def _test_te_grouped_vs_sequential_quantize_helper(tp_size, ep_size, etp_size, rank, size): + """Test that TEGrouped and sequential MoE models produce similar amax values.""" + initialize_for_megatron( + tensor_model_parallel_size=tp_size, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, + seed=SEED, + ) + + # Create input + prompt_tokens = torch.randint(0, 64, (2, 16)).cuda() + + def forward_fn(model): + return megatron_prefill(model, prompt_tokens) + + # Create TEGrouped MoE model + te_grouped_moe_model = _gpt_model_provider( + tp_size=tp_size, + ep_size=ep_size, + etp_size=etp_size, + hidden_size=32, + moe_grouped_gemm=True, + use_te=True, + num_moe_experts=4, + ) + num_te_grouped_mlp = sum( + isinstance(module, TEGroupedMLP) for module in te_grouped_moe_model.modules() + ) + assert num_te_grouped_mlp == 4, ( + f"TEGrupedMoEModel has {num_te_grouped_mlp} TEGroupedMLP modules, it should have 4" + ) + + # Create sequential MoE model + sequential_moe_model = _gpt_model_provider( + tp_size=tp_size, + ep_size=ep_size, + etp_size=etp_size, + hidden_size=32, + moe_grouped_gemm=False, + num_moe_experts=4, + transformer_impl="modelopt", + ) + num_sequential_mlp = sum( + isinstance(module, SequentialMLP) for module in sequential_moe_model.modules() + ) + assert num_sequential_mlp == 4, ( + f"SequentialMoEModel has {num_sequential_mlp} SequentialMLP modules, it should have 4" + ) + # Copy weights from grouped to non-grouped model + copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_moe_model) + + # Compare model outputs before quantization + te_grouped_moe_output = forward_fn(te_grouped_moe_model) + sequential_moe_output = forward_fn(sequential_moe_model) + assert torch.allclose(te_grouped_moe_output, sequential_moe_output, atol=1e-6, rtol=1e-6) + + # Quantize grouped model + mtq.quantize(te_grouped_moe_model, mtq.FP8_DEFAULT_CFG, forward_fn) + + # Quantize non-grouped model + mtq.quantize(sequential_moe_model, mtq.FP8_DEFAULT_CFG, forward_fn) + + # Compare model outputs after quantization + te_grouped_moe_quant_output = forward_fn(te_grouped_moe_model) + sequential_moe_quant_output = forward_fn(sequential_moe_model) + assert torch.allclose( + te_grouped_moe_quant_output, sequential_moe_quant_output, atol=1e-6, rtol=1e-6 + ) + + +def test_te_grouped_vs_sequential_quantize(need_4_gpus): + """Test that TEGrouped and sequential MoE models produce similar quantized models.""" + pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently") + size = torch.cuda.device_count() + spawn_multiprocess_job( + size=size, + job=partial(_test_te_grouped_vs_sequential_quantize_helper, 1, 2, 2), + backend="nccl", + ) + + +def _test_expert_model_parallel_amax_sync( + tp_size, ep_size, etp_size, moe_grouped_gemm, config, rank, size +): + """Test expert parallel synchronization with different configurations.""" + initialize_for_megatron( + tensor_model_parallel_size=tp_size, + pipeline_model_parallel_size=1, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=etp_size, + seed=SEED, + ) + + # Create model with expert parallelism + model = _gpt_model_provider( + tp_size=tp_size, + ep_size=ep_size, + etp_size=etp_size, + hidden_size=256, + moe_grouped_gemm=moe_grouped_gemm, + use_te=moe_grouped_gemm, + num_moe_experts=8, + transformer_impl="modelopt", + ) + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + + # force all expert routing + for module in model.modules(): + if isinstance(module, TopKRouter): + module.topk = module.num_experts + + def forward_fn(model): + return megatron_prefill(model, prompt_tokens) + + # quantize the model + model = mtq.quantize(model, config, forward_fn) + # Check initial sync status + initial_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model) + assert initial_sync, ( + f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}" + ) + + # Test if the amax values are inconsistent when distributed sync is disabled + mtq.model_calib.max_calibrate(model, forward_fn, distributed_sync=False) + inconsistent_amax, _, _ = compare_amax_sync_across_expert_parallel( + model, compare_across_experts=False + ) + + assert not inconsistent_amax, ( + "Consistent amax across expert parallel ranks, " + "Amax should not be synchronized across expert parallel ranks since expert parallel is disabled" + ) + # calibrate the model with distributed sync and test synchronization + mtq.model_calib.max_calibrate(model, forward_fn, distributed_sync=True) + for module in model.modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() + + final_sync, quantizer_type, rank_values = compare_amax_sync_across_expert_parallel(model) + assert final_sync, f"Inconsistent amax for expert {quantizer_type} across ranks: {rank_values}" + + +@pytest.mark.parametrize("config", [mtq.FP8_DEFAULT_CFG, mtq.INT8_DEFAULT_CFG]) +@pytest.mark.parametrize(("ep_size", "etp_size"), [(1, 2), (2, 1), (2, 2)]) +@pytest.mark.parametrize("moe_grouped_gemm", [True, False]) +def test_expert_parallel_sync(config, ep_size, etp_size, moe_grouped_gemm): + """Test expert model parallel synchronization.""" + size = torch.cuda.device_count() + if size < ep_size * etp_size: + pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test") + + if moe_grouped_gemm: + pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently") + + spawn_multiprocess_job( + size=size, + job=partial( + _test_expert_model_parallel_amax_sync, + etp_size, # tp_size + ep_size, + etp_size, + moe_grouped_gemm, + config, + ), + backend="nccl", + )