Skip to content

Conversation

@kaix-nv
Copy link
Contributor

@kaix-nv kaix-nv commented Sep 25, 2025

What does this PR do?

Type of change: ?
New feature

Overview: ?

Usage

pytest tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_quant -v
pytest tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_sharded_state_dict -v

Testing

tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_quant[config1]
tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_quant[config0]

tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_sharded_state_dict[config0] PASSED
tests/gpu/torch/quantization/plugins/test_megatron.py::test_kv_cache_sharded_state_dict[config1] PASSED

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: Yes

Additional Information

Summary by CodeRabbit

  • New Features

    • Added FP8/NVFP4 KV-cache quantization support for Megatron Core attention, enabling quantized K/V during inference with KV caching.
    • Enhanced quantized attention to expose and validate K/V quantizers and support save/load of their state in sharded checkpoints.
  • Tests

    • Added GPU tests covering KV-cache quantization and sharded state restoration for quantized attention modules.
  • Documentation

    • Updated changelog with a new entry announcing KV-cache quantization support for Megatron Core models.

@kaix-nv kaix-nv requested a review from a team as a code owner September 25, 2025 19:29
@kaix-nv kaix-nv requested a review from cjluo-nv September 25, 2025 19:29
@copy-pr-bot
Copy link

copy-pr-bot bot commented Sep 25, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 25, 2025

Walkthrough

Adds quantization support for Megatron TEDotProductAttention with KV-cache (FP8/NVFP4), including a registered QuantModule, quantizer calibration, forward-path Q/K/V quantization, sharded state dict save/load, and post-restore handling. Introduces focused GPU tests for KV-cache quantization and state dict, and updates the changelog.

Changes

Cohort / File(s) Summary
Megatron TE Attention Quantization
modelopt/torch/quantization/plugins/megatron.py
Adds _QuantTEDotProductAttention registered for TEDotProductAttention; implements setup, calibration, quantized forward for Q/K/V (incl. KV-cache), sharded state dict (with amax handling), load-time remapping/reshape, and post-restore calibration checks. Public imports added. Note: class block appears duplicated.
GPU Tests for KV Cache Quantization
tests/gpu/torch/quantization/plugins/test_megatron.py
Adds helper routines and tests to validate KV-cache quantization (FP8/NVFP4), presence/enabled state of K/V quantizers, forward smoke test, and sharded state dict round-trip consistency for quantizer state.
Changelog
CHANGELOG.rst
Adds v0.41 entry noting new FP8/NVFP4 KV cache quantization support for Megatron Core models.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant C as Caller
  participant QTA as _QuantTEDotProductAttention
  participant QQ as Q quantizer
  participant KQ as K quantizer
  participant VQ as V quantizer
  participant TE as TEDotProductAttention (parent)

  C->>QTA: forward(query, key, value, ...)
  rect rgba(230,245,255,0.6)
    note right of QTA: Apply quantization if enabled
    QTA->>QQ: quantize(query)
    QQ-->>QTA: q_q
    QTA->>KQ: quantize(key)
    KQ-->>QTA: k_q
    QTA->>VQ: quantize(value)
    VQ-->>QTA: v_q
  end
  QTA->>TE: forward(q_q, k_q, v_q, ...)
  TE-->>QTA: attn_output
  QTA-->>C: attn_output
Loading
sequenceDiagram
  autonumber
  participant M as Model with QTA
  participant SD as Sharded Checkpoint Utils
  participant S as Storage

  rect rgba(245,235,255,0.5)
    note over M: Save
    M->>SD: sharded_state_dict(prefix, offsets, metadata)
    SD-->>M: dict incl. non-quant params + quantizer state (amax, config)
    M->>S: write shards
  end

  rect rgba(235,255,235,0.5)
    note over M: Load/Restore
    S-->>M: read shards
    M->>SD: _load_from_state_dict(state_dict, prefix, ...)
    SD-->>M: remap _amax keys, reshape local quantizer states
    M->>M: modelopt_post_restore() -> validate/calibrate if needed
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

A rabbit taps keys by moonlit cache,
Byte-carrots line the quantized stash.
Q, K, V hop through FP8 snow,
NVFP4 footprints neatly in a row.
Shards nest snug where amax gleams—
Megatron dreams in tensor streams. 🐇✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title succinctly describes the primary change of adding key-value cache quantization support for mcore using bmm quantizers and aligns precisely with the objectives and code modifications in the pull request, making its intent clear to reviewers.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.
✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch kaix/kvcache_mcore

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 2fed9db and d2c05f2.

