From 93bab27a5306b5a1f163893c22b71e884d31d403 Mon Sep 17 00:00:00 2001 From: Kai Xu Date: Thu, 25 Sep 2025 12:11:53 -0700 Subject: [PATCH 1/2] add kv cache quantization for mcore using bmm_quantizers Signed-off-by: Kai Xu --- .../torch/quantization/plugins/__init__.py | 2 +- .../torch/quantization/plugins/megatron.py | 43 ++++--- .../quantization/plugins/test_megatron.py | 107 ++++++++++-------- 3 files changed, 85 insertions(+), 67 deletions(-) diff --git a/modelopt/torch/quantization/plugins/__init__.py b/modelopt/torch/quantization/plugins/__init__.py index d1451e37d..ac63582bb 100644 --- a/modelopt/torch/quantization/plugins/__init__.py +++ b/modelopt/torch/quantization/plugins/__init__.py @@ -73,4 +73,4 @@ from .trl import * with import_plugin("mcore"): - from .mcore import * \ No newline at end of file + from .mcore import * diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 3452f27bc..0d0c1f7ea 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -22,11 +22,11 @@ import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp import torch +from megatron.core.extensions.transformer_engine import TEDotProductAttention from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region from megatron.core.transformer import MegatronModule from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint from megatron.core.utils import get_tensor_model_parallel_group_if_none -from megatron.core.extensions.transformer_engine import TEDotProductAttention from modelopt.torch.opt.plugins.megatron import ( _MegatronMLP, @@ -34,11 +34,11 @@ ) from modelopt.torch.utils.distributed import ParallelState +from ..model_calib import max_calibrate from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer from ..nn.modules.quant_linear import RealQuantLinear from ..qtensor import QTensorWrapper from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear -from ..model_calib import max_calibrate __all__ = [] @@ -477,24 +477,32 @@ def _setup(self): def _calibrate_quantizers(self): """Calibrate quantizers with minimal dummy tensors.""" # Get device from parent module parameters - device = next(self.parameters()).device if self.parameters() else torch.device('cuda') - + device = next(self.parameters()).device if self.parameters() else torch.device("cuda") + # TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion batch_size = 1 seq_len = 1 - + # Get dimensions from config num_heads = self.config.num_attention_heads - head_dim = self.config.kv_channels if hasattr(self.config, 'kv_channels') else self.config.hidden_size // num_heads - + head_dim = ( + self.config.kv_channels + if hasattr(self.config, "kv_channels") + else self.config.hidden_size // num_heads + ) + # Determine tensor format (default to sbhd if not specified) - apply_rope_fusion = getattr(self.config, 'apply_rope_fusion', False) + apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False) qkv_format = "bshd" if apply_rope_fusion else "sbhd" if qkv_format == "sbhd": - dummy_tensor = torch.randn(seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16) + dummy_tensor = torch.randn( + seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16 + ) else: - dummy_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16) + dummy_tensor = torch.randn( + batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16 + ) # Calibrate each quantizer quantizers = [ @@ -511,7 +519,7 @@ def _calibrate_quantizers(self): def forward(self, query, key, value, *args, **kwargs): """Apply post-RoPE quantization to KV cache. - + TEDotProductAttention receives Q, K, V after RoPE is applied, so we quantize them directly for KV cache quantization. """ @@ -525,7 +533,7 @@ def forward(self, query, key, value, *args, **kwargs): def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): """Create a sharded state dictionary for distributed checkpointing.""" sharded_state_dict = {} - + # First add non-quantizer parameters for k, v in self.state_dict(prefix="", keep_vars=True).items(): if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k: @@ -542,10 +550,11 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): sharded_state_dict[amax_key] = quantizer._amax # Process other quantizer parameters in bmm_quantizers - quantizer_state_dict = {} - for k, v in self.state_dict(prefix="", keep_vars=True).items(): - if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k: - quantizer_state_dict[k] = v + quantizer_state_dict = { + k: v + for k, v in self.state_dict(prefix="", keep_vars=True).items() + if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k + } if quantizer_state_dict: sharded_state_dict.update( @@ -584,7 +593,7 @@ def _check_unsupported_states(quantizer): if not hasattr(quantizer, "state_dict"): return - for k in quantizer.state_dict().keys(): + for k in quantizer.state_dict(): if k not in ["_amax", "_pre_quant_scale"]: warnings.warn( f"Restore of {k} for {name} is not supported. The restore of this layer might be " diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index bac699481..ffc778cb9 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -17,7 +17,6 @@ import pytest import torch -import torch.nn as nn from _test_utils.import_helper import skip_if_no_megatron from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job from _test_utils.torch_dist.plugins.megatron_common import ( @@ -374,8 +373,10 @@ def test_fp8_real_quantize(): def _test_kv_cache_quant_helper(config, rank, size): """Helper function for testing KV cache quantization with TEDotProductAttention.""" - initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED) - + initialize_for_megatron( + tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED + ) + # Use existing infrastructure to create a minimal GPT model with TEDotProductAttention # Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention model = get_mcore_gpt_model( @@ -386,43 +387,45 @@ def _test_kv_cache_quant_helper(config, rank, size): vocab_size=32, transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec ).cuda() - + # Create dummy input for calibration prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() - + def forward_fn(model): return megatron_prefill(model, prompt_tokens) - + # Test KV cache quantization with the given config quantized_model = mtq.quantize(model, config, forward_fn) - + # Find TEDotProductAttention modules and verify they have KV cache quantizers te_attention_found = False for name, module in quantized_model.named_modules(): # Check if this is a quantized TEDotProductAttention - if hasattr(module, 'q_bmm_quantizer') and hasattr(module, 'k_bmm_quantizer'): + if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"): te_attention_found = True # Verify all expected quantizers exist - assert hasattr(module, 'v_bmm_quantizer'), f"Missing v_bmm_quantizer in {name}" - + assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}" + # Verify K and V quantizers are enabled (main purpose of KV cache configs) assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" - + assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model" - + # Quick smoke test that forward still works output = forward_fn(quantized_model) assert output is not None, "Forward pass failed" - + def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): """Helper for testing KV cache quantization with sharded state dict save/load.""" # Disable output_layer quantization (same as other sharded state dict tests) config["quant_cfg"]["*output_layer*"] = {"enable": False} - - initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED) - + + initialize_for_megatron( + tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED + ) + # Create GPT models with TEDotProductAttention (transformer_impl="modelopt") model_ref = get_mcore_gpt_model( tensor_model_parallel_size=size, @@ -432,7 +435,7 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): vocab_size=64, transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention ).cuda() - + model_test = get_mcore_gpt_model( tensor_model_parallel_size=size, num_layers=2, @@ -441,29 +444,31 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): vocab_size=64, transformer_impl="modelopt", ).cuda() - - prompt_tokens = torch.randint(0, model_ref.vocab_size, (2, model_ref.max_sequence_length)).cuda() - + + prompt_tokens = torch.randint( + 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) + ).cuda() + def forward_fn(model): return megatron_prefill(model, prompt_tokens) - + # Quantize the reference model model_ref = mtq.quantize(model_ref, config, forward_fn) - + # CRITICAL: model_test must also be quantized with the same config # Otherwise it won't have the KV cache quantizer keys when loading state dict model_test = mtq.quantize(model_test, config, forward_fn) - + # Verify KV cache quantizers were created kv_quantizers_found = False for name, module in model_ref.named_modules(): - if hasattr(module, 'k_bmm_quantizer') and hasattr(module, 'v_bmm_quantizer'): + if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): kv_quantizers_found = True assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" - + assert kv_quantizers_found, "No KV cache quantizers found in quantized model" - + # Test sharded state dict save/load sharded_state_dict_test_helper( tmp_path, @@ -473,32 +478,38 @@ def forward_fn(model): meta_device=False, version=None, ) - + # Verify KV cache quantizers are restored correctly in model_test for (name_ref, module_ref), (name_test, module_test) in zip( model_ref.named_modules(), model_test.named_modules() ): - if hasattr(module_ref, 'k_bmm_quantizer'): - assert hasattr(module_test, 'k_bmm_quantizer'), f"K quantizer missing after restore in {name_test}" - assert hasattr(module_test, 'v_bmm_quantizer'), f"V quantizer missing after restore in {name_test}" - + if hasattr(module_ref, "k_bmm_quantizer"): + assert hasattr(module_test, "k_bmm_quantizer"), ( + f"K quantizer missing after restore in {name_test}" + ) + assert hasattr(module_test, "v_bmm_quantizer"), ( + f"V quantizer missing after restore in {name_test}" + ) + # Check that quantizer states match - if hasattr(module_ref.k_bmm_quantizer, '_amax'): - assert hasattr(module_test.k_bmm_quantizer, '_amax'), f"K quantizer _amax missing in {name_test}" + if hasattr(module_ref.k_bmm_quantizer, "_amax"): + assert hasattr(module_test.k_bmm_quantizer, "_amax"), ( + f"K quantizer _amax missing in {name_test}" + ) if module_ref.k_bmm_quantizer._amax is not None: assert torch.allclose( - module_ref.k_bmm_quantizer._amax, - module_test.k_bmm_quantizer._amax + module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax ), f"K quantizer _amax mismatch in {name_test}" - - if hasattr(module_ref.v_bmm_quantizer, '_amax'): - assert hasattr(module_test.v_bmm_quantizer, '_amax'), f"V quantizer _amax missing in {name_test}" + + if hasattr(module_ref.v_bmm_quantizer, "_amax"): + assert hasattr(module_test.v_bmm_quantizer, "_amax"), ( + f"V quantizer _amax missing in {name_test}" + ) if module_ref.v_bmm_quantizer._amax is not None: assert torch.allclose( - module_ref.v_bmm_quantizer._amax, - module_test.v_bmm_quantizer._amax + module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax ), f"V quantizer _amax mismatch in {name_test}" - + @pytest.mark.parametrize( "config", @@ -509,16 +520,14 @@ def forward_fn(model): ) def test_kv_cache_quant(config): """Verify KV cache quantization works correctly with TEDotProductAttention. - - This test ensures TEDotProductAttention is properly registered and gets the + + This test ensures TEDotProductAttention is properly registered and gets the expected q/k/v_bmm_quantizers when using KV cache configs. - + Note: This test requires Transformer Engine to be installed since TEDotProductAttention is only available with transformer_impl="modelopt" or "transformer_engine" (not "local"). """ - spawn_multiprocess_job( - size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl" - ) + spawn_multiprocess_job(size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl") @pytest.mark.parametrize( @@ -530,7 +539,7 @@ def test_kv_cache_quant(config): ) def test_kv_cache_sharded_state_dict(tmp_path, config): """Test KV cache quantization with sharded state dict save/load. - + This test verifies the complete workflow of saving and loading KV cache quantized models with distributed checkpointing, ensuring quantizer states are properly preserved across the save/load cycle. @@ -539,5 +548,5 @@ def test_kv_cache_sharded_state_dict(tmp_path, config): spawn_multiprocess_job( size=size, job=partial(_test_kv_cache_sharded_state_dict_helper, tmp_path, config), - backend="nccl" + backend="nccl", ) From 701ad04954b3fe15f0a7b0bb3fcc19fd67518691 Mon Sep 17 00:00:00 2001 From: "coderabbitai[bot]" <136622811+coderabbitai[bot]@users.noreply.github.com> Date: Thu, 25 Sep 2025 21:55:54 +0000 Subject: [PATCH 2/2] CodeRabbit Generated Unit Tests: Add GPU tests for Megatron quantization; review/refactor unit tests --- .../quantization/plugins/test_megatron.py | 155 ++++++++++++++++++ 1 file changed, 155 insertions(+) diff --git a/tests/gpu/torch/quantization/plugins/test_megatron.py b/tests/gpu/torch/quantization/plugins/test_megatron.py index ffc778cb9..646c2ebf9 100644 --- a/tests/gpu/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu/torch/quantization/plugins/test_megatron.py @@ -550,3 +550,158 @@ def test_kv_cache_sharded_state_dict(tmp_path, config): job=partial(_test_kv_cache_sharded_state_dict_helper, tmp_path, config), backend="nccl", ) + +# --------------------------------------------------------------------------- +# Additional tests appended by CodeRabbit Inc.: +# Note: This project uses pytest as the testing framework and PyTorch (with Megatron-Core) +# along with repository utilities in _test_utils for distributed setup/spawning. +# These tests focus on quantizer attribute propagation, KV-cache behavior with +# local transformer_impl, compress idempotency, and block-size application. +# --------------------------------------------------------------------------- + +def test_set_quantizer_attribute_propagates(distributed_setup_size_1): + initialize_for_megatron(seed=SEED) + set_seed(SEED) + model = MegatronModel().cuda() + + # Convert modules to quantized equivalents + mtq.replace_quant_module(model) + + # Disable all quantizers first, then re-enable specific ones and set attributes + mtq.set_quantizer_attribute(model, "*", {"enable": False}) + mtq.set_quantizer_attribute(model, "*weight_quantizer", {"enable": True, "num_bits": 8}) + mtq.set_quantizer_attribute(model, "*input_quantizer", {"enable": True}) + + saw_weight_bits = False + saw_input_enabled = False + for module in model.modules(): + if isinstance(module, (ColumnParallelLinear, RowParallelLinear)): + assert hasattr(module, "input_quantizer") + assert hasattr(module, "weight_quantizer") + + # Validate 'enable' took effect on input quantizers + iq = module.input_quantizer + if hasattr(iq, "is_enabled"): + assert iq.is_enabled + saw_input_enabled = True + elif hasattr(iq, "enable"): + assert iq.enable + saw_input_enabled = True + + # Validate num_bits propagation on weight quantizers when supported + wq = module.weight_quantizer + if hasattr(wq, "num_bits"): + assert wq.num_bits == 8 + saw_weight_bits = True + + assert saw_input_enabled, "Input quantizers did not report enabled state" + assert saw_weight_bits, "No weight quantizer reported num_bits attribute == 8" + + # Clean up since this is not a spawned process + destroy_model_parallel() + + +@pytest.mark.parametrize( + "config", + [ + mtq.FP8_KV_CFG, + mtq.NVFP4_KV_CFG, + ], +) +def test_kv_cache_quant_local_impl(distributed_setup_size_1, config): + """Negative scenario: using a 'local' transformer_impl should not create KV-cache + bmm quantizers (TEDotProductAttention not present).""" + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=SEED) + + model = get_mcore_gpt_model( + tensor_model_parallel_size=1, + num_layers=1, + hidden_size=64, + num_attention_heads=4, + vocab_size=32, + transformer_impl="local", + ).cuda() + + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + + def forward_fn(m): + return megatron_prefill(m, prompt_tokens) + + quantized_model = mtq.quantize(model, config, forward_fn) + + # Ensure no k/v bmm quantizers are present under 'local' impl + for name, module in quantized_model.named_modules(): + assert not ( + hasattr(module, "k_bmm_quantizer") or hasattr(module, "v_bmm_quantizer") + ), f"Unexpected KV cache quantizers found on local impl module: {name}" + + # Smoke test + out = forward_fn(quantized_model) + assert out is not None + + destroy_model_parallel() + + +def test_compress_idempotent(distributed_setup_size_1): + """Compressing an already-compressed model should be effectively idempotent.""" + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=SEED) + + hidden_size = 256 + config = mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG + + model = _gpt_model_provider(tp_size=1, hidden_size=hidden_size) + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + + def forward_fn(m): + return megatron_prefill(m, prompt_tokens) + + model = mtq.quantize(model, config, forward_fn) + + mem_before = get_model_size(model) + mtq.compress(model) + mem_after_first = get_model_size(model) + # Second compress should not increase memory and should change very little (if at all) + mtq.compress(model) + mem_after_second = get_model_size(model) + + assert mem_after_first <= mem_before, "compress() increased model size" + assert mem_after_second <= mem_after_first * 1.05, "Second compress() is not idempotent enough" + + # Forward still works + out = forward_fn(model) + assert out is not None + + destroy_model_parallel() + + +def test_mixed_block_sizes_attribute_application(distributed_setup_size_1): + """Verify mixed_block_size_config applies 'block_sizes' attributes to weight quantizers.""" + initialize_for_megatron(tensor_model_parallel_size=1, pipeline_model_parallel_size=1, seed=SEED) + + model = _gpt_model_provider(tp_size=1, hidden_size=128) + prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() + + def forward_fn(m): + return megatron_prefill(m, prompt_tokens) + + quant_model = mtq.quantize(model, mixed_block_size_config, forward_fn) + + found_expected_block_sizes = False + for name, module in quant_model.named_modules(): + wq = getattr(module, "weight_quantizer", None) + if wq is not None: + bs = getattr(wq, "block_sizes", None) + if isinstance(bs, dict) and (-1 in bs or -2 in bs): + # Expect values as configured in mixed_block_size_config + if (-1 in bs and bs[-1] in (64, 128)) or (-2 in bs and bs[-2] == 64): + found_expected_block_sizes = True + break + + assert found_expected_block_sizes, "No weight quantizer had expected block_sizes per mixed_block_size_config" + + # Smoke test + out = forward_fn(quant_model) + assert out is not None + + destroy_model_parallel() +