Skip to content

Commit 4efc618

Browse files
committed
add kv cache quantization for mcore using bmm_quantizers
Signed-off-by: Kai Xu <[email protected]>
1 parent 08fb23f commit 4efc618

File tree

3 files changed

+326
-1
lines changed

3 files changed

+326
-1
lines changed

modelopt/torch/quantization/plugins/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,6 @@
7171

7272
with import_plugin("trl"):
7373
from .trl import *
74+
75+
with import_plugin("mcore"):
76+
from .mcore import *

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,19 @@
2626
from megatron.core.transformer import MegatronModule
2727
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
2828
from megatron.core.utils import get_tensor_model_parallel_group_if_none
29+
from megatron.core.extensions.transformer_engine import TEDotProductAttention
2930

3031
from modelopt.torch.opt.plugins.megatron import (
3132
_MegatronMLP,
3233
register_modelopt_extra_state_callbacks,
3334
)
3435
from modelopt.torch.utils.distributed import ParallelState
3536

36-
from ..nn import QuantModuleRegistry, TensorQuantizer
37+
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
3738
from ..nn.modules.quant_linear import RealQuantLinear
3839
from ..qtensor import QTensorWrapper
3940
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
41+
from ..model_calib import max_calibrate
4042

4143
__all__ = []
4244

