Skip to content

Commit 93bab27

Browse files
committed
add kv cache quantization for mcore using bmm_quantizers
Signed-off-by: Kai Xu <[email protected]>
1 parent aeff9c8 commit 93bab27

File tree

3 files changed

+85
-67
lines changed

3 files changed

+85
-67
lines changed

modelopt/torch/quantization/plugins/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,4 @@
7373
from .trl import *
7474

7575
with import_plugin("mcore"):
76-
from .mcore import *
76+
from .mcore import *

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,23 @@
2222
import megatron.core.tensor_parallel.layers as megatron_parallel
2323
import megatron.core.transformer.mlp as megatron_mlp
2424
import torch
25+
from megatron.core.extensions.transformer_engine import TEDotProductAttention
2526
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
2627
from megatron.core.transformer import MegatronModule
2728
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
2829
from megatron.core.utils import get_tensor_model_parallel_group_if_none
29-
from megatron.core.extensions.transformer_engine import TEDotProductAttention
3030

3131
from modelopt.torch.opt.plugins.megatron import (
3232
_MegatronMLP,
3333
register_modelopt_extra_state_callbacks,
3434
)
3535
from modelopt.torch.utils.distributed import ParallelState
3636

37+
from ..model_calib import max_calibrate
3738
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
3839
from ..nn.modules.quant_linear import RealQuantLinear
3940
from ..qtensor import QTensorWrapper
4041
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
41-
from ..model_calib import max_calibrate
4242

4343
__all__ = []
4444

@@ -477,24 +477,32 @@ def _setup(self):
477477
def _calibrate_quantizers(self):
478478
"""Calibrate quantizers with minimal dummy tensors."""
479479
# Get device from parent module parameters
480-
device = next(self.parameters()).device if self.parameters() else torch.device('cuda')
481-
480+
device = next(self.parameters()).device if self.parameters() else torch.device("cuda")
481+
482482
# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
483483
batch_size = 1
484484
seq_len = 1
485-
485+
486486
# Get dimensions from config
487487
num_heads = self.config.num_attention_heads
488-
head_dim = self.config.kv_channels if hasattr(self.config, 'kv_channels') else self.config.hidden_size // num_heads
489-
488+
head_dim = (
489+
self.config.kv_channels
490+
if hasattr(self.config, "kv_channels")
491+
else self.config.hidden_size // num_heads
492+
)
493+
490494
# Determine tensor format (default to sbhd if not specified)
491-
apply_rope_fusion = getattr(self.config, 'apply_rope_fusion', False)
495+
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False)
492496
qkv_format = "bshd" if apply_rope_fusion else "sbhd"
493497

494498
if qkv_format == "sbhd":
495-
dummy_tensor = torch.randn(seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16)
499+
dummy_tensor = torch.randn(
500+
seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16
501+
)
496502
else:
497-
dummy_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16)
503+
dummy_tensor = torch.randn(
504+
batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16
505+
)
498506

