- 
                Notifications
    
You must be signed in to change notification settings  - Fork 190
 
Support kv cache quantization for mcore using bmm_quantizers #375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
          
WalkthroughAdds quantization support for Megatron TEDotProductAttention with KV-cache (FP8/NVFP4), including a registered QuantModule, quantizer calibration, forward-path Q/K/V quantization, sharded state dict save/load, and post-restore handling. Introduces focused GPU tests for KV-cache quantization and state dict, and updates the changelog. Changes
 Sequence Diagram(s)sequenceDiagram
  autonumber
  participant C as Caller
  participant QTA as _QuantTEDotProductAttention
  participant QQ as Q quantizer
  participant KQ as K quantizer
  participant VQ as V quantizer
  participant TE as TEDotProductAttention (parent)
  C->>QTA: forward(query, key, value, ...)
  rect rgba(230,245,255,0.6)
    note right of QTA: Apply quantization if enabled
    QTA->>QQ: quantize(query)
    QQ-->>QTA: q_q
    QTA->>KQ: quantize(key)
    KQ-->>QTA: k_q
    QTA->>VQ: quantize(value)
    VQ-->>QTA: v_q
  end
  QTA->>TE: forward(q_q, k_q, v_q, ...)
  TE-->>QTA: attn_output
  QTA-->>C: attn_output
    sequenceDiagram
  autonumber
  participant M as Model with QTA
  participant SD as Sharded Checkpoint Utils
  participant S as Storage
  rect rgba(245,235,255,0.5)
    note over M: Save
    M->>SD: sharded_state_dict(prefix, offsets, metadata)
    SD-->>M: dict incl. non-quant params + quantizer state (amax, config)
    M->>S: write shards
  end
  rect rgba(235,255,235,0.5)
    note over M: Load/Restore
    S-->>M: read shards
    M->>SD: _load_from_state_dict(state_dict, prefix, ...)
    SD-->>M: remap _amax keys, reshape local quantizer states
    M->>M: modelopt_post_restore() -> validate/calibrate if needed
  end
    Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
 Pre-merge checks and finishing touches✅ Passed checks (3 passed)
 ✨ Finishing touches🧪 Generate unit tests (beta)
 📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
 🚧 Files skipped from review as they are similar to previous changes (1)
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment   | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️  Outside diff range comments (1)
modelopt/torch/quantization/plugins/megatron.py (1)
588-620: Fix is_enabled checks in post‑restore to avoid unconditional calibration.Call the method; otherwise it always evaluates truthy and may trigger unnecessary calibration.
Apply:
- if not hasattr(self, quantizer_name) or not quantizer.is_enabled: + if not hasattr(self, quantizer_name) or not quantizer.is_enabled(): continue
🧹 Nitpick comments (3)
modelopt/torch/quantization/plugins/__init__.py (1)
75-76: Gate the new mcore plugin on an importable package (likely megatron.core) and verify plugin file exists.
import_plugin("mcore")will only import.mcoreif a top-level module namedmcoreexists. That’s probably not what you want. If the plugin depends on Megatron-Core, gate on"megatron.core"(or whichever module actually guarantees availability) and ensuremodelopt/torch/quantization/plugins/mcore.pyexists.Proposed change:
-with import_plugin("mcore"): +with import_plugin("megatron.core"): from .mcore import *Also consider adding
- :meth:\mcore<modelopt.torch.quantization.plugins.mcore>`` to the plugin list in the module docstring for discoverability.To verify:
- Check that
 plugins/mcore.pyexists.- Confirm which import string (
 "megatron.core"vs"mcore") should gate plugin loading in your environments.modelopt/torch/quantization/plugins/megatron.py (2)
