Skip to content

Commit 1f77518

Browse files
committed
add kv cache quantization for mcore using bmm_quantizers
Signed-off-by: Kai Xu <[email protected]>
1 parent 4efc618 commit 1f77518

File tree

3 files changed

+98
-71
lines changed

3 files changed

+98
-71
lines changed

modelopt/torch/quantization/plugins/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,4 @@
7373
from .trl import *
7474

7575
with import_plugin("mcore"):
76-
from .mcore import *
76+
from .mcore import *

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,23 @@
2222
import megatron.core.tensor_parallel.layers as megatron_parallel
2323
import megatron.core.transformer.mlp as megatron_mlp
2424
import torch
25+
from megatron.core.extensions.transformer_engine import TEDotProductAttention
2526
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
2627
from megatron.core.transformer import MegatronModule
2728
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
2829
from megatron.core.utils import get_tensor_model_parallel_group_if_none
29-
from megatron.core.extensions.transformer_engine import TEDotProductAttention
3030

3131
from modelopt.torch.opt.plugins.megatron import (
3232
_MegatronMLP,
3333
register_modelopt_extra_state_callbacks,
3434
)
3535
from modelopt.torch.utils.distributed import ParallelState
3636

37+
from ..model_calib import max_calibrate
3738
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
3839
from ..nn.modules.quant_linear import RealQuantLinear
3940
from ..qtensor import QTensorWrapper
4041
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
41-
from ..model_calib import max_calibrate
4242

4343
__all__ = []
4444

@@ -468,7 +468,13 @@ def forward(self, input, *args, **kwargs):
468468

469469
@QuantModuleRegistry.register({TEDotProductAttention: "TEDotProductAttention"})
470470
class _QuantTEDotProductAttention(QuantModule):
471-
"""Quantized version of TEDotProductAttention for Megatron models with KV cache quantization."""
471+
"""Quantized version of TEDotProductAttention for Megatron models with KV cache quantization.
472+
473+
This class adds KV cache quantization support to Transformer Engine's TEDotProductAttention
474+
module used in Megatron-Core models. It introduces three quantizers (q_bmm_quantizer,
475+
k_bmm_quantizer, v_bmm_quantizer) that quantize the query, key, and value tensors after
476+
RoPE has been applied.
477+
"""
472478

473479
def _setup(self):
474480
"""Initialize quantizers for Q, K, V tensors."""
@@ -478,25 +484,35 @@ def _setup(self):
478484

479485
def _calibrate_quantizers(self):
480486
"""Calibrate quantizers with minimal dummy tensors."""
481-
# Get device from parent module parameters
482-
device = next(self.parameters()).device if self.parameters() else torch.device('cuda')
483-
487+
# Get device and dtype from the parent module's parameters
488+
param = next(iter(self.parameters()), None)
489+
device = param.device if param is not None else torch.device("cuda")
490+
dtype = param.dtype if param is not None else torch.float16
491+
484492
# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
485493
batch_size = 1
486494
seq_len = 1
487-
495+
488496
# Get dimensions from config
489497
num_heads = self.config.num_attention_heads
490-
head_dim = self.config.kv_channels if hasattr(self.config, 'kv_channels') else self.config.hidden_size // num_heads
491-
498+
head_dim = (
499+
self.config.kv_channels
500+
if hasattr(self.config, "kv_channels")
501+
else self.config.hidden_size // num_heads
502+
)
503+
492504
# Determine tensor format (default to sbhd if not specified)
493-
apply_rope_fusion = getattr(self.config, 'apply_rope_fusion', False)
505+
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False)
494506
qkv_format = "bshd" if apply_rope_fusion else "sbhd"
495507

496508
if qkv_format == "sbhd":
497-
dummy_tensor = torch.randn(seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16)
509+
dummy_tensor = torch.randn(
510+
seq_len, batch_size, num_heads, head_dim, device=device, dtype=dtype
511+
)
498512
else:
499-
dummy_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16)
513+
dummy_tensor = torch.randn(
514+
batch_size, seq_len, num_heads, head_dim, device=device, dtype=dtype
515+
)
500516

501517
# Calibrate each quantizer
502518
quantizers = [
@@ -506,14 +522,14 @@ def _calibrate_quantizers(self):
506522
]
507523

508524
for _, quantizer in quantizers:
509-
if quantizer is not None and quantizer.is_enabled:
525+
if quantizer is not None and quantizer.is_enabled():
510526
if not hasattr(quantizer, "_amax") or quantizer._amax is None:
511527
quantizer.reset_amax()
512528
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)
513529

514530
def forward(self, query, key, value, *args, **kwargs):
515531
"""Apply post-RoPE quantization to KV cache.
516-
532+
517533
TEDotProductAttention receives Q, K, V after RoPE is applied,
518534
so we quantize them directly for KV cache quantization.
519535
"""
@@ -527,7 +543,7 @@ def forward(self, query, key, value, *args, **kwargs):
527543
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
528544
"""Create a sharded state dictionary for distributed checkpointing."""
529545
sharded_state_dict = {}
530-
546+
531547
# First add non-quantizer parameters
532548
for k, v in self.state_dict(prefix="", keep_vars=True).items():
533549
if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k:
@@ -544,10 +560,11 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
544560
sharded_state_dict[amax_key] = quantizer._amax
545561