📒 Files selected for processing (1)
  • CHANGELOG.rst (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • CHANGELOG.rst

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/plugins/megatron.py (1)

588-620: Fix is_enabled checks in post‑restore to avoid unconditional calibration.

Call the method; otherwise it always evaluates truthy and may trigger unnecessary calibration.

Apply:

-            if not hasattr(self, quantizer_name) or not quantizer.is_enabled:
+            if not hasattr(self, quantizer_name) or not quantizer.is_enabled():
                 continue
🧹 Nitpick comments (3)
modelopt/torch/quantization/plugins/__init__.py (1)

75-76: Gate the new mcore plugin on an importable package (likely megatron.core) and verify plugin file exists.

import_plugin("mcore") will only import .mcore if a top-level module named mcore exists. That’s probably not what you want. If the plugin depends on Megatron-Core, gate on "megatron.core" (or whichever module actually guarantees availability) and ensure modelopt/torch/quantization/plugins/mcore.py exists.

Proposed change:

-with import_plugin("mcore"):
+with import_plugin("megatron.core"):
     from .mcore import *

Also consider adding - :meth:\mcore<modelopt.torch.quantization.plugins.mcore>`` to the plugin list in the module docstring for discoverability.

To verify:

  • Check that plugins/mcore.py exists.
  • Confirm which import string ("megatron.core" vs "mcore") should gate plugin loading in your environments.
modelopt/torch/quantization/plugins/megatron.py (2)

533-567: Sharded state dict: consider consistency for amax handling.

You special‑case _amax by inserting directly, and use make_sharded_tensors_for_checkpoint for other quantizer tensors with empty shard axes. If future bmm quantizers introduce sharded tensors (e.g., per‑channel), add shard axis mapping here to avoid shape mismatches after TP changes. No action required now, but keep in mind for NVFP4 evolution.


568-587: Redundant amax key remapping.

amax_key and expected_amax_key are identical: f"{prefix}{quantizer_name}._amax". The rename is a no‑op.

Apply:

-        for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]:
-            full_prefix = f"{prefix}{quantizer_name}."
-            amax_key = f"{prefix}{quantizer_name}._amax"
-
-            # If amax is in state_dict, rename it to the format expected by TensorQuantizer
-            if amax_key in state_dict:
-                expected_amax_key = f"{full_prefix}_amax"
-                state_dict[expected_amax_key] = state_dict.pop(amax_key)
+        # amax keys already match TensorQuantizer state naming; no remap needed
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 598b9ce and 93bab27.

📒 Files selected for processing (3)
  • modelopt/torch/quantization/plugins/__init__.py (1 hunks)
  • modelopt/torch/quantization/plugins/megatron.py (3 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
  • initialize_for_megatron (385-393)
  • get_mcore_gpt_model (133-208)
  • sharded_state_dict_test_helper (410-457)
modelopt/torch/utils/plugins/megatron_generate.py (1)
  • megatron_prefill (41-130)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • is_enabled (389-391)
modelopt/torch/quantization/plugins/megatron.py (5)
modelopt/torch/quantization/model_calib.py (1)
  • max_calibrate (61-173)
modelopt/torch/quantization/nn/modules/quant_module.py (4)
  • QuantModule (37-96)
  • _setup (118-126)
  • _setup (163-169)
  • modelopt_post_restore (40-69)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • TensorQuantizer (60-1143)
  • is_enabled (389-391)
  • reset_amax (252-256)
  • forward (872-977)
modelopt/torch/quantization/plugins/huggingface.py (17)
  • _setup (55-58)
  • _setup (161-164)
  • _setup (239-244)
  • _setup (349-350)
  • _setup (365-369)
  • _setup (388-390)
  • _setup (427-473)
  • _setup (601-612)
  • forward (71-119)
  • forward (170-174)
  • forward (255-256)
  • forward (337-338)
  • forward (352-361)
  • forward (371-383)
  • forward (393-423)
  • forward (475-480)
  • forward (643-649)
modelopt/torch/quantization/plugins/custom.py (2)
  • modelopt_post_restore (117-174)
  • _check_unsupported_states (127-133)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (6)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)