533-567: Sharded state dict: consider consistency for amax handling.You special‑case
_amaxby inserting directly, and usemake_sharded_tensors_for_checkpointfor other quantizer tensors with empty shard axes. If future bmm quantizers introduce sharded tensors (e.g., per‑channel), add shard axis mapping here to avoid shape mismatches after TP changes. No action required now, but keep in mind for NVFP4 evolution.
568-587: Redundant amax key remapping.
amax_keyandexpected_amax_keyare identical:f"{prefix}{quantizer_name}._amax". The rename is a no‑op.Apply:
- for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]: - full_prefix = f"{prefix}{quantizer_name}." - amax_key = f"{prefix}{quantizer_name}._amax" - - # If amax is in state_dict, rename it to the format expected by TensorQuantizer - if amax_key in state_dict: - expected_amax_key = f"{full_prefix}_amax" - state_dict[expected_amax_key] = state_dict.pop(amax_key) + # amax keys already match TensorQuantizer state naming; no remap needed
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/torch/quantization/plugins/__init__.py(1 hunks)modelopt/torch/quantization/plugins/megatron.py(3 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
initialize_for_megatron(385-393)get_mcore_gpt_model(133-208)sharded_state_dict_test_helper(410-457)modelopt/torch/utils/plugins/megatron_generate.py (1)
megatron_prefill(41-130)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
is_enabled(389-391)
modelopt/torch/quantization/plugins/megatron.py (5)
modelopt/torch/quantization/model_calib.py (1)
max_calibrate(61-173)modelopt/torch/quantization/nn/modules/quant_module.py (4)
QuantModule(37-96)_setup(118-126)_setup(163-169)modelopt_post_restore(40-69)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
TensorQuantizer(60-1143)is_enabled(389-391)reset_amax(252-256)forward(872-977)modelopt/torch/quantization/plugins/huggingface.py (17)
_setup(55-58)_setup(161-164)_setup(239-244)_setup(349-350)_setup(365-369)_setup(388-390)_setup(427-473)_setup(601-612)forward(71-119)forward(170-174)forward(255-256)forward(337-338)forward(352-361)forward(371-383)forward(393-423)forward(475-480)forward(643-649)modelopt/torch/quantization/plugins/custom.py (2)
modelopt_post_restore(117-174)_check_unsupported_states(127-133)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
 - GitHub Check: code-quality
 - GitHub Check: build-docs
 
🔇 Additional comments (6)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
235-237: Note acknowledged — separating KV-cache configs is reasonable.
521-531: LGTM — targeted KV‑cache test entrypoint.
540-553: LGTM — sharded state dict test for KV‑cache quantization.modelopt/torch/quantization/plugins/megatron.py (3)
25-25: Confirm TEDotProductAttention import path across supported Megatron‑Core/TE versions.
from megatron.core.extensions.transformer_engine import TEDotProductAttentionmay vary by version. If older MC/TE versions are supported, consider guarding this import (try/except) and registering conditionally.
467-476: LGTM — consistent quantizer setup for Q/K/V bmm.
520-532: LGTM — quantize Q/K/V post‑RoPE before delegating.
| def _test_kv_cache_quant_helper(config, rank, size): | ||
| """Helper function for testing KV cache quantization with TEDotProductAttention.""" | ||
| initialize_for_megatron( | ||
| tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED | ||
| ) | ||
| 
               | 
          ||
| # Use existing infrastructure to create a minimal GPT model with TEDotProductAttention | ||
| # Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention | ||
| model = get_mcore_gpt_model( | ||
| tensor_model_parallel_size=size, | ||
| num_layers=1, | ||
| hidden_size=64, | ||
| num_attention_heads=4, | ||
| vocab_size=32, | ||
| transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec | ||
| ).cuda() | ||
| 
               | 
          ||
| # Create dummy input for calibration | ||
| prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() | ||
| 
               | 
          ||
| def forward_fn(model): | ||
| return megatron_prefill(model, prompt_tokens) | ||
| 
               | 
          ||
| # Test KV cache quantization with the given config | ||
| quantized_model = mtq.quantize(model, config, forward_fn) | ||
| 
               | 
          ||
| # Find TEDotProductAttention modules and verify they have KV cache quantizers | ||
| te_attention_found = False | ||
| for name, module in quantized_model.named_modules(): | ||
| # Check if this is a quantized TEDotProductAttention | ||
| if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"): | ||
| te_attention_found = True | ||
| # Verify all expected quantizers exist | ||
| assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}" | ||
| 
               | 
          ||
| # Verify K and V quantizers are enabled (main purpose of KV cache configs) | ||
| assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" | ||
| assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" | ||
| 
               | 
          ||
| assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model" | ||
| 
               | 
          ||