546562
# Process other quantizer parameters in bmm_quantizers
547-
quantizer_state_dict = {}
548-
for k, v in self.state_dict(prefix="", keep_vars=True).items():
549-
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k:
550-
quantizer_state_dict[k] = v
563+
quantizer_state_dict = {
564+
k: v
565+
for k, v in self.state_dict(prefix="", keep_vars=True).items()
566+
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k
567+
}
551568

552569
if quantizer_state_dict:
553570
sharded_state_dict.update(
@@ -583,10 +600,11 @@ def modelopt_post_restore(self, name=""):
583600
super().modelopt_post_restore(name)
584601

585602
def _check_unsupported_states(quantizer):
603+
"""Check for unsupported quantizer states and warn if found."""
586604
if not hasattr(quantizer, "state_dict"):
587605
return
588606

589-
for k in quantizer.state_dict().keys():
607+
for k in quantizer.state_dict():
590608
if k not in ["_amax", "_pre_quant_scale"]:
591609
warnings.warn(
592610
f"Restore of {k} for {name} is not supported. The restore of this layer might be "
@@ -600,7 +618,7 @@ def _check_unsupported_states(quantizer):
600618
("k_bmm_quantizer", self.k_bmm_quantizer),
601619
("v_bmm_quantizer", self.v_bmm_quantizer),
602620
]:
603-
if not hasattr(self, quantizer_name) or not quantizer.is_enabled:
621+
if not hasattr(self, quantizer_name) or not quantizer.is_enabled():
604622
continue
605623

606624
_check_unsupported_states(quantizer)

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import pytest
1919
import torch
20-
import torch.nn as nn
2120
from _test_utils.import_helper import skip_if_no_megatron
2221
from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job
2322
from _test_utils.torch_dist.plugins.megatron_common import (
@@ -368,8 +367,10 @@ def test_fp8_real_quantize():
368367

369368
def _test_kv_cache_quant_helper(config, rank, size):
370369
"""Helper function for testing KV cache quantization with TEDotProductAttention."""
371-
initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED)
372-
370+
initialize_for_megatron(
371+
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
372+
)
373+
373374
# Use existing infrastructure to create a minimal GPT model with TEDotProductAttention
374375
# Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention
375376
model = get_mcore_gpt_model(
@@ -380,43 +381,45 @@ def _test_kv_cache_quant_helper(config, rank, size):
380381
vocab_size=32,
381382
transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec
382383
).cuda()
383-
384+
384385
# Create dummy input for calibration
385386
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
386-
387+
387388
def forward_fn(model):
388389
return megatron_prefill(model, prompt_tokens)
389-
390+
390391
# Test KV cache quantization with the given config
391392
quantized_model = mtq.quantize(model, config, forward_fn)
392-
393+
393394
# Find TEDotProductAttention modules and verify they have KV cache quantizers
394395
te_attention_found = False
395396
for name, module in quantized_model.named_modules():
396397
# Check if this is a quantized TEDotProductAttention
397-
if hasattr(module, 'q_bmm_quantizer') and hasattr(module, 'k_bmm_quantizer'):
398+
if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"):
398399
te_attention_found = True
399400
# Verify all expected quantizers exist
400-
assert hasattr(module, 'v_bmm_quantizer'), f"Missing v_bmm_quantizer in {name}"
401-
401+
assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}"
402+
402403
# Verify K and V quantizers are enabled (main purpose of KV cache configs)
403404
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
404405
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
405-
406+
406407
assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model"
407-
408+
408409
# Quick smoke test that forward still works
409410
output = forward_fn(quantized_model)
410411
assert output is not None, "Forward pass failed"
411-
412+
412413