@@ -462,3 +464,149 @@ class _RealQuantMegatronRowParallelLinear(
462464

463465
def forward(self, input, *args, **kwargs):
464466
return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs)
467+
468+
469+
@QuantModuleRegistry.register({TEDotProductAttention: "TEDotProductAttention"})
470+
class _QuantTEDotProductAttention(QuantModule):
471+
"""Quantized version of TEDotProductAttention for Megatron models with KV cache quantization."""
472+
473+
def _setup(self):
474+
"""Initialize quantizers for Q, K, V tensors."""
475+
self.q_bmm_quantizer = TensorQuantizer()
476+
self.k_bmm_quantizer = TensorQuantizer()
477+
self.v_bmm_quantizer = TensorQuantizer()
478+
479+
def _calibrate_quantizers(self):
480+
"""Calibrate quantizers with minimal dummy tensors."""
481+
# Get device from parent module parameters
482+
device = next(self.parameters()).device if self.parameters() else torch.device('cuda')
483+
484+
# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
485+
batch_size = 1
486+
seq_len = 1
487+
488+
# Get dimensions from config
489+
num_heads = self.config.num_attention_heads
490+
head_dim = self.config.kv_channels if hasattr(self.config, 'kv_channels') else self.config.hidden_size // num_heads
491+
492+
# Determine tensor format (default to sbhd if not specified)
493+
apply_rope_fusion = getattr(self.config, 'apply_rope_fusion', False)
494+
qkv_format = "bshd" if apply_rope_fusion else "sbhd"
495+
496+
if qkv_format == "sbhd":
497+
dummy_tensor = torch.randn(seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16)
498+
else:
499+
dummy_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16)
500+
501+
# Calibrate each quantizer
502+
quantizers = [
503+
("q_bmm_quantizer", self.q_bmm_quantizer),
504+
("k_bmm_quantizer", self.k_bmm_quantizer),
505+
("v_bmm_quantizer", self.v_bmm_quantizer),
506+
]
507+
508+
for _, quantizer in quantizers:
509+
if quantizer is not None and quantizer.is_enabled:
510+
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
511+
quantizer.reset_amax()
512+
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)
513+
514+
def forward(self, query, key, value, *args, **kwargs):
515+
"""Apply post-RoPE quantization to KV cache.
516+
517+
TEDotProductAttention receives Q, K, V after RoPE is applied,
518+
so we quantize them directly for KV cache quantization.
519+
"""
520+
# Quantize Q, K, V
521+
query = self.q_bmm_quantizer(query)
522+
key = self.k_bmm_quantizer(key)
523+
value = self.v_bmm_quantizer(value)
524+
525+
return super().forward(query, key, value, *args, **kwargs)
526+
527+
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
528+
"""Create a sharded state dictionary for distributed checkpointing."""
529+
sharded_state_dict = {}
530+
531+
# First add non-quantizer parameters
532+
for k, v in self.state_dict(prefix="", keep_vars=True).items():
533+
if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k:
534+
sharded_state_dict[prefix + k] = v
535+
536+
# Process _amax in bmm_quantizers
537+
for name, quantizer in [
538+
("q_bmm_quantizer", self.q_bmm_quantizer),
539+
("k_bmm_quantizer", self.k_bmm_quantizer),
540+
("v_bmm_quantizer", self.v_bmm_quantizer),
541+
]:
542+
if hasattr(quantizer, "_amax") and quantizer._amax is not None:
543+
amax_key = f"{prefix}{name}._amax"
544+
sharded_state_dict[amax_key] = quantizer._amax
545+
546+
# Process other quantizer parameters in bmm_quantizers
547+
quantizer_state_dict = {}
548+
for k, v in self.state_dict(prefix="", keep_vars=True).items():
549+
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k:
550+
quantizer_state_dict[k] = v
551+
552+
if quantizer_state_dict:
553+
sharded_state_dict.update(
554+
**make_sharded_tensors_for_checkpoint(
555+
quantizer_state_dict, prefix, {}, sharded_offsets
556+
)
557+
)
558+
559+
return sharded_state_dict
560+
561+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
562+
"""Handle loading state dict for quantizers."""
563+
for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]:
564+
full_prefix = f"{prefix}{quantizer_name}."
565+
amax_key = f"{prefix}{quantizer_name}._amax"
566+
567+
# If amax is in state_dict, rename it to the format expected by TensorQuantizer
568+
if amax_key in state_dict:
569+
expected_amax_key = f"{full_prefix}_amax"
570+
state_dict[expected_amax_key] = state_dict.pop(amax_key)
571+
572+
# Handle other quantizer states
573+
for k in list(state_dict.keys()):
574+
if "_quantizer" in k and "_amax" not in k:
575+
name = k.split(prefix)[-1] if prefix else k
576+
if name in self.state_dict():
577+
state_dict[k] = state_dict[k].view_as(self.state_dict()[name])
578+
579+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
580+
581+
def modelopt_post_restore(self, name=""):
582+
"""Restore quantizer states after model loading."""
583+
super().modelopt_post_restore(name)
584+
585+
def _check_unsupported_states(quantizer):
586+
if not hasattr(quantizer, "state_dict"):
587+
return
588+
589+
for k in quantizer.state_dict().keys():
590+
if k not in ["_amax", "_pre_quant_scale"]:
591+
warnings.warn(
592+
f"Restore of {k} for {name} is not supported. The restore of this layer might be "
593+
f"incorrect. Please implement a custom restore for {k}."
594+
)
595+
596+
calibration_needed = False
597+
598+
for quantizer_name, quantizer in [
599+
("q_bmm_quantizer", self.q_bmm_quantizer),
600+
("k_bmm_quantizer", self.k_bmm_quantizer),
601+
("v_bmm_quantizer", self.v_bmm_quantizer),
602+
]:
603+
if not hasattr(self, quantizer_name) or not quantizer.is_enabled:
604+
continue
605+
606+
_check_unsupported_states(quantizer)
607+
608+
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
609+
calibration_needed = True
610+
611+
if calibration_needed:
612+
self._calibrate_quantizers()

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import pytest
1919
import torch
20+
import torch.nn as nn
2021
from _test_utils.import_helper import skip_if_no_megatron
2122
from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job
2223
from _test_utils.torch_dist.plugins.megatron_common import (
@@ -230,6 +231,8 @@ def forward_fn(model):
230231
mtq.W4A8_AWQ_BETA_CFG,
231232
mtq.NVFP4_DEFAULT_CFG,
232233
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
234+
# Note: KV cache configs (FP8_KV_CFG, NVFP4_KV_CFG) are tested separately in test_kv_cache_quant
235+
# They require TEDotProductAttention which needs transformer_impl="modelopt", not "local"
233236
],
234237
)
235238
@pytest.mark.parametrize("compress", [False, True])
@@ -361,3 +364,174 @@ def forward_fn(model):
361364
def test_fp8_real_quantize():
362365
size = torch.cuda.device_count()
363366
spawn_multiprocess_job(size=size, job=_test_fp8_real_quantize_helper, backend="nccl")
367+
368+
369+
def _test_kv_cache_quant_helper(config, rank, size):
370+
"""Helper function for testing KV cache quantization with TEDotProductAttention."""
371+
initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED)
372+
373+
# Use existing infrastructure to create a minimal GPT model with TEDotProductAttention
374+
# Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention
375+
model = get_mcore_gpt_model(
376+
tensor_model_parallel_size=size,
377+
num_layers=1,
378+
hidden_size=64,
379+
num_attention_heads=4,
380+
vocab_size=32,
381+
transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec
382+
).cuda()
383+
384+
# Create dummy input for calibration
385+
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
386+
387+
def forward_fn(model):
388+
return megatron_prefill(model, prompt_tokens)
389+
390+
# Test KV cache quantization with the given config
391+
quantized_model = mtq.quantize(model, config, forward_fn)
392+
393+
# Find TEDotProductAttention modules and verify they have KV cache quantizers
394+
te_attention_found = False
395+
for name, module in quantized_model.named_modules():
396+
# Check if this is a quantized TEDotProductAttention
397+
if hasattr(module, 'q_bmm_quantizer') and hasattr(module, 'k_bmm_quantizer'):
398+
te_attention_found = True
399+
# Verify all expected quantizers exist
400+
assert hasattr(module, 'v_bmm_quantizer'), f"Missing v_bmm_quantizer in {name}"
401+
402+
# Verify K and V quantizers are enabled (main purpose of KV cache configs)
403+
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
404+
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
405+
406+
assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model"
407+
408+
# Quick smoke test that forward still works
409+
output = forward_fn(quantized_model)
410+
assert output is not None, "Forward pass failed"
411+
412+
413+
def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
414+
"""Helper for testing KV cache quantization with sharded state dict save/load."""
415+
# Disable output_layer quantization (same as other sharded state dict tests)
416+
config["quant_cfg"]["*output_layer*"] = {"enable": False}
417+
418+
initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED)
419+
420+
# Create GPT models with TEDotProductAttention (transformer_impl="modelopt")
421+
model_ref = get_mcore_gpt_model(
422+
tensor_model_parallel_size=size,
423+
num_layers=2, # At least 2 layers to test multiple attention modules
424+
hidden_size=64,
425+
num_attention_heads=4,
426+
vocab_size=64,
427+
transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention
428+
).cuda()
429+
430+
model_test = get_mcore_gpt_model(
431+
tensor_model_parallel_size=size,
432+
num_layers=2,
433+
hidden_size=64,
434+
num_attention_heads=4,
435+
vocab_size=64,
436+
transformer_impl="modelopt",
437+
).cuda()
438+
439+
prompt_tokens = torch.randint(0, model_ref.vocab_size, (2, model_ref.max_sequence_length)).cuda()
440+
441+
def forward_fn(model):
442+
return megatron_prefill(model, prompt_tokens)
443+
444+
# Quantize the reference model
445+
model_ref = mtq.quantize(model_ref, config, forward_fn)
446+
447+
# CRITICAL: model_test must also be quantized with the same config
448+
# Otherwise it won't have the KV cache quantizer keys when loading state dict
449+
model_test = mtq.quantize(model_test, config, forward_fn)
450+
451+
# Verify KV cache quantizers were created
452+
kv_quantizers_found = False
453+
for name, module in model_ref.named_modules():
454+
if hasattr(module, 'k_bmm_quantizer') and hasattr(module, 'v_bmm_quantizer'):
455+
kv_quantizers_found = True
456+
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
457+
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
458+
459+
assert kv_quantizers_found, "No KV cache quantizers found in quantized model"
460+
461+
# Test sharded state dict save/load
462+
sharded_state_dict_test_helper(
463+
tmp_path,
464+
model_ref,
465+
model_test,
466+
forward_fn,
467+
meta_device=False,
468+
version=None,
469+
)
470+
471+
# Verify KV cache quantizers are restored correctly in model_test
472+
for (name_ref, module_ref), (name_test, module_test) in zip(
473+
model_ref.named_modules(), model_test.named_modules()
474+
):
475+
if hasattr(module_ref, 'k_bmm_quantizer'):
476+
assert hasattr(module_test, 'k_bmm_quantizer'), f"K quantizer missing after restore in {name_test}"
477+
assert hasattr(module_test, 'v_bmm_quantizer'), f"V quantizer missing after restore in {name_test}"
478+
479+
# Check that quantizer states match
480+
if hasattr(module_ref.k_bmm_quantizer, '_amax'):
481+
assert hasattr(module_test.k_bmm_quantizer, '_amax'), f"K quantizer _amax missing in {name_test}"
482+
if module_ref.k_bmm_quantizer._amax is not None:
483+
assert torch.allclose(
484+
module_ref.k_bmm_quantizer._amax,
485+
module_test.k_bmm_quantizer._amax
486+
), f"K quantizer _amax mismatch in {name_test}"
487+
488+
if hasattr(module_ref.v_bmm_quantizer, '_amax'):
489+
assert hasattr(module_test.v_bmm_quantizer, '_amax'), f"V quantizer _amax missing in {name_test}"
490+
if module_ref.v_bmm_quantizer._amax is not None:
491+
assert torch.allclose(
492+
module_ref.v_bmm_quantizer._amax,
493+
module_test.v_bmm_quantizer._amax
494+
), f"V quantizer _amax mismatch in {name_test}"
495+
496+
497+
@pytest.mark.parametrize(
498+
"config",
499+
[
500+
mtq.FP8_KV_CFG,
501+
mtq.NVFP4_KV_CFG,
502+
],
503+
)
504+
def test_kv_cache_quant(config):
505+
"""Verify KV cache quantization works correctly with TEDotProductAttention.
506+
507+
This test ensures TEDotProductAttention is properly registered and gets the
508+
expected q/k/v_bmm_quantizers when using KV cache configs.
509+
510+
Note: This test requires Transformer Engine to be installed since TEDotProductAttention
511+
is only available with transformer_impl="modelopt" or "transformer_engine" (not "local").
512+
"""
513+
spawn_multiprocess_job(
514+
size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl"
515+
)
516+
517+
518+
@pytest.mark.parametrize(
519+
"config",
520+
[
521+
mtq.FP8_KV_CFG,
522+
mtq.NVFP4_KV_CFG,
523+
],
524+
)
525+
def test_kv_cache_sharded_state_dict(tmp_path, config):
526+
"""Test KV cache quantization with sharded state dict save/load.
527+
528+
This test verifies the complete workflow of saving and loading KV cache quantized
529+
models with distributed checkpointing, ensuring quantizer states are properly
530+
preserved across the save/load cycle.
531+
"""
532+
size = min(2, torch.cuda.device_count()) # Use 2 GPUs if available, else 1
533+
spawn_multiprocess_job(
534+
size=size,
535+
job=partial(_test_kv_cache_sharded_state_dict_helper, tmp_path, config),
536+
backend="nccl"
537+
)

0 commit comments

Comments
 (0)