| # Quick smoke test that forward still works | ||
| output = forward_fn(quantized_model) | ||
| assert output is not None, "Forward pass failed" | ||
| 
               | 
          
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix assertions: call TensorQuantizer.is_enabled as a method.
is_enabled is a method, not a property. As written, the assertions will always pass because a bound method is truthy.
Apply:
-            assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
-            assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+            assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+            assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _test_kv_cache_quant_helper(config, rank, size): | |
| """Helper function for testing KV cache quantization with TEDotProductAttention.""" | |
| initialize_for_megatron( | |
| tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED | |
| ) | |
| # Use existing infrastructure to create a minimal GPT model with TEDotProductAttention | |
| # Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention | |
| model = get_mcore_gpt_model( | |
| tensor_model_parallel_size=size, | |
| num_layers=1, | |
| hidden_size=64, | |
| num_attention_heads=4, | |
| vocab_size=32, | |
| transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec | |
| ).cuda() | |
| # Create dummy input for calibration | |
| prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda() | |
| def forward_fn(model): | |
| return megatron_prefill(model, prompt_tokens) | |
| # Test KV cache quantization with the given config | |
| quantized_model = mtq.quantize(model, config, forward_fn) | |
| # Find TEDotProductAttention modules and verify they have KV cache quantizers | |
| te_attention_found = False | |
| for name, module in quantized_model.named_modules(): | |
| # Check if this is a quantized TEDotProductAttention | |
| if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"): | |
| te_attention_found = True | |
| # Verify all expected quantizers exist | |
| assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}" | |
| # Verify K and V quantizers are enabled (main purpose of KV cache configs) | |
| assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" | |
| assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" | |
| assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model" | |
| # Quick smoke test that forward still works | |
| output = forward_fn(quantized_model) | |
| assert output is not None, "Forward pass failed" | |
| for name, module in quantized_model.named_modules(): | |
| # Check if this is a quantized TEDotProductAttention | |
| if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"): | |
| te_attention_found = True | |
| # Verify all expected quantizers exist | |
| assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}" | |
| # Verify K and V quantizers are enabled (main purpose of KV cache configs) | |
| assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}" | |
| assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}" | 
🤖 Prompt for AI Agents
In tests/gpu/torch/quantization/plugins/test_megatron.py around lines 374 to
418, the assertions check TensorQuantizer.is_enabled as an attribute
(module.k_bmm_quantizer.is_enabled and module.v_bmm_quantizer.is_enabled) but
is_enabled is a method; change those assertions to call the method
(module.k_bmm_quantizer.is_enabled() and module.v_bmm_quantizer.is_enabled()) so
they evaluate correctly.
| def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): | ||
| """Helper for testing KV cache quantization with sharded state dict save/load.""" | ||
| # Disable output_layer quantization (same as other sharded state dict tests) | ||
| config["quant_cfg"]["*output_layer*"] = {"enable": False} | ||
| 
               | 
          ||
| initialize_for_megatron( | ||
| tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED | ||
| ) | ||
| 
               | 
          ||
| # Create GPT models with TEDotProductAttention (transformer_impl="modelopt") | ||
| model_ref = get_mcore_gpt_model( | ||
| tensor_model_parallel_size=size, | ||
| num_layers=2, # At least 2 layers to test multiple attention modules | ||
| hidden_size=64, | ||
| num_attention_heads=4, | ||
| vocab_size=64, | ||
| transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention | ||
| ).cuda() | ||
| 
               | 
          ||
| model_test = get_mcore_gpt_model( | ||
| tensor_model_parallel_size=size, | ||
| num_layers=2, | ||
| hidden_size=64, | ||
| num_attention_heads=4, | ||
| vocab_size=64, | ||
| transformer_impl="modelopt", | ||
| ).cuda() | ||
| 
               | 
          ||
| prompt_tokens = torch.randint( | ||
| 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) | ||
| ).cuda() | ||
| 
               | 
          ||
| def forward_fn(model): | ||
| return megatron_prefill(model, prompt_tokens) | ||
| 
               | 
          ||
| # Quantize the reference model | ||
| model_ref = mtq.quantize(model_ref, config, forward_fn) | ||
| 
               | 
          ||