413414
def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
414415
"""Helper for testing KV cache quantization with sharded state dict save/load."""
415416
# Disable output_layer quantization (same as other sharded state dict tests)
416417
config["quant_cfg"]["*output_layer*"] = {"enable": False}
417-
418-
initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED)
419-
418+
419+
initialize_for_megatron(
420+
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
421+
)
422+
420423
# Create GPT models with TEDotProductAttention (transformer_impl="modelopt")
421424
model_ref = get_mcore_gpt_model(
422425
tensor_model_parallel_size=size,
@@ -426,7 +429,7 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
426429
vocab_size=64,
427430
transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention
428431
).cuda()
429-
432+
430433
model_test = get_mcore_gpt_model(
431434
tensor_model_parallel_size=size,
432435
num_layers=2,
@@ -435,29 +438,31 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
435438
vocab_size=64,
436439
transformer_impl="modelopt",
437440
).cuda()
438-
439-
prompt_tokens = torch.randint(0, model_ref.vocab_size, (2, model_ref.max_sequence_length)).cuda()
440-
441+
442+
prompt_tokens = torch.randint(
443+
0, model_ref.vocab_size, (2, model_ref.max_sequence_length)
444+
).cuda()
445+
441446
def forward_fn(model):
442447
return megatron_prefill(model, prompt_tokens)
443-
448+
444449
# Quantize the reference model
445450
model_ref = mtq.quantize(model_ref, config, forward_fn)
446-
451+
447452
# CRITICAL: model_test must also be quantized with the same config
448453
# Otherwise it won't have the KV cache quantizer keys when loading state dict
449454
model_test = mtq.quantize(model_test, config, forward_fn)
450-
455+
451456
# Verify KV cache quantizers were created
452457
kv_quantizers_found = False
453458
for name, module in model_ref.named_modules():
454-
if hasattr(module, 'k_bmm_quantizer') and hasattr(module, 'v_bmm_quantizer'):
459+
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
455460
kv_quantizers_found = True
456461
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
457462
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
458-
463+
459464
assert kv_quantizers_found, "No KV cache quantizers found in quantized model"
460-
465+
461466
# Test sharded state dict save/load
462467
sharded_state_dict_test_helper(
463468
tmp_path,
@@ -467,32 +472,38 @@ def forward_fn(model):
467472
meta_device=False,
468473
version=None,
469474
)
470-
475+
471476
# Verify KV cache quantizers are restored correctly in model_test
472477
for (name_ref, module_ref), (name_test, module_test) in zip(
473478
model_ref.named_modules(), model_test.named_modules()
474479
):
475-
if hasattr(module_ref, 'k_bmm_quantizer'):
476-
assert hasattr(module_test, 'k_bmm_quantizer'), f"K quantizer missing after restore in {name_test}"
477-
assert hasattr(module_test, 'v_bmm_quantizer'), f"V quantizer missing after restore in {name_test}"
478-
480+
if hasattr(module_ref, "k_bmm_quantizer"):
481+
assert hasattr(module_test, "k_bmm_quantizer"), (
482+
f"K quantizer missing after restore in {name_test}"
483+
)
484+
assert hasattr(module_test, "v_bmm_quantizer"), (
485+
f"V quantizer missing after restore in {name_test}"
486+
)
487+
479488
# Check that quantizer states match
480-
if hasattr(module_ref.k_bmm_quantizer, '_amax'):
481-
assert hasattr(module_test.k_bmm_quantizer, '_amax'), f"K quantizer _amax missing in {name_test}"
489+
if hasattr(module_ref.k_bmm_quantizer, "_amax"):
490+
assert hasattr(module_test.k_bmm_quantizer, "_amax"), (
491+
f"K quantizer _amax missing in {name_test}"
492+
)
482493
if module_ref.k_bmm_quantizer._amax is not None:
483494
assert torch.allclose(
484-
module_ref.k_bmm_quantizer._amax,
485-
module_test.k_bmm_quantizer._amax
495+
module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax
486496
), f"K quantizer _amax mismatch in {name_test}"
487-
488-
if hasattr(module_ref.v_bmm_quantizer, '_amax'):
489-
assert hasattr(module_test.v_bmm_quantizer, '_amax'), f"V quantizer _amax missing in {name_test}"
497+
498+
if hasattr(module_ref.v_bmm_quantizer, "_amax"):
499+
assert hasattr(module_test.v_bmm_quantizer, "_amax"), (
500+
f"V quantizer _amax missing in {name_test}"
501+
)
490502
if module_ref.v_bmm_quantizer._amax is not None:
491503
assert torch.allclose(
492-
module_ref.v_bmm_quantizer._amax,
493-
module_test.v_bmm_quantizer._amax
504+
module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax
494505
), f"V quantizer _amax mismatch in {name_test}"
495-
506+
496507

497508
@pytest.mark.parametrize(
498509
"config",
@@ -503,16 +514,14 @@ def forward_fn(model):
503514
)
504515
def test_kv_cache_quant(config):
505516
"""Verify KV cache quantization works correctly with TEDotProductAttention.
506-
507-
This test ensures TEDotProductAttention is properly registered and gets the
517+
518+
This test ensures TEDotProductAttention is properly registered and gets the
508519
expected q/k/v_bmm_quantizers when using KV cache configs.
509-
520+
510521
Note: This test requires Transformer Engine to be installed since TEDotProductAttention
511522
is only available with transformer_impl="modelopt" or "transformer_engine" (not "local").
512523
"""
513-
spawn_multiprocess_job(
514-
size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl"
515-
)
524+
spawn_multiprocess_job(size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl")
516525

517526

518527
@pytest.mark.parametrize(
@@ -524,7 +533,7 @@ def test_kv_cache_quant(config):
524533
)
525534
def test_kv_cache_sharded_state_dict(tmp_path, config):
526535
"""Test KV cache quantization with sharded state dict save/load.
527-
536+
528537
This test verifies the complete workflow of saving and loading KV cache quantized
529538
models with distributed checkpointing, ensuring quantizer states are properly
530539
preserved across the save/load cycle.
@@ -533,5 +542,5 @@ def test_kv_cache_sharded_state_dict(tmp_path, config):
533542
spawn_multiprocess_job(
534543
size=size,
535544
job=partial(_test_kv_cache_sharded_state_dict_helper, tmp_path, config),
536-
backend="nccl"
545+
backend="nccl",
537546
)

0 commit comments

Comments
 (0)