Skip to content

Commit c1d624e

Browse files
authored
Support kv cache quantization for mcore using bmm_quantizers (NVIDIA#375)
Signed-off-by: Kai Xu <[email protected]>
1 parent 41c66bb commit c1d624e

File tree

7 files changed

+378
-19
lines changed

7 files changed

+378
-19
lines changed

CHANGELOG.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
Model Optimizer Changelog (Linux)
22
=================================
3+
0.41 (2025-12-xx)
4+
^^^^^^^^^^^^^^^^^
5+
6+
**Deprecations**
7+
8+
**New Features**
9+
- Add FP8/NVFP4 KV cache quantization support for Megatron Core models.
10+
311

412
0.40 (2025-12-xx)
513
^^^^^^^^^^^^^^^^^

examples/llm_ptq/example_utils.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import sys
2121
import warnings
2222
from pathlib import Path
23-
from typing import Any
2423

2524
import torch
2625
import transformers
@@ -159,7 +158,7 @@ def build_quant_cfg(
159158

160159
# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
161160
if enable_quant_kv_cache:
162-
quant_cfg = apply_kv_cache_quant(
161+
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
163162
quant_cfg,
164163
getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
165164
)
@@ -403,20 +402,6 @@ def is_enc_dec(model_type) -> bool:
403402
return model_type in ["t5", "bart", "whisper"]
404403

405404

406-
def apply_kv_cache_quant(quant_cfg: dict[str, Any], kv_cache_quant_cfg: dict[str, Any]):
407-
"""Apply quantization to the kv cache of the model."""
408-
# Update KV cache related bmm quantizers
409-
# If quant_cfg["quant_cfg"] is None, it corresponds to only kv cache quantization case
410-
quant_cfg["quant_cfg"] = quant_cfg.get("quant_cfg", {"default": {"enable": False}})
411-
quant_cfg["quant_cfg"].update(kv_cache_quant_cfg)
412-
413-
# Set default algorithm for kv cache quantization if not provided.
414-
if not quant_cfg.get("algorithm"):
415-
quant_cfg["algorithm"] = "max"
416-
417-
return quant_cfg
418-
419-
420405
def _resolve_model_path(model_name_or_path: str, trust_remote_code: bool = False) -> str:
421406
"""Resolve a model name or path to a local directory path.
422407

examples/llm_ptq/hf_ptq.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import torch
2424
from accelerate.hooks import remove_hook_from_module
2525
from example_utils import (
26-
apply_kv_cache_quant,
2726
build_quant_cfg,
2827
copy_custom_model_files,
2928
get_model,
@@ -86,8 +85,10 @@
8685
KV_QUANT_CFG_CHOICES = {
8786
"none": "none",
8887
"fp8": "FP8_KV_CFG",
88+
"fp8_affine": "FP8_AFFINE_KV_CFG",
8989
"nvfp4": "NVFP4_KV_CFG",
9090
"nvfp4_affine": "NVFP4_AFFINE_KV_CFG",
91+
"nvfp4_rotate": "NVFP4_KV_ROTATE_CFG",
9192
}
9293

9394
mto.enable_huggingface_checkpointing()
@@ -257,7 +258,7 @@ def main(args):
257258
)
258259
quant_cfg = QUANT_CFG_CHOICES[args.qformat]
259260
if args.kv_cache_qformat != "none":
260-
quant_cfg = apply_kv_cache_quant(
261+
quant_cfg = mtq.utils.update_quant_cfg_with_kv_cache_quant(
261262
quant_cfg, getattr(mtq, KV_QUANT_CFG_CHOICES[args.kv_cache_qformat])["quant_cfg"]
262263
)
263264

modelopt/torch/quantization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@
2424
from .conversion import *
2525
from .model_quant import *
2626
from .nn.modules.quant_module import QuantModuleRegistry
27+
from .utils import update_quant_cfg_with_kv_cache_quant

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
)
3838
from modelopt.torch.utils.distributed import ParallelState
3939

40+
from ..model_calib import max_calibrate
4041
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
4142
from ..nn.modules.quant_linear import RealQuantLinear
4243
from ..qtensor import QTensorWrapper
@@ -45,6 +46,7 @@
4546
try:
4647
from megatron.core.extensions.transformer_engine import (
4748
TEColumnParallelGroupedLinear,
49+
TEDotProductAttention,
4850
TERowParallelGroupedLinear,
4951
)
5052

@@ -590,6 +592,169 @@ def _setup(self):
590592
self.linear_fc1.parallel_state = self.parallel_state
591593
self.linear_fc2.parallel_state = self.parallel_state
592594

595+
@QuantModuleRegistry.register({TEDotProductAttention: "TEDotProductAttention"})
596+
class _QuantTEDotProductAttention(QuantModule):
597+
"""Quantized version of TEDotProductAttention for Megatron models with KV cache quantization.
598+
599+
This class adds KV cache quantization support to Transformer Engine's TEDotProductAttention
600+
module used in Megatron-Core models. It introduces three quantizers (q_bmm_quantizer,
601+
k_bmm_quantizer, v_bmm_quantizer) that quantize the query, key, and value tensors after
602+
RoPE has been applied.
603+
"""
604+
605+
def _setup(self):
606+
"""Initialize quantizers for Q, K, V tensors."""
607+
self.q_bmm_quantizer = TensorQuantizer()
608+
self.k_bmm_quantizer = TensorQuantizer()
609+
self.v_bmm_quantizer = TensorQuantizer()
610+
611+
def _calibrate_quantizers(self):
612+
"""Calibrate quantizers with minimal dummy tensors."""
613+
# Get device and dtype from the parent module's parameters
614+
param = next(iter(self.parameters()), None)
615+
device = param.device if param is not None else torch.device("cuda")
616+
dtype = param.dtype if param is not None else torch.float16
617+
618+
# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
619+
batch_size = 1
620+
seq_len = 1
621+
622+
# Get dimensions from config
623+
num_heads = self.config.num_attention_heads
624+
head_dim = (
625+
self.config.kv_channels
626+
if hasattr(self.config, "kv_channels")
627+
else self.config.hidden_size // num_heads
628+
)
629+
630+
# Determine tensor format (default to sbhd if not specified)
631+
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False)
632+
qkv_format = "bshd" if apply_rope_fusion else "sbhd"
633+
634+
if qkv_format == "sbhd":
635+
dummy_tensor = torch.randn(
636+
seq_len, batch_size, num_heads, head_dim, device=device, dtype=dtype
637+
)
638+
else:
639+
dummy_tensor = torch.randn(
640+
batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype
641+
)
642+
643+
# Calibrate each quantizer
644+
quantizers = [
645+
("q_bmm_quantizer", self.q_bmm_quantizer),
646+
("k_bmm_quantizer", self.k_bmm_quantizer),
647+
("v_bmm_quantizer", self.v_bmm_quantizer),
648+
]
649+
650+
for _, quantizer in quantizers:
651+
if quantizer is not None and quantizer.is_enabled():
652+
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
653+
quantizer.reset_amax()
654+
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)
655+
656+
def forward(self, query, key, value, *args, **kwargs):
657+
"""Apply post-RoPE quantization to KV cache.
658+
659+
TEDotProductAttention receives Q, K, V after RoPE is applied,
660+
so we quantize them directly for KV cache quantization.
661+
"""
662+
# Quantize Q, K, V
663+
query = self.q_bmm_quantizer(query)
664+
key = self.k_bmm_quantizer(key)
665+
value = self.v_bmm_quantizer(value)
666+
667+
return super().forward(query, key, value, *args, **kwargs)
668+
669+
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
670+
"""Create a sharded state dictionary for distributed checkpointing."""
671+
sharded_state_dict = {}
672+
673+
# First add non-quantizer parameters
674+
for k, v in self.state_dict(prefix="", keep_vars=True).items():
675+
if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k:
676+
sharded_state_dict[prefix + k] = v
677+
678+
# Process _amax in bmm_quantizers
679+
for name, quantizer in [
680+
("q_bmm_quantizer", self.q_bmm_quantizer),
681+
("k_bmm_quantizer", self.k_bmm_quantizer),
682+
("v_bmm_quantizer", self.v_bmm_quantizer),
683+
]:
684+
if hasattr(quantizer, "_amax") and quantizer._amax is not None:
685+
amax_key = f"{prefix}{name}._amax"
686+
sharded_state_dict[amax_key] = quantizer._amax
687+
688+
# Process other quantizer parameters in bmm_quantizers
689+
quantizer_state_dict = {
690+
k: v
691+
for k, v in self.state_dict(prefix="", keep_vars=True).items()
692+
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k
693+
}
694+
695+
if quantizer_state_dict:
696+
sharded_state_dict.update(
697+
**make_sharded_tensors_for_checkpoint(
698+
quantizer_state_dict, prefix, {}, sharded_offsets
699+
)
700+
)
701+
702+
return sharded_state_dict
703+
704+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
705+
"""Handle loading state dict for quantizers."""
706+
for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]:
707+
full_prefix = f"{prefix}{quantizer_name}."
708+
amax_key = f"{prefix}{quantizer_name}._amax"
709+
710+
# If amax is in state_dict, rename it to the format expected by TensorQuantizer
711+
if amax_key in state_dict:
712+
expected_amax_key = f"{full_prefix}_amax"
713+
state_dict[expected_amax_key] = state_dict.pop(amax_key)
714+
715+
# Handle other quantizer states
716+
for k in list(state_dict.keys()):
717+
if "_quantizer" in k and "_amax" not in k:
718+
name = k.split(prefix)[-1] if prefix else k
719+
if name in self.state_dict():
720+
state_dict[k] = state_dict[k].view_as(self.state_dict()[name])
721+
722+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
723+
724+
def modelopt_post_restore(self, name=""):
725+
"""Restore quantizer states after model loading."""
726+
super().modelopt_post_restore(name)
727+
728+
def _check_unsupported_states(quantizer):
729+
"""Check for unsupported quantizer states and warn if found."""
730+
if not hasattr(quantizer, "state_dict"):
731+
return
732+
733+
for k in quantizer.state_dict():
734+
if k not in ["_amax", "_pre_quant_scale"]:
735+
warnings.warn(
736+
f"Restore of {k} for {name} is not supported. The restore of this layer might be "
737+
f"incorrect. Please implement a custom restore for {k}."
738+
)
739+
740+
calibration_needed = False
741+
742+
for quantizer_name, quantizer in [
743+
("q_bmm_quantizer", self.q_bmm_quantizer),
744+
("k_bmm_quantizer", self.k_bmm_quantizer),
745+
("v_bmm_quantizer", self.v_bmm_quantizer),
746+
]:
747+
if not hasattr(self, quantizer_name) or not quantizer.is_enabled():
748+
continue
749+
750+
_check_unsupported_states(quantizer)
751+
752+
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
753+
calibration_needed = True
754+
755+
if calibration_needed:
756+
self._calibrate_quantizers()
757+
593758

594759
@QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"})
595760
class _QuantMoELayer(QuantModule):

modelopt/torch/quantization/utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from collections import namedtuple
2121
from contextlib import ExitStack, contextmanager, nullcontext
22-
from typing import TYPE_CHECKING
22+
from typing import TYPE_CHECKING, Any
2323

2424
import torch
2525
import torch.nn as nn
@@ -43,6 +43,7 @@
4343
"is_quantized_row_parallel_linear",
4444
"reduce_amax",
4545
"replace_function",
46+
"update_quant_cfg_with_kv_cache_quant",
4647
"weight_attr_names",
4748
]
4849

@@ -703,3 +704,18 @@ def fsdp2_aware_weight_update(root_model, modules_to_update, reshard=True):
703704
if reshard:
704705
with enable_fake_quant(root_module):
705706
root_module.reshard()
707+
708+
709+
def update_quant_cfg_with_kv_cache_quant(
710+
quant_cfg: dict[str, Any], kv_cache_quant_cfg: dict[str, Any]
711+
) -> dict[str, Any]:
712+
"""Update the quant_cfg with the kv cache quant_cfg."""
713+
# If quant_cfg["quant_cfg"] is None, it corresponds to only kv cache quantization case
714+
quant_cfg["quant_cfg"] = quant_cfg.get("quant_cfg", {"default": {"enable": False}})
715+
quant_cfg["quant_cfg"].update(kv_cache_quant_cfg)
716+
717+
# Set default algorithm for kv cache quantization if not provided.
718+
if not quant_cfg.get("algorithm"):
719+
quant_cfg["algorithm"] = "max"
720+
print_rank_0(f"Updated quant_cfg with KV cache quantization: {quant_cfg}")
721+
return quant_cfg

0 commit comments

Comments
 (0)