| # CRITICAL: model_test must also be quantized with the same config | ||
| # Otherwise it won't have the KV cache quantizer keys when loading state dict | ||
| model_test = mtq.quantize(model_test, config, forward_fn) | ||
| 
               | 
          ||
| # Verify KV cache quantizers were created | ||
| kv_quantizers_found = False | ||
| for name, module in model_ref.named_modules(): | ||
| if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): | ||
| kv_quantizers_found = True | ||
| assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" | ||
| assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" | ||
| 
               | 
          ||
| assert kv_quantizers_found, "No KV cache quantizers found in quantized model" | ||
| 
               | 
          ||
| # Test sharded state dict save/load | ||
| sharded_state_dict_test_helper( | ||
| tmp_path, | ||
| model_ref, | ||
| model_test, | ||
| forward_fn, | ||
| meta_device=False, | ||
| version=None, | ||
| ) | ||
| 
               | 
          ||
| # Verify KV cache quantizers are restored correctly in model_test | ||
| for (name_ref, module_ref), (name_test, module_test) in zip( | ||
| model_ref.named_modules(), model_test.named_modules() | ||
| ): | ||
| if hasattr(module_ref, "k_bmm_quantizer"): | ||
| assert hasattr(module_test, "k_bmm_quantizer"), ( | ||
| f"K quantizer missing after restore in {name_test}" | ||
| ) | ||
| assert hasattr(module_test, "v_bmm_quantizer"), ( | ||
| f"V quantizer missing after restore in {name_test}" | ||
| ) | ||
| 
               | 
          ||
| # Check that quantizer states match | ||
| if hasattr(module_ref.k_bmm_quantizer, "_amax"): | ||
| assert hasattr(module_test.k_bmm_quantizer, "_amax"), ( | ||
| f"K quantizer _amax missing in {name_test}" | ||
| ) | ||
| if module_ref.k_bmm_quantizer._amax is not None: | ||
| assert torch.allclose( | ||
| module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax | ||
| ), f"K quantizer _amax mismatch in {name_test}" | ||
| 
               | 
          ||
| if hasattr(module_ref.v_bmm_quantizer, "_amax"): | ||
| assert hasattr(module_test.v_bmm_quantizer, "_amax"), ( | ||
| f"V quantizer _amax missing in {name_test}" | ||
| ) | ||
| if module_ref.v_bmm_quantizer._amax is not None: | ||
| assert torch.allclose( | ||
| module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax | ||
| ), f"V quantizer _amax mismatch in {name_test}" | ||
| 
               | 
          ||
