diff --git a/modelopt/torch/quantization/plugins/__init__.py b/modelopt/torch/quantization/plugins/__init__.py index d1451e37d..ac63582bb 100644 --- a/modelopt/torch/quantization/plugins/__init__.py +++ b/modelopt/torch/quantization/plugins/__init__.py @@ -73,4 +73,4 @@ from .trl import * with import_plugin("mcore"): - from .mcore import * \ No newline at end of file + from .mcore import * diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 3452f27bc..474fa0872 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -22,11 +22,11 @@ import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp import torch +from megatron.core.extensions.transformer_engine import TEDotProductAttention from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint from megatron.core.utils import get_tensor_model_parallel_group_if_none -from megatron.core.extensions.transformer_engine import TEDotProductAttention from modelopt.torch.opt.plugins.megatron import ( _MegatronMLP, @@ -34,11 +34,11 @@ ) from modelopt.torch.utils.distributed import ParallelState +from ..model_calib import max_calibrate from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear -from ..model_calib import max_calibrate __all__ = [] @@ -461,6 +461,14 @@ class _RealQuantMegatronRowParallelLinear( _scale_tensor_shard_axis = 1 def forward(self, input, *args, **kwargs): + """ + Compute the forward pass using the row-parallel linear implementation. + + Forwards all positional and keyword arguments to the row-parallel parent implementation. + + Returns: + torch.Tensor: The output activations produced by the linear layer. + """ return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs) @@ -469,32 +477,46 @@ class _QuantTEDotProductAttention(QuantModule): """Quantized version of TEDotProductAttention for Megatron models with KV cache quantization.""" def _setup(self): - """Initialize quantizers for Q, K, V tensors.""" + """ + Create and attach three TensorQuantizer instances as q_bmm_quantizer, k_bmm_quantizer, and v_bmm_quantizer for quantizing query, key, and value tensors. + """ self.q_bmm_quantizer = TensorQuantizer() self.k_bmm_quantizer = TensorQuantizer() self.v_bmm_quantizer = TensorQuantizer() def _calibrate_quantizers(self): - """Calibrate quantizers with minimal dummy tensors.""" - # Get device from parent module parameters - device = next(self.parameters()).device if self.parameters() else torch.device('cuda') + """ + Calibrate the module's Q/K/V tensor quantizers using minimal dummy inputs. + Creates a tiny float16 dummy tensor shaped according to the attention QKV layout (either "sbhd" or "bshd", determined from self.config.apply_rope_fusion) and uses it to compute and store `_amax` values for any enabled q_bmm_quantizer, k_bmm_quantizer, or v_bmm_quantizer that does not yet have an `_amax`. Calibration is performed only for quantizers that are enabled and lack existing scale information. + """ + # Get device from parent module parameters + device = next(self.parameters()).device if self.parameters() else torch.device("cuda") + # TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion batch_size = 1 seq_len = 1 - + # Get dimensions from config num_heads = self.config.num_attention_heads - head_dim = self.config.kv_channels if hasattr(self.config, 'kv_channels') else self.config.hidden_size // num_heads - + head_dim = ( + self.config.kv_channels + if hasattr(self.config, "kv_channels") + else self.config.hidden_size // num_heads + ) + # Determine tensor format (default to sbhd if not specified) - apply_rope_fusion = getattr(self.config, 'apply_rope_fusion', False) + apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False) qkv_format = "bshd" if apply_rope_fusion else "sbhd" if qkv_format == "sbhd": - dummy_tensor = torch.randn(seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16) + dummy_tensor = torch.randn( + seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16 + ) else: - dummy_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16) + dummy_tensor = torch.randn( + batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16 + ) # Calibrate each quantizer quantizers = [ @@ -510,10 +532,16 @@ def _calibrate_quantizers(self): max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False) def forward(self, query, key, value, *args, **kwargs): - """Apply post-RoPE quantization to KV cache. + """ + Quantize the provided query, key, and value tensors for KV-cache and forward them to the base attention implementation. + + Parameters: + query (Tensor): Query tensor (already rotated by RoPE) to be quantized and used for attention. + key (Tensor): Key tensor (already rotated by RoPE) to be quantized and used for attention. + value (Tensor): Value tensor to be quantized and used for attention. - TEDotProductAttention receives Q, K, V after RoPE is applied, - so we quantize them directly for KV cache quantization. + Returns: + The output of the parent attention `forward` called with the quantized query, key, and value. """ # Quantize Q, K, V query = self.q_bmm_quantizer(query) @@ -523,9 +551,22 @@ def forward(self, query, key, value, *args, **kwargs): return super().forward(query, key, value, *args, **kwargs) def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): - """Create a sharded state dictionary for distributed checkpointing.""" - sharded_state_dict = {} + """ + Builds a sharded state dictionary containing non-quantizer parameters and bmm-quantizer state for distributed checkpointing. + Parameters: + prefix (str): Key prefix to prepend to returned state keys. + sharded_offsets (tuple): Offsets describing shard positions for sharded tensors (passed to make_sharded_tensors_for_checkpoint). + metadata: Ignored by this implementation (kept for API compatibility). + + Returns: + state_dict (dict): Mapping from checkpoint keys to tensors, including: + - Non-quantizer module tensors (prefixed). + - Per-quantizer `_amax` entries for q/k/v bmm quantizers when present. + - Other quantizer tensors processed into sharded tensors via the checkpoint helper. + """ + sharded_state_dict = {} + # First add non-quantizer parameters for k, v in self.state_dict(prefix="", keep_vars=True).items(): if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k: @@ -542,10 +583,11 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): sharded_state_dict[amax_key] = quantizer._amax # Process other quantizer parameters in bmm_quantizers - quantizer_state_dict = {} - for k, v in self.state_dict(prefix="", keep_vars=True).items(): - if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k: - quantizer_state_dict[k] = v + quantizer_state_dict = { + k: v + for k, v in self.state_dict(prefix="", keep_vars=True).items() + if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k + } if quantizer_state_dict: sharded_state_dict.update( @@ -557,7 +599,18 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): return sharded_state_dict def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): - """Handle loading state dict for quantizers.""" + """ + Adjust quantizer entries in a loaded state dict to match this module's expected keys and tensor shapes before delegating to the parent loader. + + This method: + - Renames per-quantizer `_amax` keys from `{prefix}{quantizer_name}._amax` to `{prefix}{quantizer_name}._amax`'s expected TensorQuantizer key format (`{prefix}{quantizer_name}._amax` -> `{prefix}{quantizer_name}._amax` mapped to `{prefix}{quantizer_name}._amax` as `_amax` is remapped to `_{quantizer_name}_amax` format expected by the local TensorQuantizer). + - Reshapes any remaining quantizer state tensors (keys containing `_quantizer` but not `_amax`) to match the corresponding tensor shapes in this module's `state_dict`. + - Calls the superclass `_load_from_state_dict` with the adjusted `state_dict`. + + Parameters: + state_dict (dict): The incoming state dictionary being loaded; modified in-place to align quantizer keys and shapes. + prefix (str): The prefix applied to keys for this module in `state_dict`. + """ for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]: full_prefix = f"{prefix}{quantizer_name}." amax_key = f"{prefix}{quantizer_name}._amax" @@ -577,14 +630,29 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) def modelopt_post_restore(self, name=""): - """Restore quantizer states after model loading.""" + """ + Perform post-restore validation for attention quantizers and trigger calibration if needed. + + Checks each of the instance's Q/K/V BMM quantizers (if present and enabled) for unsupported saved state keys and emits a warning identifying the provided `name` when such keys are found. If any enabled quantizer lacks a stored `_amax` value, schedules and runs quantizer calibration by calling self._calibrate_quantizers(). + + Parameters: + name (str): Human-readable identifier for the module being restored; included in warning messages to help locate the layer. + """ super().modelopt_post_restore(name) def _check_unsupported_states(quantizer): + """ + Check a quantizer's saved state keys and warn about any unsupported entries. + + Inspects quantizer.state_dict() (if present) and emits a warning for each key other than `_amax` and `_pre_quant_scale` indicating that restoring that key is not supported. + + Parameters: + quantizer: An object with a `state_dict()` method (typically a TensorQuantizer) whose saved state keys will be validated. + """ if not hasattr(quantizer, "state_dict"): return - for k in quantizer.state_dict().keys(): + for k in quantizer.state_dict(): if k not in ["_amax", "_pre_quant_scale"]: warnings.warn( f"Restore of {k} for {name} is not supported. The restore of this layer might be " diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index bac699481..1f6299c0c 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -17,7 +17,6 @@ import pytest import torch -import torch.nn as nn from _test_utils.import_helper import skip_if_no_megatron from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job from _test_utils.torch_dist.plugins.megatron_common import ( @@ -368,14 +367,30 @@ def forward_fn(model): def test_fp8_real_quantize(): + """ + Runs the FP8 real quantization memory reduction test across all available CUDA devices. + + Spawns a multiprocess job (NCCL backend) that executes _test_fp8_real_quantize_helper on each detected GPU. + """ size = torch.cuda.device_count() spawn_multiprocess_job(size=size, job=_test_fp8_real_quantize_helper, backend="nccl") def _test_kv_cache_quant_helper(config, rank, size): - """Helper function for testing KV cache quantization with TEDotProductAttention.""" - initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED) + """ + Verify that TEDotProductAttention modules receive KV-cache quantization and remain functional after quantization. + + Quantizes a minimal GPT model (built with transformer_impl="modelopt") using the provided `config`, checks that each TEDotProductAttention-like module exposes `k_bmm_quantizer` and `v_bmm_quantizer`, asserts those quantizers are enabled, and performs a smoke forward pass to ensure the quantized model runs. + Parameters: + config: Quantization configuration used to quantize the model (e.g., FP8_KV_CFG or NVFP4_KV_CFG). + rank (int): Process rank in the distributed test invocation. + size (int): Tensor model parallel size used to construct the model. + """ + initialize_for_megatron( + tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED + ) + # Use existing infrastructure to create a minimal GPT model with TEDotProductAttention # Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention model = get_mcore_gpt_model( @@ -386,43 +401,64 @@ def _test_kv_cache_quant_helper(config, rank, size): vocab_size=32, transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec ).cuda() - + # Create dummy input for calibration prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() - + def forward_fn(model): + """ + Run megatron_prefill with the predefined `prompt_tokens` on the given model. + + Parameters: + model: The Megatron model to execute the prefill pass on. + + Returns: + The outputs produced by `megatron_prefill` for the provided model. + """ return megatron_prefill(model, prompt_tokens) - + # Test KV cache quantization with the given config quantized_model = mtq.quantize(model, config, forward_fn) - + # Find TEDotProductAttention modules and verify they have KV cache quantizers te_attention_found = False for name, module in quantized_model.named_modules(): # Check if this is a quantized TEDotProductAttention - if hasattr(module, 'q_bmm_quantizer') and hasattr(module, 'k_bmm_quantizer'): + if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"): te_attention_found = True # Verify all expected quantizers exist - assert hasattr(module, 'v_bmm_quantizer'), f"Missing v_bmm_quantizer in {name}" - + assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}" + # Verify K and V quantizers are enabled (main purpose of KV cache configs) assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" - + assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model" - + # Quick smoke test that forward still works output = forward_fn(quantized_model) assert output is not None, "Forward pass failed" - + def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): - """Helper for testing KV cache quantization with sharded state dict save/load.""" - # Disable output_layer quantization (same as other sharded state dict tests) - config["quant_cfg"]["*output_layer*"] = {"enable": False} + """ + Validate that KV-cache quantizers on TEDotProductAttention modules are created and correctly preserved across sharded state_dict save/load. - initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED) + This helper initializes Megatron, constructs two small GPT models using transformer_impl="modelopt" (TEDotProductAttention), quantizes both models with the provided config (KV quantizers must exist on both before checkpointing), runs a sharded save/load roundtrip, and asserts that k_bmm_quantizer and v_bmm_quantizer instances are present and that their internal `_amax` tensors (when present) match between the reference and restored models. + Parameters: + tmp_path (pathlib.Path): Temporary directory to write sharded checkpoints. + config (dict): Quantization configuration dictionary to use for mtq.quantize. + rank (int): Distributed process rank for this helper. + size (int): Tensor-model-parallel size / world size used to initialize Megatron. + """ + # Disable output_layer quantization (same as other sharded state dict tests) + config["quant_cfg"]["*output_layer*"] = {"enable": False} + + initialize_for_megatron( + tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED + ) + # Create GPT models with TEDotProductAttention (transformer_impl="modelopt") model_ref = get_mcore_gpt_model( tensor_model_parallel_size=size, @@ -432,7 +468,7 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): vocab_size=64, transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention ).cuda() - + model_test = get_mcore_gpt_model( tensor_model_parallel_size=size, num_layers=2, @@ -441,29 +477,40 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): vocab_size=64, transformer_impl="modelopt", ).cuda() - - prompt_tokens = torch.randint(0, model_ref.vocab_size, (2, model_ref.max_sequence_length)).cuda() - + + prompt_tokens = torch.randint( + 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) + ).cuda() + def forward_fn(model): + """ + Run megatron_prefill with the predefined `prompt_tokens` on the given model. + + Parameters: + model: The Megatron model to execute the prefill pass on. + + Returns: + The outputs produced by `megatron_prefill` for the provided model. + """ return megatron_prefill(model, prompt_tokens) - + # Quantize the reference model model_ref = mtq.quantize(model_ref, config, forward_fn) - + # CRITICAL: model_test must also be quantized with the same config # Otherwise it won't have the KV cache quantizer keys when loading state dict model_test = mtq.quantize(model_test, config, forward_fn) - + # Verify KV cache quantizers were created kv_quantizers_found = False for name, module in model_ref.named_modules(): - if hasattr(module, 'k_bmm_quantizer') and hasattr(module, 'v_bmm_quantizer'): + if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): kv_quantizers_found = True assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" - + assert kv_quantizers_found, "No KV cache quantizers found in quantized model" - + # Test sharded state dict save/load sharded_state_dict_test_helper( tmp_path, @@ -473,32 +520,38 @@ def forward_fn(model): meta_device=False, version=None, ) - + # Verify KV cache quantizers are restored correctly in model_test for (name_ref, module_ref), (name_test, module_test) in zip( model_ref.named_modules(), model_test.named_modules() ): - if hasattr(module_ref, 'k_bmm_quantizer'): - assert hasattr(module_test, 'k_bmm_quantizer'), f"K quantizer missing after restore in {name_test}" - assert hasattr(module_test, 'v_bmm_quantizer'), f"V quantizer missing after restore in {name_test}" - + if hasattr(module_ref, "k_bmm_quantizer"): + assert hasattr(module_test, "k_bmm_quantizer"), ( + f"K quantizer missing after restore in {name_test}" + ) + assert hasattr(module_test, "v_bmm_quantizer"), ( + f"V quantizer missing after restore in {name_test}" + ) + # Check that quantizer states match - if hasattr(module_ref.k_bmm_quantizer, '_amax'): - assert hasattr(module_test.k_bmm_quantizer, '_amax'), f"K quantizer _amax missing in {name_test}" + if hasattr(module_ref.k_bmm_quantizer, "_amax"): + assert hasattr(module_test.k_bmm_quantizer, "_amax"), ( + f"K quantizer _amax missing in {name_test}" + ) if module_ref.k_bmm_quantizer._amax is not None: assert torch.allclose( - module_ref.k_bmm_quantizer._amax, - module_test.k_bmm_quantizer._amax + module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax ), f"K quantizer _amax mismatch in {name_test}" - - if hasattr(module_ref.v_bmm_quantizer, '_amax'): - assert hasattr(module_test.v_bmm_quantizer, '_amax'), f"V quantizer _amax missing in {name_test}" + + if hasattr(module_ref.v_bmm_quantizer, "_amax"): + assert hasattr(module_test.v_bmm_quantizer, "_amax"), ( + f"V quantizer _amax missing in {name_test}" + ) if module_ref.v_bmm_quantizer._amax is not None: assert torch.allclose( - module_ref.v_bmm_quantizer._amax, - module_test.v_bmm_quantizer._amax + module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax ), f"V quantizer _amax mismatch in {name_test}" - + @pytest.mark.parametrize( "config", @@ -508,17 +561,12 @@ def forward_fn(model): ], ) def test_kv_cache_quant(config): - """Verify KV cache quantization works correctly with TEDotProductAttention. - - This test ensures TEDotProductAttention is properly registered and gets the - expected q/k/v_bmm_quantizers when using KV cache configs. + """ + Verify that TEDotProductAttention modules gain the expected KV-cache quantizers when using KV cache configurations. - Note: This test requires Transformer Engine to be installed since TEDotProductAttention - is only available with transformer_impl="modelopt" or "transformer_engine" (not "local"). + Runs the KV-cache quantization smoke test (via a single-process multiprocess spawn) and requires Transformer Engine or modelopt since TEDotProductAttention is not available with the local transformer implementation. """ - spawn_multiprocess_job( - size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl" - ) + spawn_multiprocess_job(size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl") @pytest.mark.parametrize( @@ -529,15 +577,14 @@ def test_kv_cache_quant(config): ], ) def test_kv_cache_sharded_state_dict(tmp_path, config): - """Test KV cache quantization with sharded state dict save/load. + """ + Run a sharded-state-dict save/load test for KV-cache quantized models. - This test verifies the complete workflow of saving and loading KV cache quantized - models with distributed checkpointing, ensuring quantizer states are properly - preserved across the save/load cycle. + Spawns up to 2 processes using NCCL to execute the sharded-state-dict helper with the provided temporary path and quantization config. """ size = min(2, torch.cuda.device_count()) # Use 2 GPUs if available, else 1 spawn_multiprocess_job( size=size, job=partial(_test_kv_cache_sharded_state_dict_helper, tmp_path, config), - backend="nccl" + backend="nccl", )