235-237: Note acknowledged — separating KV-cache configs is reasonable.


521-531: LGTM — targeted KV‑cache test entrypoint.


540-553: LGTM — sharded state dict test for KV‑cache quantization.

modelopt/torch/quantization/plugins/megatron.py (3)

25-25: Confirm TEDotProductAttention import path across supported Megatron‑Core/TE versions.

from megatron.core.extensions.transformer_engine import TEDotProductAttention may vary by version. If older MC/TE versions are supported, consider guarding this import (try/except) and registering conditionally.


467-476: LGTM — consistent quantizer setup for Q/K/V bmm.


520-532: LGTM — quantize Q/K/V post‑RoPE before delegating.

Comment on lines +374 to +412
def _test_kv_cache_quant_helper(config, rank, size):
"""Helper function for testing KV cache quantization with TEDotProductAttention."""
initialize_for_megatron(
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
)

# Use existing infrastructure to create a minimal GPT model with TEDotProductAttention
# Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention
model = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=1,
hidden_size=64,
num_attention_heads=4,
vocab_size=32,
transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec
).cuda()

# Create dummy input for calibration
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()

def forward_fn(model):
return megatron_prefill(model, prompt_tokens)

# Test KV cache quantization with the given config
quantized_model = mtq.quantize(model, config, forward_fn)

# Find TEDotProductAttention modules and verify they have KV cache quantizers
te_attention_found = False
for name, module in quantized_model.named_modules():
# Check if this is a quantized TEDotProductAttention
if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"):
te_attention_found = True
# Verify all expected quantizers exist
assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}"

# Verify K and V quantizers are enabled (main purpose of KV cache configs)
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"

assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model"

# Quick smoke test that forward still works
output = forward_fn(quantized_model)
assert output is not None, "Forward pass failed"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix assertions: call TensorQuantizer.is_enabled as a method.

is_enabled is a method, not a property. As written, the assertions will always pass because a bound method is truthy.

Apply:

-            assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
-            assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+            assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+            assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _test_kv_cache_quant_helper(config, rank, size):
"""Helper function for testing KV cache quantization with TEDotProductAttention."""
initialize_for_megatron(
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
)
# Use existing infrastructure to create a minimal GPT model with TEDotProductAttention
# Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention
model = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=1,
hidden_size=64,
num_attention_heads=4,
vocab_size=32,
transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec
).cuda()
# Create dummy input for calibration
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
def forward_fn(model):
return megatron_prefill(model, prompt_tokens)
# Test KV cache quantization with the given config
quantized_model = mtq.quantize(model, config, forward_fn)
# Find TEDotProductAttention modules and verify they have KV cache quantizers
te_attention_found = False
for name, module in quantized_model.named_modules():
# Check if this is a quantized TEDotProductAttention
if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"):
te_attention_found = True
# Verify all expected quantizers exist
assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}"
# Verify K and V quantizers are enabled (main purpose of KV cache configs)
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model"
# Quick smoke test that forward still works
output = forward_fn(quantized_model)
assert output is not None, "Forward pass failed"
for name, module in quantized_model.named_modules():
# Check if this is a quantized TEDotProductAttention
if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"):
te_attention_found = True
# Verify all expected quantizers exist
assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}"
# Verify K and V quantizers are enabled (main purpose of KV cache configs)
assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
🤖 Prompt for AI Agents
In tests/gpu/torch/quantization/plugins/test_megatron.py around lines 374 to
418, the assertions check TensorQuantizer.is_enabled as an attribute
(module.k_bmm_quantizer.is_enabled and module.v_bmm_quantizer.is_enabled) but
is_enabled is a method; change those assertions to call the method
(module.k_bmm_quantizer.is_enabled() and module.v_bmm_quantizer.is_enabled()) so
they evaluate correctly.

Comment on lines +420 to +507
def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
"""Helper for testing KV cache quantization with sharded state dict save/load."""
# Disable output_layer quantization (same as other sharded state dict tests)
config["quant_cfg"]["*output_layer*"] = {"enable": False}

initialize_for_megatron(
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
)