| 
               | 
          
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix is_enabled checks in sharded state dict helper.
Same issue: call is_enabled().
Apply:
-            assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
-            assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+            assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+            assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): | |
| """Helper for testing KV cache quantization with sharded state dict save/load.""" | |
| # Disable output_layer quantization (same as other sharded state dict tests) | |
| config["quant_cfg"]["*output_layer*"] = {"enable": False} | |
| initialize_for_megatron( | |
| tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED | |
| ) | |
| # Create GPT models with TEDotProductAttention (transformer_impl="modelopt") | |
| model_ref = get_mcore_gpt_model( | |
| tensor_model_parallel_size=size, | |
| num_layers=2, # At least 2 layers to test multiple attention modules | |
| hidden_size=64, | |
| num_attention_heads=4, | |
| vocab_size=64, | |
| transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention | |
| ).cuda() | |
| model_test = get_mcore_gpt_model( | |
| tensor_model_parallel_size=size, | |
| num_layers=2, | |
| hidden_size=64, | |
| num_attention_heads=4, | |
| vocab_size=64, | |
| transformer_impl="modelopt", | |
| ).cuda() | |
| prompt_tokens = torch.randint( | |
| 0, model_ref.vocab_size, (2, model_ref.max_sequence_length) | |
| ).cuda() | |
| def forward_fn(model): | |
| return megatron_prefill(model, prompt_tokens) | |
| # Quantize the reference model | |
| model_ref = mtq.quantize(model_ref, config, forward_fn) | |
| # CRITICAL: model_test must also be quantized with the same config | |
| # Otherwise it won't have the KV cache quantizer keys when loading state dict | |
| model_test = mtq.quantize(model_test, config, forward_fn) | |
| # Verify KV cache quantizers were created | |
| kv_quantizers_found = False | |
| for name, module in model_ref.named_modules(): | |
| if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): | |
| kv_quantizers_found = True | |
| assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" | |
| assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" | |
| assert kv_quantizers_found, "No KV cache quantizers found in quantized model" | |
| # Test sharded state dict save/load | |
| sharded_state_dict_test_helper( | |
| tmp_path, | |
| model_ref, | |
| model_test, | |
| forward_fn, | |
| meta_device=False, | |
| version=None, | |
| ) | |
| # Verify KV cache quantizers are restored correctly in model_test | |
| for (name_ref, module_ref), (name_test, module_test) in zip( | |
| model_ref.named_modules(), model_test.named_modules() | |
| ): | |
| if hasattr(module_ref, "k_bmm_quantizer"): | |
| assert hasattr(module_test, "k_bmm_quantizer"), ( | |
| f"K quantizer missing after restore in {name_test}" | |
| ) | |
| assert hasattr(module_test, "v_bmm_quantizer"), ( | |
| f"V quantizer missing after restore in {name_test}" | |
| ) | |
| # Check that quantizer states match | |
| if hasattr(module_ref.k_bmm_quantizer, "_amax"): | |
| assert hasattr(module_test.k_bmm_quantizer, "_amax"), ( | |
| f"K quantizer _amax missing in {name_test}" | |
| ) | |
| if module_ref.k_bmm_quantizer._amax is not None: | |
| assert torch.allclose( | |
| module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax | |
| ), f"K quantizer _amax mismatch in {name_test}" | |
| if hasattr(module_ref.v_bmm_quantizer, "_amax"): | |
| assert hasattr(module_test.v_bmm_quantizer, "_amax"), ( | |
| f"V quantizer _amax missing in {name_test}" | |
| ) | |
| if module_ref.v_bmm_quantizer._amax is not None: | |
| assert torch.allclose( | |
| module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax | |
| ), f"V quantizer _amax mismatch in {name_test}" | |
| for name, module in model_ref.named_modules(): | |
| if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): | |
| kv_quantizers_found = True | |
| assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}" | |
| assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}" | 
🤖 Prompt for AI Agents
In tests/gpu/torch/quantization/plugins/test_megatron.py around lines 420-513,
the assertions that verify KV cache quantizers use the attribute access
.is_enabled instead of calling the method; update those assertions to call
module.k_bmm_quantizer.is_enabled() and module.v_bmm_quantizer.is_enabled() so
the checks invoke the method rather than reference a non-callable attribute.
Docstrings generation was requested by @kaix-nv. * #375 (comment) The following files were modified: * `modelopt/torch/quantization/plugins/megatron.py` * `tests/gpu/torch/quantization/plugins/test_megatron.py`
| 
           Note Generated docstrings for this pull request at #376  | 
    
| 
          
 Note Unit test generation is an Early Access feature. Expect some limitations and changes as we gather feedback and continue to improve it. Generating unit tests... This may take up to 20 minutes.  | 
    
| 
           ✅ UTG Post-Process Complete No new issues were detected in the generated code and all check runs have completed. The unit test generation process has completed successfully.  | 
    
| 
          
 Creating a PR to put the unit tests in... The changes have been created in this pull request: View PR  | 
    
93bab27    to
    5fc56fe      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