499507
# Calibrate each quantizer
500508
quantizers = [
@@ -511,7 +519,7 @@ def _calibrate_quantizers(self):
511519

512520
def forward(self, query, key, value, *args, **kwargs):
513521
"""Apply post-RoPE quantization to KV cache.
514-
522+
515523
TEDotProductAttention receives Q, K, V after RoPE is applied,
516524
so we quantize them directly for KV cache quantization.
517525
"""
@@ -525,7 +533,7 @@ def forward(self, query, key, value, *args, **kwargs):
525533
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
526534
"""Create a sharded state dictionary for distributed checkpointing."""
527535
sharded_state_dict = {}
528-
536+
529537
# First add non-quantizer parameters
530538
for k, v in self.state_dict(prefix="", keep_vars=True).items():
531539
if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k:
@@ -542,10 +550,11 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
542550
sharded_state_dict[amax_key] = quantizer._amax
543551

544552
# Process other quantizer parameters in bmm_quantizers
545-
quantizer_state_dict = {}
546-
for k, v in self.state_dict(prefix="", keep_vars=True).items():
547-
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k:
548-
quantizer_state_dict[k] = v
553+
quantizer_state_dict = {
554+
k: v
555+
for k, v in self.state_dict(prefix="", keep_vars=True).items()
556+
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k
557+
}
549558

550559
if quantizer_state_dict:
551560
sharded_state_dict.update(
@@ -584,7 +593,7 @@ def _check_unsupported_states(quantizer):
584593
if not hasattr(quantizer, "state_dict"):
585594
return
586595

587-
for k in quantizer.state_dict().keys():
596+
for k in quantizer.state_dict():
588597
if k not in ["_amax", "_pre_quant_scale"]:
589598
warnings.warn(
590599
f"Restore of {k} for {name} is not supported. The restore of this layer might be "

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 58 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import pytest
1919
import torch
20-
import torch.nn as nn
2120
from _test_utils.import_helper import skip_if_no_megatron
2221
from _test_utils.torch_dist.dist_utils import spawn_multiprocess_job
2322
from _test_utils.torch_dist.plugins.megatron_common import (
@@ -374,8 +373,10 @@ def test_fp8_real_quantize():
374373

375374
def _test_kv_cache_quant_helper(config, rank, size):
376375
"""Helper function for testing KV cache quantization with TEDotProductAttention."""
377-
initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED)
378-
376+
initialize_for_megatron(
377+
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
378+
)
379+
379380
# Use existing infrastructure to create a minimal GPT model with TEDotProductAttention
380381
# Note: transformer_impl must be "modelopt" or "transformer_engine" (not "local") to get TEDotProductAttention
381382
model = get_mcore_gpt_model(
@@ -386,43 +387,45 @@ def _test_kv_cache_quant_helper(config, rank, size):
386387
vocab_size=32,
387388
transformer_impl="modelopt", # This uses TEDotProductAttention via get_gpt_modelopt_spec
388389
).cuda()
389-
390+
390391
# Create dummy input for calibration
391392
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
392-
393+
393394
def forward_fn(model):
394395
return megatron_prefill(model, prompt_tokens)
395-
396+
396397
# Test KV cache quantization with the given config
397398
quantized_model = mtq.quantize(model, config, forward_fn)
398-
399+
399400
# Find TEDotProductAttention modules and verify they have KV cache quantizers
400401
te_attention_found = False
401402
for name, module in quantized_model.named_modules():
402403
# Check if this is a quantized TEDotProductAttention
403-
if hasattr(module, 'q_bmm_quantizer') and hasattr(module, 'k_bmm_quantizer'):
404+
if hasattr(module, "q_bmm_quantizer") and hasattr(module, "k_bmm_quantizer"):
404405
te_attention_found = True
405406
# Verify all expected quantizers exist
406-
assert hasattr(module, 'v_bmm_quantizer'), f"Missing v_bmm_quantizer in {name}"
407-
407+
assert hasattr(module, "v_bmm_quantizer"), f"Missing v_bmm_quantizer in {name}"
408+
408409
# Verify K and V quantizers are enabled (main purpose of KV cache configs)
409410
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
410411
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
411-
412+
412413
assert te_attention_found, "No TEDotProductAttention with KV cache quantizers found in model"
413-
414+
414415
# Quick smoke test that forward still works
415416
output = forward_fn(quantized_model)
416417
assert output is not None, "Forward pass failed"
417-
418+
418419

419420
def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
420421
"""Helper for testing KV cache quantization with sharded state dict save/load."""
421422
# Disable output_layer quantization (same as other sharded state dict tests)
422423
config["quant_cfg"]["*output_layer*"] = {"enable": False}
423-
424-
initialize_for_megatron(tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED)
425-
424+
425+
initialize_for_megatron(
426+
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
427+
)
428+
426429
# Create GPT models with TEDotProductAttention (transformer_impl="modelopt")
427430
model_ref = get_mcore_gpt_model(
428431
tensor_model_parallel_size=size,
@@ -432,7 +435,7 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
432435
vocab_size=64,
433436
transformer_impl="modelopt", # CRITICAL: Use TEDotProductAttention
434437
).cuda()
435-
438+
436439
model_test = get_mcore_gpt_model(
437440
tensor_model_parallel_size=size,
438441
num_layers=2,
@@ -441,29 +444,31 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
441444
vocab_size=64,
442445
transformer_impl="modelopt",
443446
).cuda()
444-
445-
prompt_tokens = torch.randint(0, model_ref.vocab_size, (2, model_ref.max_sequence_length)).cuda()
446-
447+
448+
prompt_tokens = torch.randint(
449+
0, model_ref.vocab_size, (2, model_ref.max_sequence_length)
450+
).cuda()
451+
447452
def forward_fn(model):
448453
return megatron_prefill(model, prompt_tokens)
449-
454+
450455
# Quantize the reference model
451456
model_ref = mtq.quantize(model_ref, config, forward_fn)
452-
457+
453458
# CRITICAL: model_test must also be quantized with the same config
454459
# Otherwise it won't have the KV cache quantizer keys when loading state dict
455460
model_test = mtq.quantize(model_test, config, forward_fn)
456-
461+
457462
# Verify KV cache quantizers were created
458463
kv_quantizers_found = False
459464
for name, module in model_ref.named_modules():
460-
if hasattr(module, 'k_bmm_quantizer') and hasattr(module, 'v_bmm_quantizer'):
465+
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
461466
kv_quantizers_found = True
462467
assert module.k_bmm_quantizer.is_enabled, f"K quantizer not enabled in {name}"
463468
assert module.v_bmm_quantizer.is_enabled, f"V quantizer not enabled in {name}"
464-
469+
465470
assert kv_quantizers_found, "No KV cache quantizers found in quantized model"
466-
471+
467472
# Test sharded state dict save/load
468473
sharded_state_dict_test_helper(
469474
tmp_path,
@@ -473,32 +478,38 @@ def forward_fn(model):
473478
meta_device=False,
474479
version=None,
475480
)
476-
481+
477482
# Verify KV cache quantizers are restored correctly in model_test
478483
for (name_ref, module_ref), (name_test, module_test) in zip(
479484
model_ref.named_modules(), model_test.named_modules()
480485
):
481-
if hasattr(module_ref, 'k_bmm_quantizer'):
482-
assert hasattr(module_test, 'k_bmm_quantizer'), f"K quantizer missing after restore in {name_test}"
483-
assert hasattr(module_test, 'v_bmm_quantizer'), f"V quantizer missing after restore in {name_test}"
484-
486+
if hasattr(module_ref, "k_bmm_quantizer"):
487+
assert hasattr(module_test, "k_bmm_quantizer"), (
488+
f"K quantizer missing after restore in {name_test}"
489+
)
490+
assert hasattr(module_test, "v_bmm_quantizer"), (
491+
f"V quantizer missing after restore in {name_test}"
492+
)
493+
485494
# Check that quantizer states match
486-
if hasattr(module_ref.k_bmm_quantizer, '_amax'):
487-
assert hasattr(module_test.k_bmm_quantizer, '_amax'), f"K quantizer _amax missing in {name_test}"
495+
if hasattr(module_ref.k_bmm_quantizer, "_amax"):
496+
assert hasattr(module_test.k_bmm_quantizer, "_amax"), (
497+
f"K quantizer _amax missing in {name_test}"
498+
)
488499
if module_ref.k_bmm_quantizer._amax is not None:
489500
assert torch.allclose(
490-
module_ref.k_bmm_quantizer._amax,
491-
module_test.k_bmm_quantizer._amax
501+
module_ref.k_bmm_quantizer._amax, module_test.k_bmm_quantizer._amax
492502
), f"K quantizer _amax mismatch in {name_test}"
493-
494-
if hasattr(module_ref.v_bmm_quantizer, '_amax'):
495-
assert hasattr(module_test.v_bmm_quantizer, '_amax'), f"V quantizer _amax missing in {name_test}"
503+
504+
if hasattr(module_ref.v_bmm_quantizer, "_amax"):
505+
assert hasattr(module_test.v_bmm_quantizer, "_amax"), (
506+
f"V quantizer _amax missing in {name_test}"
507+
)
496508
if module_ref.v_bmm_quantizer._amax is not None:
497509
assert torch.allclose(
498-
module_ref.v_bmm_quantizer._amax,
499-
module_test.v_bmm_quantizer._amax
510+
module_ref.v_bmm_quantizer._amax, module_test.v_bmm_quantizer._amax
500511
), f"V quantizer _amax mismatch in {name_test}"
501-
512+
502513

503514
@pytest.mark.parametrize(
504515
"config",
@@ -509,16 +520,14 @@ def forward_fn(model):
509520
)
510521
def test_kv_cache_quant(config):
511522
"""Verify KV cache quantization works correctly with TEDotProductAttention.
512-
513-
This test ensures TEDotProductAttention is properly registered and gets the
523+
524+
This test ensures TEDotProductAttention is properly registered and gets the
514525
expected q/k/v_bmm_quantizers when using KV cache configs.
515-
526+
516527
Note: This test requires Transformer Engine to be installed since TEDotProductAttention
517528
is only available with transformer_impl="modelopt" or "transformer_engine" (not "local").
518529
"""
519-
spawn_multiprocess_job(
520-
size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl"
521-
)
530+
spawn_multiprocess_job(size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl")
522531

523532

524533
@pytest.mark.parametrize(
@@ -530,7 +539,7 @@ def test_kv_cache_quant(config):
530539
)
531540
def test_kv_cache_sharded_state_dict(tmp_path, config):
532541
"""Test KV cache quantization with sharded state dict save/load.
533-
542+
534543
This test verifies the complete workflow of saving and loading KV cache quantized
535544
models with distributed checkpointing, ensuring quantizer states are properly
536545
preserved across the save/load cycle.
@@ -539,5 +548,5 @@ def test_kv_cache_sharded_state_dict(tmp_path, config):
539548
spawn_multiprocess_job(
540549
size=size,
541550
job=partial(_test_kv_cache_sharded_state_dict_helper, tmp_path, config),
542-
backend="nccl"
551+
backend="nccl",
543552
)

0 commit comments

Comments
 (0)