# Create GPT models with TEDotProductAttention (transformer_impl="modelopt")
model_ref = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=2, # At least 2 layers to test multiple attention modules
hidden_size=64,
num_attention_heads=4,
vocab_size=64,
transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention
).cuda()

model_test = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=2,
hidden_size=64,
num_attention_heads=4,
vocab_size=64,
transformer_impl="modelopt",
).cuda()

prompt_tokens = torch.randint(
0, model_ref.vocab_size, (2, model_ref.max_sequence_length)
).cuda()

def forward_fn(model):
return megatron_prefill(model, prompt_tokens)

# Quantize the reference model
model_ref = mtq.quantize(model_ref, config, forward_fn)

# CRITICAL: model_test must also be quantized with the same config
# Otherwise it won't have the KV cache quantizer keys when loading state dict
model_test = mtq.quantize(model_test, config, forward_fn)

# Verify KV cache quantizers were created
kv_quantizers_found = False
for name, module in model_ref.named_modules():
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
kv_quantizers_found = True
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"

assert kv_quantizers_found, "No KV cache quantizers found in quantized model"

# Test sharded state dict save/load
sharded_state_dict_test_helper(
tmp_path,
model_ref,
model_test,
forward_fn,
meta_device=False,
version=None,
)

# Verify KV cache quantizers are restored correctly in model_test
for (name_ref, module_ref), (name_test, module_test) in zip(
model_ref.named_modules(), model_test.named_modules()
):
if hasattr(module_ref, "k_bmm_quantizer"):
assert hasattr(module_test, "k_bmm_quantizer"), (
f"K quantizer missing after restore in {name_test}"
)
assert hasattr(module_test, "v_bmm_quantizer"), (
f"V quantizer missing after restore in {name_test}"
)

# Check that quantizer states match
if hasattr(module_ref.k_bmm_quantizer, "_amax"):
assert hasattr(module_test.k_bmm_quantizer, "_amax"), (
f"K quantizer _amax missing in {name_test}"
)
if module_ref.k_bmm_quantizer._amax is not None:
assert torch.allclose(
module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax
), f"K quantizer _amax mismatch in {name_test}"

if hasattr(module_ref.v_bmm_quantizer, "_amax"):
assert hasattr(module_test.v_bmm_quantizer, "_amax"), (
f"V quantizer _amax missing in {name_test}"
)
if module_ref.v_bmm_quantizer._amax is not None:
assert torch.allclose(
module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax
), f"V quantizer _amax mismatch in {name_test}"


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix is_enabled checks in sharded state dict helper.

Same issue: call is_enabled().

Apply:

-            assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
-            assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+            assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+            assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
"""Helper for testing KV cache quantization with sharded state dict save/load."""
# Disable output_layer quantization (same as other sharded state dict tests)
config["quant_cfg"]["*output_layer*"] = {"enable": False}
initialize_for_megatron(
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
)
# Create GPT models with TEDotProductAttention (transformer_impl="modelopt")
model_ref = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=2, # At least 2 layers to test multiple attention modules
hidden_size=64,
num_attention_heads=4,
vocab_size=64,
transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention
).cuda()
model_test = get_mcore_gpt_model(
tensor_model_parallel_size=size,
num_layers=2,
hidden_size=64,
num_attention_heads=4,
vocab_size=64,
transformer_impl="modelopt",
).cuda()
prompt_tokens = torch.randint(
0, model_ref.vocab_size, (2, model_ref.max_sequence_length)
).cuda()
def forward_fn(model):
return megatron_prefill(model, prompt_tokens)
# Quantize the reference model
model_ref = mtq.quantize(model_ref, config, forward_fn)
# CRITICAL: model_test must also be quantized with the same config
# Otherwise it won't have the KV cache quantizer keys when loading state dict
model_test = mtq.quantize(model_test, config, forward_fn)
# Verify KV cache quantizers were created
kv_quantizers_found = False
for name, module in model_ref.named_modules():
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
kv_quantizers_found = True
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
assert kv_quantizers_found, "No KV cache quantizers found in quantized model"
# Test sharded state dict save/load
sharded_state_dict_test_helper(
tmp_path,
model_ref,
model_test,
forward_fn,
meta_device=False,
version=None,
)
# Verify KV cache quantizers are restored correctly in model_test
for (name_ref, module_ref), (name_test, module_test) in zip(
model_ref.named_modules(), model_test.named_modules()
):
if hasattr(module_ref, "k_bmm_quantizer"):
assert hasattr(module_test, "k_bmm_quantizer"), (
f"K quantizer missing after restore in {name_test}"
)
assert hasattr(module_test, "v_bmm_quantizer"), (
f"V quantizer missing after restore in {name_test}"
)
# Check that quantizer states match
if hasattr(module_ref.k_bmm_quantizer, "_amax"):
assert hasattr(module_test.k_bmm_quantizer, "_amax"), (
f"K quantizer _amax missing in {name_test}"
)
if module_ref.k_bmm_quantizer._amax is not None:
assert torch.allclose(
module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax
), f"K quantizer _amax mismatch in {name_test}"
if hasattr(module_ref.v_bmm_quantizer, "_amax"):
assert hasattr(module_test.v_bmm_quantizer, "_amax"), (
f"V quantizer _amax missing in {name_test}"
)
if module_ref.v_bmm_quantizer._amax is not None:
assert torch.allclose(
module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax
), f"V quantizer _amax mismatch in {name_test}"
for name, module in model_ref.named_modules():
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
kv_quantizers_found = True
assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
🤖 Prompt for AI Agents
In tests/gpu/torch/quantization/plugins/test_megatron.py around lines 420-513,
the assertions that verify KV cache quantizers use the attribute access
.is_enabled instead of calling the method; update those assertions to call
module.k_bmm_quantizer.is_enabled() and module.v_bmm_quantizer.is_enabled() so
the checks invoke the method rather than reference a non-callable attribute.