modelopt/torch/quantization/plugins/__init__.py(1 hunks)modelopt/torch/quantization/plugins/megatron.py(3 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/plugins/init.py
 
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-18T20:15:04.615Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#332
File: modelopt/torch/quantization/algorithms.py:323-326
Timestamp: 2025-09-18T20:15:04.615Z
Learning: In modelopt/torch/quantization/algorithms.py, the `_is_auto_quantize_module` method requires `isinstance(module, QuantModule)` because some modules like MCore Column/Row Parallel Linear are `QuantModule` but not `QuantLinearConvBase`. The check ensures all quantization-capable modules are included in AutoQuantize search.
Applied to files:
tests/gpu/torch/quantization/plugins/test_megatron.py
🧬 Code graph analysis (2)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
initialize_for_megatron(385-393)get_mcore_gpt_model(133-208)sharded_state_dict_test_helper(410-457)modelopt/torch/utils/plugins/megatron_generate.py (1)
megatron_prefill(41-130)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
is_enabled(389-391)
modelopt/torch/quantization/plugins/megatron.py (5)
modelopt/torch/quantization/model_calib.py (1)
max_calibrate(61-173)modelopt/torch/quantization/nn/modules/quant_module.py (4)
QuantModule(37-96)_setup(118-126)_setup(163-169)modelopt_post_restore(40-69)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
TensorQuantizer(60-1143)is_enabled(389-391)reset_amax(252-256)forward(872-977)modelopt/torch/quantization/plugins/huggingface.py (17)
_setup(55-58)_setup(161-164)_setup(239-244)_setup(349-350)_setup(365-369)_setup(388-390)_setup(427-473)_setup(601-612)forward(71-119)forward(170-174)forward(255-256)forward(337-338)forward(352-361)forward(371-383)forward(393-423)forward(475-480)forward(643-649)modelopt/torch/quantization/plugins/custom.py (2)
modelopt_post_restore(117-174)_check_unsupported_states(127-133)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
 - GitHub Check: code-quality
 - GitHub Check: build-docs
 
🔇 Additional comments (3)
tests/gpu/torch/quantization/plugins/test_megatron.py (2)
409-411: Fix assertions: call TensorQuantizer.is_enabled()Line 410 still treats
is_enabledas an attribute, so the assertion always passes because a bound method is truthy. Please invoke it.- assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" - assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" + assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}" + assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
467-468: Fix assertions: call TensorQuantizer.is_enabled()Line 467 repeats the same issue—
is_enabledmust be invoked. Without the call these checks never fail even when the quantizers are disabled.- assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" - assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" + assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}" + assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"modelopt/torch/quantization/plugins/megatron.py (1)
483-488: Calibrator device/dtype inference must handle param/buffer-less modulesIn
_calibrate_quantizers(Lines 486-488),param = next(iter(self.parameters()), None)falls back totorch.device("cuda")andtorch.float16whenever the module has no parameters. Many wrapper modules expose only buffers (or can be instantiated on CPU/meta first), so this default hits real code paths and will crash on hosts without CUDA. Please probe both parameters and buffers, and only prefer CUDA if it is actually available.- param = next(iter(self.parameters()), None) - device = param.device if param is not None else torch.device("cuda") - dtype = param.dtype if param is not None else torch.float16 + param = next(self.parameters(recurse=False), None) + buf = next(self.buffers(recurse=False), None) + if param is not None: + device, dtype = param.device, param.dtype + elif buf is not None: + device, dtype = buf.device, buf.dtype + else: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + dtype = torch.float16 if device.type == "cuda" else torch.float32This matches the safe pattern already used elsewhere in the plugin and avoids hard-crashing on CPU/meta setups.
| import megatron.core.tensor_parallel.layers as megatron_parallel | ||
| import megatron.core.transformer.mlp as megatron_mlp | ||
| import torch | ||
| from megatron.core.extensions.transformer_engine import TEDotProductAttention | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard TEDotProductAttention import when TE is unavailable
Line 25 introduces an unconditional import of TEDotProductAttention. In environments where Transformer Engine is not installed (which is a supported configuration for the rest of this plugin), this raises ModuleNotFoundError during import and breaks all Megatron quantization paths that previously worked without TE. Please wrap the import in a try/except and only register _QuantTEDotProductAttention when the symbol is available.
-from megatron.core.extensions.transformer_engine import TEDotProductAttention
+try:
+    from megatron.core.extensions.transformer_engine import TEDotProductAttention
+except ImportError:
+    TEDotProductAttention = NoneAnd guard the registration/class definition behind if TEDotProductAttention is not None: so the module remains usable without TE.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from megatron.core.extensions.transformer_engine import TEDotProductAttention | |
| try: | |
| from megatron.core.extensions.transformer_engine import TEDotProductAttention | |
| except ImportError: | |
| TEDotProductAttention = None | 
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/megatron.py around line 25, the
unconditional import of TEDotProductAttention will raise ModuleNotFoundError
when Transformer Engine (TE) is not installed; wrap the import in a try/except
that sets TEDotProductAttention = None on ImportError, then only define and
register the _QuantTEDotProductAttention class (and any related registration
calls) inside an if TEDotProductAttention is not None: block so the module
remains importable and Megatron quantization works without TE.
| config["quant_cfg"]["*output_layer*"] = {"enable": False} | ||
| 
               | 
          
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not mutate shared config constants in-place
Line 423 mutates the shared config object (mtq.FP8_KV_CFG / mtq.NVFP4_KV_CFG) by updating quant_cfg. Because those constants are reused across parametrized tests (and potentially by library consumers), this causes cross-test contamination and unexpected behavior later in the suite. Make a copy before modifying.
-def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
+def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
+    config = copy.deepcopy(config)
     """Helper for testing KV cache quantization with sharded state dict save/load."""
     # Disable output_layer quantization (same as other sharded state dict tests)
-    config["quant_cfg"]["*output_layer*"] = {"enable": False}
+    config["quant_cfg"]["*output_layer*"] = {"enable": False}🤖 Prompt for AI Agents
In tests/gpu/torch/quantization/plugins/test_megatron.py around lines 423-424,
the test mutates the shared config constant by assigning
config["quant_cfg"]["*output_layer*"] = {"enable": False}; instead, avoid
in-place mutation by making a copy (use copy.deepcopy(config) or dict deepcopy)
into a local variable and mutate that copy (or create a shallow copy of
quant_cfg and update it) so the original mtq.FP8_KV_CFG / mtq.NVFP4_KV_CFG
remain unchanged; replace subsequent uses in the test with the copied config.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Signed-off-by: Kai Xu <[email protected]>
Signed-off-by: Kai Xu <[email protected]>
5fc56fe    to
    0ecf711      
    Compare
  
    There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (3)
modelopt/torch/quantization/plugins/megatron.py (1)
25-25: Guard TEDotProductAttention import when TE is unavailableLine 25 introduces an unconditional import of
TEDotProductAttention. In environments where Transformer Engine is not installed (which is a supported configuration for the rest of this plugin), this raisesModuleNotFoundErrorduring import and breaks all Megatron quantization paths that previously worked without TE. Please wrap the import in atry/exceptand only register_QuantTEDotProductAttentionwhen the symbol is available.-from megatron.core.extensions.transformer_engine import TEDotProductAttention +try: + from megatron.core.extensions.transformer_engine import TEDotProductAttention +except ImportError: + TEDotProductAttention = NoneAnd guard the registration/class definition behind
if TEDotProductAttention is not None:so the module remains usable without TE.tests/gpu/torch/quantization/plugins/test_megatron.py (2)
368-412: Fix assertions: call TensorQuantizer.is_enabled as a method.
is_enabledis a method, not a property. As written, the assertions at lines 404-405 will always pass because a bound method is truthy.Apply:
- assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" - assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" + assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}" + assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
414-506: Fix is_enabled checks and config mutationTwo issues:
- Lines 461-462:
 is_enabledmust be called as a method, not accessed as a property.- Line 417: Mutating the shared
 configobject causes cross-test contamination.Apply:
def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size): """Helper for testing KV cache quantization with sharded state dict save/load.""" + config = copy.deepcopy(config) # Disable output_layer quantization (same as other sharded state dict tests) config["quant_cfg"]["*output_layer*"] = {"enable": False}And:
- assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}" - assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}" + assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}" + assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
🧹 Nitpick comments (1)
modelopt/torch/quantization/plugins/megatron.py (1)
485-528: Consider usingrecurse=Falsefor device detectionThe device detection logic is safe with the current implementation using
next(iter(self.parameters()), None). However, for clarity and to avoid unintended recursion into child modules (like the quantizers themselves), consider usingrecurse=False:- param = next(iter(self.parameters()), None) + param = next(self.parameters(recurse=False), None) + buf = next(self.buffers(recurse=False), None)Note: The
is_enabled()method call at line 525 is correct (past issue resolved).
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
CHANGELOG.rst(1 hunks)modelopt/torch/quantization/plugins/__init__.py(1 hunks)modelopt/torch/quantization/plugins/megatron.py(3 hunks)tests/gpu/torch/quantization/plugins/test_megatron.py(2 hunks)
✅ Files skipped from review due to trivial changes (1)
- CHANGELOG.rst
 
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/quantization/plugins/init.py
 
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/quantization/plugins/megatron.py (4)
modelopt/torch/quantization/model_calib.py (1)
max_calibrate(62-174)modelopt/torch/quantization/nn/modules/quant_module.py (2)
QuantModule(37-96)modelopt_post_restore(40-69)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
TensorQuantizer(65-1111)is_enabled(395-397)reset_amax(257-262)forward(847-946)modelopt/torch/quantization/plugins/custom.py (2)
modelopt_post_restore(117-174)_check_unsupported_states(127-133)
tests/gpu/torch/quantization/plugins/test_megatron.py (5)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
initialize_for_megatron(392-400)get_mcore_gpt_model(134-209)sharded_state_dict_test_helper(417-464)modelopt/torch/utils/plugins/megatron_generate.py (1)
megatron_prefill(41-130)modelopt/torch/quantization/model_quant.py (1)
quantize(136-231)modelopt/torch/trace/symbols.py (1)
named_modules(444-447)modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
is_enabled(395-397)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
 - GitHub Check: wait-checks / wait
 - GitHub Check: wait-checks / wait
 - GitHub Check: code-quality
 - GitHub Check: build-docs
 
🔇 Additional comments (9)
modelopt/torch/quantization/plugins/megatron.py (6)
37-38: LGTM: Public quantization importsThe new imports for
max_calibrate,QuantModule,QuantModuleRegistry, andTensorQuantizerare appropriate for implementing KV cache quantization support.
479-483: LGTM: Quantizer initializationThe
_setupmethod correctly initializes threeTensorQuantizerinstances for Q, K, and V tensors.
530-541: LGTM: Forward pass quantizationThe forward method correctly applies post-RoPE quantization to Q, K, and V tensors before invoking the parent
TEDotProductAttention.forward(). This aligns with the KV cache quantization design.
543-576: LGTM: Sharded state dict implementationThe
sharded_state_dictmethod correctly handles:
- Non-quantizer parameters (lines 548-550)
 - Quantizer
 _amaxvalues (lines 553-560)- Other quantizer states via
 make_sharded_tensors_for_checkpoint(lines 562-574)This follows the established pattern for distributed checkpointing in Megatron models.
578-596: LGTM: State dict loading logicThe
_load_from_state_dictmethod correctly:
- Remaps
 _amaxkeys to match TensorQuantizer expectations (lines 580-587)- Reshapes quantizer states to fit local shapes (lines 589-594)
 This handles distributed checkpoint restore scenarios properly.
598-630: LGTM: Post-restore calibrationThe
modelopt_post_restoremethod appropriately:
- Validates quantizer states and warns about unsupported entries (lines 602-612)
 - Triggers calibration when
 _amaxis missing (lines 614-630)This ensures quantizers are properly configured after restoration, especially when tensor parallelism changes between save and load.
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
233-234: LGTM: Test separation rationaleThe comment appropriately explains why KV cache configs are tested separately—they require
transformer_impl="modelopt"forTEDotProductAttention, not the"local"implementation used in other tests.
508-524: LGTM: KV cache quantization testThe test correctly spawns a process to verify KV cache quantization with
TEDotProductAttention. The test coverage includes:
- Model creation with
 transformer_impl="modelopt"- Quantizer presence verification
 - Forward pass smoke test
 
527-546: LGTM: Sharded state dict test for KV cacheThe test properly validates the complete save/load workflow for KV-quantized models with distributed checkpointing, including quantizer state preservation across the cycle.
0ecf711    to
    2fed9db      
    Compare
  
    2fed9db    to
    0482373      
    Compare
  
    Signed-off-by: Kai Xu <[email protected]>
0482373    to
    d2c05f2      
    Compare
  
    
What does this PR do?
Type of change: ?
New feature
Overview: ?
Usage
Testing
tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_quant[config1]
tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_quant[config0]
tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_sharded_state_dict[config0] PASSED
tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_sharded_state_dict[config1] PASSED
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Tests
Documentation