coderabbitai bot added a commit that referenced this pull request Sep 25, 2025
Docstrings generation was requested by @kaix-nv.

* #375 (comment)

The following files were modified:

* `modelopt/torch/quantization/plugins/megatron.py`
* `tests/gpu/torch/quantization/plugins/test_megatron.py`
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 25, 2025

Note

Generated docstrings for this pull request at #376

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 25, 2025

Note

Unit test generation is an Early Access feature. Expect some limitations and changes as we gather feedback and continue to improve it.


Generating unit tests... This may take up to 20 minutes.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 25, 2025

UTG Post-Process Complete

No new issues were detected in the generated code and all check runs have completed. The unit test generation process has completed successfully.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 25, 2025

Creating a PR to put the unit tests in...

The changes have been created in this pull request: View PR

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 93bab27 and 5fc56fe.

📒 Files selected for processing (3)
  • modelopt/torch/quantization/plugins/__init__.py (1 hunks)
  • modelopt/torch/quantization/plugins/megatron.py (3 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/plugins/init.py
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-09-18T20:15:04.615Z
Learnt from: realAsma
PR: NVIDIA/TensorRT-Model-Optimizer#332
File: modelopt/torch/quantization/algorithms.py:323-326
Timestamp: 2025-09-18T20:15:04.615Z
Learning: In modelopt/torch/quantization/algorithms.py, the `_is_auto_quantize_module` method requires `isinstance(module, QuantModule)` because some modules like MCore Column/Row Parallel Linear are `QuantModule` but not `QuantLinearConvBase`. The check ensures all quantization-capable modules are included in AutoQuantize search.

Applied to files:

  • tests/gpu/torch/quantization/plugins/test_megatron.py
🧬 Code graph analysis (2)
tests/gpu/torch/quantization/plugins/test_megatron.py (3)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
  • initialize_for_megatron (385-393)
  • get_mcore_gpt_model (133-208)
  • sharded_state_dict_test_helper (410-457)
modelopt/torch/utils/plugins/megatron_generate.py (1)
  • megatron_prefill (41-130)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • is_enabled (389-391)
modelopt/torch/quantization/plugins/megatron.py (5)
modelopt/torch/quantization/model_calib.py (1)
  • max_calibrate (61-173)
modelopt/torch/quantization/nn/modules/quant_module.py (4)
  • QuantModule (37-96)
  • _setup (118-126)
  • _setup (163-169)
  • modelopt_post_restore (40-69)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • TensorQuantizer (60-1143)
  • is_enabled (389-391)
  • reset_amax (252-256)
  • forward (872-977)
modelopt/torch/quantization/plugins/huggingface.py (17)
  • _setup (55-58)
  • _setup (161-164)
  • _setup (239-244)
  • _setup (349-350)
  • _setup (365-369)
  • _setup (388-390)
  • _setup (427-473)
  • _setup (601-612)
  • forward (71-119)
  • forward (170-174)
  • forward (255-256)
  • forward (337-338)
  • forward (352-361)
  • forward (371-383)
  • forward (393-423)
  • forward (475-480)
  • forward (643-649)
modelopt/torch/quantization/plugins/custom.py (2)
  • modelopt_post_restore (117-174)
  • _check_unsupported_states (127-133)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: linux
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (3)
tests/gpu/torch/quantization/plugins/test_megatron.py (2)

409-411: Fix assertions: call TensorQuantizer.is_enabled()

Line 410 still treats is_enabled as an attribute, so the assertion always passes because a bound method is truthy. Please invoke it.

-            assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
-            assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+            assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+            assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"

467-468: Fix assertions: call TensorQuantizer.is_enabled()

Line 467 repeats the same issue—is_enabled must be invoked. Without the call these checks never fail even when the quantizers are disabled.

-            assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
-            assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+            assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+            assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
modelopt/torch/quantization/plugins/megatron.py (1)

483-488: Calibrator device/dtype inference must handle param/buffer-less modules

In _calibrate_quantizers (Lines 486-488), param = next(iter(self.parameters()), None) falls back to torch.device("cuda") and torch.float16 whenever the module has no parameters. Many wrapper modules expose only buffers (or can be instantiated on CPU/meta first), so this default hits real code paths and will crash on hosts without CUDA. Please probe both parameters and buffers, and only prefer CUDA if it is actually available.

-        param = next(iter(self.parameters()), None)
-        device = param.device if param is not None else torch.device("cuda")
-        dtype = param.dtype if param is not None else torch.float16
+        param = next(self.parameters(recurse=False), None)
+        buf = next(self.buffers(recurse=False), None)
+        if param is not None:
+            device, dtype = param.device, param.dtype
+        elif buf is not None:
+            device, dtype = buf.device, buf.dtype
+        else:
+            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+            dtype = torch.float16 if device.type == "cuda" else torch.float32

This matches the safe pattern already used elsewhere in the plugin and avoids hard-crashing on CPU/meta setups.

import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import torch
from megatron.core.extensions.transformer_engine import TEDotProductAttention
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard TEDotProductAttention import when TE is unavailable

Line 25 introduces an unconditional import of TEDotProductAttention. In environments where Transformer Engine is not installed (which is a supported configuration for the rest of this plugin), this raises ModuleNotFoundError during import and breaks all Megatron quantization paths that previously worked without TE. Please wrap the import in a try/except and only register _QuantTEDotProductAttention when the symbol is available.

-from megatron.core.extensions.transformer_engine import TEDotProductAttention
+try:
+    from megatron.core.extensions.transformer_engine import TEDotProductAttention
+except ImportError:
+    TEDotProductAttention = None

And guard the registration/class definition behind if TEDotProductAttention is not None: so the module remains usable without TE.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from megatron.core.extensions.transformer_engine import TEDotProductAttention
try:
from megatron.core.extensions.transformer_engine import TEDotProductAttention
except ImportError:
TEDotProductAttention = None
🤖 Prompt for AI Agents
In modelopt/torch/quantization/plugins/megatron.py around line 25, the
unconditional import of TEDotProductAttention will raise ModuleNotFoundError
when Transformer Engine (TE) is not installed; wrap the import in a try/except
that sets TEDotProductAttention = None on ImportError, then only define and
register the _QuantTEDotProductAttention class (and any related registration
calls) inside an if TEDotProductAttention is not None: block so the module
remains importable and Megatron quantization works without TE.

Comment on lines +423 to +418
config["quant_cfg"]["*output_layer*"] = {"enable": False}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Do not mutate shared config constants in-place

Line 423 mutates the shared config object (mtq.FP8_KV_CFG / mtq.NVFP4_KV_CFG) by updating quant_cfg. Because those constants are reused across parametrized tests (and potentially by library consumers), this causes cross-test contamination and unexpected behavior later in the suite. Make a copy before modifying.

-def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
+def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
+    config = copy.deepcopy(config)
     """Helper for testing KV cache quantization with sharded state dict save/load."""
     # Disable output_layer quantization (same as other sharded state dict tests)
-    config["quant_cfg"]["*output_layer*"] = {"enable": False}
+    config["quant_cfg"]["*output_layer*"] = {"enable": False}
🤖 Prompt for AI Agents
In tests/gpu/torch/quantization/plugins/test_megatron.py around lines 423-424,
the test mutates the shared config constant by assigning
config["quant_cfg"]["*output_layer*"] = {"enable": False}; instead, avoid
in-place mutation by making a copy (use copy.deepcopy(config) or dict deepcopy)
into a local variable and mutate that copy (or create a shallow copy of
quant_cfg and update it) so the original mtq.FP8_KV_CFG / mtq.NVFP4_KV_CFG
remain unchanged; replace subsequent uses in the test with the copied config.

@kaix-nv kaix-nv requested a review from jingyu-ml September 25, 2025 22:30
Copy link
Contributor

@jingyu-ml jingyu-ml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kaix-nv kaix-nv requested a review from ChenhanYu October 10, 2025 19:06
@kaix-nv kaix-nv force-pushed the kaix/kvcache_mcore branch from 5fc56fe to 0ecf711 Compare October 10, 2025 20:50
@kaix-nv kaix-nv enabled auto-merge (squash) October 10, 2025 20:51
@kaix-nv kaix-nv disabled auto-merge October 10, 2025 20:53
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (3)
modelopt/torch/quantization/plugins/megatron.py (1)

25-25: Guard TEDotProductAttention import when TE is unavailable

Line 25 introduces an unconditional import of TEDotProductAttention. In environments where Transformer Engine is not installed (which is a supported configuration for the rest of this plugin), this raises ModuleNotFoundError during import and breaks all Megatron quantization paths that previously worked without TE. Please wrap the import in a try/except and only register _QuantTEDotProductAttention when the symbol is available.

-from megatron.core.extensions.transformer_engine import TEDotProductAttention
+try:
+    from megatron.core.extensions.transformer_engine import TEDotProductAttention
+except ImportError:
+    TEDotProductAttention = None

And guard the registration/class definition behind if TEDotProductAttention is not None: so the module remains usable without TE.

tests/gpu/torch/quantization/plugins/test_megatron.py (2)

368-412: Fix assertions: call TensorQuantizer.is_enabled as a method.

is_enabled is a method, not a property. As written, the assertions at lines 404-405 will always pass because a bound method is truthy.

Apply:

-            assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
-            assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+            assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+            assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"

414-506: Fix is_enabled checks and config mutation

Two issues:

  1. Lines 461-462: is_enabled must be called as a method, not accessed as a property.
  2. Line 417: Mutating the shared config object causes cross-test contamination.

Apply:

 def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
     """Helper for testing KV cache quantization with sharded state dict save/load."""
+    config = copy.deepcopy(config)
     # Disable output_layer quantization (same as other sharded state dict tests)
     config["quant_cfg"]["*output_layer*"] = {"enable": False}

And:

-            assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
-            assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
+            assert module.k_bmm_quantizer.is_enabled(), f"K quantizer not enabled in {name}"
+            assert module.v_bmm_quantizer.is_enabled(), f"V quantizer not enabled in {name}"
🧹 Nitpick comments (1)
modelopt/torch/quantization/plugins/megatron.py (1)

485-528: Consider using recurse=False for device detection

The device detection logic is safe with the current implementation using next(iter(self.parameters()), None). However, for clarity and to avoid unintended recursion into child modules (like the quantizers themselves), consider using recurse=False:

-        param = next(iter(self.parameters()), None)
+        param = next(self.parameters(recurse=False), None)
+        buf = next(self.buffers(recurse=False), None)

Note: The is_enabled() method call at line 525 is correct (past issue resolved).

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5fc56fe and 0ecf711.

📒 Files selected for processing (4)
  • CHANGELOG.rst (1 hunks)
  • modelopt/torch/quantization/plugins/__init__.py (1 hunks)
  • modelopt/torch/quantization/plugins/megatron.py (3 hunks)
  • tests/gpu/torch/quantization/plugins/test_megatron.py (2 hunks)
✅ Files skipped from review due to trivial changes (1)
  • CHANGELOG.rst
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/quantization/plugins/init.py
🧰 Additional context used
🧬 Code graph analysis (2)
modelopt/torch/quantization/plugins/megatron.py (4)
modelopt/torch/quantization/model_calib.py (1)
  • max_calibrate (62-174)
modelopt/torch/quantization/nn/modules/quant_module.py (2)
  • QuantModule (37-96)
  • modelopt_post_restore (40-69)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (4)
  • TensorQuantizer (65-1111)
  • is_enabled (395-397)
  • reset_amax (257-262)
  • forward (847-946)
modelopt/torch/quantization/plugins/custom.py (2)
  • modelopt_post_restore (117-174)
  • _check_unsupported_states (127-133)
tests/gpu/torch/quantization/plugins/test_megatron.py (5)
tests/_test_utils/torch_dist/plugins/megatron_common.py (3)
  • initialize_for_megatron (392-400)
  • get_mcore_gpt_model (134-209)
  • sharded_state_dict_test_helper (417-464)
modelopt/torch/utils/plugins/megatron_generate.py (1)
  • megatron_prefill (41-130)
modelopt/torch/quantization/model_quant.py (1)
  • quantize (136-231)
modelopt/torch/trace/symbols.py (1)
  • named_modules (444-447)
modelopt/torch/quantization/nn/modules/tensor_quantizer.py (1)
  • is_enabled (395-397)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: code-quality
  • GitHub Check: build-docs
🔇 Additional comments (9)
modelopt/torch/quantization/plugins/megatron.py (6)

37-38: LGTM: Public quantization imports

The new imports for max_calibrate, QuantModule, QuantModuleRegistry, and TensorQuantizer are appropriate for implementing KV cache quantization support.


479-483: LGTM: Quantizer initialization

The _setup method correctly initializes three TensorQuantizer instances for Q, K, and V tensors.


530-541: LGTM: Forward pass quantization

The forward method correctly applies post-RoPE quantization to Q, K, and V tensors before invoking the parent TEDotProductAttention.forward(). This aligns with the KV cache quantization design.


543-576: LGTM: Sharded state dict implementation

The sharded_state_dict method correctly handles:

  • Non-quantizer parameters (lines 548-550)
  • Quantizer _amax values (lines 553-560)
  • Other quantizer states via make_sharded_tensors_for_checkpoint (lines 562-574)

This follows the established pattern for distributed checkpointing in Megatron models.


578-596: LGTM: State dict loading logic

The _load_from_state_dict method correctly:

  • Remaps _amax keys to match TensorQuantizer expectations (lines 580-587)
  • Reshapes quantizer states to fit local shapes (lines 589-594)

This handles distributed checkpoint restore scenarios properly.


598-630: LGTM: Post-restore calibration

The modelopt_post_restore method appropriately:

  • Validates quantizer states and warns about unsupported entries (lines 602-612)
  • Triggers calibration when _amax is missing (lines 614-630)

This ensures quantizers are properly configured after restoration, especially when tensor parallelism changes between save and load.

tests/gpu/torch/quantization/plugins/test_megatron.py (3)

233-234: LGTM: Test separation rationale

The comment appropriately explains why KV cache configs are tested separately—they require transformer_impl="modelopt" for TEDotProductAttention, not the "local" implementation used in other tests.


508-524: LGTM: KV cache quantization test

The test correctly spawns a process to verify KV cache quantization with TEDotProductAttention. The test coverage includes:

  • Model creation with transformer_impl="modelopt"
  • Quantizer presence verification
  • Forward pass smoke test

527-546: LGTM: Sharded state dict test for KV cache

The test properly validates the complete save/load workflow for KV-quantized models with distributed checkpointing, including quantizer state preservation across the cycle.

@kaix-nv kaix-nv force-pushed the kaix/kvcache_mcore branch from 0ecf711 to 2fed9db Compare October 10, 2025 20:55
@kaix-nv kaix-nv enabled auto-merge (squash) October 10, 2025 20:59
@kaix-nv kaix-nv force-pushed the kaix/kvcache_mcore branch from 2fed9db to 0482373 Compare October 10, 2025 22:55
Signed-off-by: Kai Xu <[email protected]>
@kaix-nv kaix-nv force-pushed the kaix/kvcache_mcore branch from 0482373 to d2c05f2 Compare October 10, 2025 22:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants