Skip to content

Commit 9784508

Browse files
📝 Add docstrings to kaix/kvcache_mcore
Docstrings generation was requested by @kaix-nv. * #375 (comment) The following files were modified: * `modelopt/torch/quantization/plugins/megatron.py` * `tests/gpu/torch/quantization/plugins/test_megatron.py`
1 parent 93bab27 commit 9784508

File tree

2 files changed

+120
-23
lines changed

2 files changed

+120
-23
lines changed

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,14 @@ class _RealQuantMegatronRowParallelLinear(
461461
_scale_tensor_shard_axis = 1
462462

463463
def forward(self, input, *args, **kwargs):
464+
"""
465+
Compute the forward pass using the row-parallel linear implementation.
466+
467+
Forwards all positional and keyword arguments to the row-parallel parent implementation.
468+
469+
Returns:
470+
torch.Tensor: The output activations produced by the linear layer.
471+
"""
464472
return _MegatronRowParallelLinear.forward(self, input, *args, **kwargs)
465473

466474

@@ -469,13 +477,19 @@ class _QuantTEDotProductAttention(QuantModule):
469477
"""Quantized version of TEDotProductAttention for Megatron models with KV cache quantization."""
470478

471479
def _setup(self):
472-
"""Initialize quantizers for Q, K, V tensors."""
480+
"""
481+
Create and attach three TensorQuantizer instances as q_bmm_quantizer, k_bmm_quantizer, and v_bmm_quantizer for quantizing query, key, and value tensors.
482+
"""
473483
self.q_bmm_quantizer = TensorQuantizer()
474484
self.k_bmm_quantizer = TensorQuantizer()
475485
self.v_bmm_quantizer = TensorQuantizer()
476486

477487
def _calibrate_quantizers(self):
478-
"""Calibrate quantizers with minimal dummy tensors."""
488+
"""
489+
Calibrate the module's Q/K/V tensor quantizers using minimal dummy inputs.
490+
491+
Creates a tiny float16 dummy tensor shaped according to the attention QKV layout (either "sbhd" or "bshd", determined from self.config.apply_rope_fusion) and uses it to compute and store `_amax` values for any enabled q_bmm_quantizer, k_bmm_quantizer, or v_bmm_quantizer that does not yet have an `_amax`. Calibration is performed only for quantizers that are enabled and lack existing scale information.
492+
"""
479493
# Get device from parent module parameters
480494
device = next(self.parameters()).device if self.parameters() else torch.device("cuda")
481495

@@ -518,10 +532,16 @@ def _calibrate_quantizers(self):
518532
max_calibrate(quantizer, lambda q: q(dummy_tensor), distributed_sync=False)
519533

520534
def forward(self, query, key, value, *args, **kwargs):
521-
"""Apply post-RoPE quantization to KV cache.
522-
523-
TEDotProductAttention receives Q, K, V after RoPE is applied,
524-
so we quantize them directly for KV cache quantization.
535+
"""
536+
Quantize the provided query, key, and value tensors for KV-cache and forward them to the base attention implementation.
537+
538+
Parameters:
539+
query (Tensor): Query tensor (already rotated by RoPE) to be quantized and used for attention.
540+
key (Tensor): Key tensor (already rotated by RoPE) to be quantized and used for attention.
541+
value (Tensor): Value tensor to be quantized and used for attention.
542+
543+
Returns:
544+
The output of the parent attention `forward` called with the quantized query, key, and value.
525545
"""
526546
# Quantize Q, K, V
527547
query = self.q_bmm_quantizer(query)
@@ -531,7 +551,20 @@ def forward(self, query, key, value, *args, **kwargs):
531551
return super().forward(query, key, value, *args, **kwargs)
532552

533553
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
534-
"""Create a sharded state dictionary for distributed checkpointing."""
554+
"""
555+
Builds a sharded state dictionary containing non-quantizer parameters and bmm-quantizer state for distributed checkpointing.
556+
557+
Parameters:
558+
prefix (str): Key prefix to prepend to returned state keys.
559+
sharded_offsets (tuple): Offsets describing shard positions for sharded tensors (passed to make_sharded_tensors_for_checkpoint).
560+
metadata: Ignored by this implementation (kept for API compatibility).
561+
562+
Returns:
563+
state_dict (dict): Mapping from checkpoint keys to tensors, including:
564+
- Non-quantizer module tensors (prefixed).
565+
- Per-quantizer `_amax` entries for q/k/v bmm quantizers when present.
566+
- Other quantizer tensors processed into sharded tensors via the checkpoint helper.
567+
"""
535568
sharded_state_dict = {}
536569

537570
# First add non-quantizer parameters
@@ -566,7 +599,18 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
566599
return sharded_state_dict
567600

568601
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
569-
"""Handle loading state dict for quantizers."""
602+
"""
603+
Adjust quantizer entries in a loaded state dict to match this module's expected keys and tensor shapes before delegating to the parent loader.
604+
605+
This method:
606+
- Renames per-quantizer `_amax` keys from `{prefix}{quantizer_name}._amax` to `{prefix}{quantizer_name}._amax`'s expected TensorQuantizer key format (`{prefix}{quantizer_name}._amax` -> `{prefix}{quantizer_name}._amax` mapped to `{prefix}{quantizer_name}._amax` as `_amax` is remapped to `_{quantizer_name}_amax` format expected by the local TensorQuantizer).
607+
- Reshapes any remaining quantizer state tensors (keys containing `_quantizer` but not `_amax`) to match the corresponding tensor shapes in this module's `state_dict`.
608+
- Calls the superclass `_load_from_state_dict` with the adjusted `state_dict`.
609+
610+
Parameters:
611+
state_dict (dict): The incoming state dictionary being loaded; modified in-place to align quantizer keys and shapes.
612+
prefix (str): The prefix applied to keys for this module in `state_dict`.
613+
"""
570614
for quantizer_name in ["q_bmm_quantizer", "k_bmm_quantizer", "v_bmm_quantizer"]:
571615
full_prefix = f"{prefix}{quantizer_name}."
572616
amax_key = f"{prefix}{quantizer_name}._amax"
@@ -586,10 +630,25 @@ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
586630
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
587631

588632
def modelopt_post_restore(self, name=""):
589-
"""Restore quantizer states after model loading."""
633+
"""
634+
Perform post-restore validation for attention quantizers and trigger calibration if needed.
635+
636+
Checks each of the instance's Q/K/V BMM quantizers (if present and enabled) for unsupported saved state keys and emits a warning identifying the provided `name` when such keys are found. If any enabled quantizer lacks a stored `_amax` value, schedules and runs quantizer calibration by calling self._calibrate_quantizers().
637+
638+
Parameters:
639+
name (str): Human-readable identifier for the module being restored; included in warning messages to help locate the layer.
640+
"""
590641
super().modelopt_post_restore(name)
591642

592643
def _check_unsupported_states(quantizer):
644+
"""
645+
Check a quantizer's saved state keys and warn about any unsupported entries.
646+
647+
Inspects quantizer.state_dict() (if present) and emits a warning for each key other than `_amax` and `_pre_quant_scale` indicating that restoring that key is not supported.
648+
649+
Parameters:
650+
quantizer: An object with a `state_dict()` method (typically a TensorQuantizer) whose saved state keys will be validated.
651+
"""
593652
if not hasattr(quantizer, "state_dict"):
594653
return
595654

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -367,12 +367,26 @@ def forward_fn(model):
367367

368368

369369
def test_fp8_real_quantize():
370+
"""
371+
Runs the FP8 real quantization memory reduction test across all available CUDA devices.
372+
373+
Spawns a multiprocess job (NCCL backend) that executes _test_fp8_real_quantize_helper on each detected GPU.
374+
"""
370375
size = torch.cuda.device_count()
371376
spawn_multiprocess_job(size=size, job=_test_fp8_real_quantize_helper, backend="nccl")
372377

373378

374379
def _test_kv_cache_quant_helper(config, rank, size):
375-
"""Helper function for testing KV cache quantization with TEDotProductAttention."""
380+
"""
381+
Verify that TEDotProductAttention modules receive KV-cache quantization and remain functional after quantization.
382+
383+
Quantizes a minimal GPT model (built with transformer_impl="modelopt") using the provided `config`, checks that each TEDotProductAttention-like module exposes `k_bmm_quantizer` and `v_bmm_quantizer`, asserts those quantizers are enabled, and performs a smoke forward pass to ensure the quantized model runs.
384+
385+
Parameters:
386+
config: Quantization configuration used to quantize the model (e.g., FP8_KV_CFG or NVFP4_KV_CFG).
387+
rank (int): Process rank in the distributed test invocation.
388+
size (int): Tensor model parallel size used to construct the model.
389+
"""
376390
initialize_for_megatron(
377391
tensor_model_parallel_size=size, pipeline_model_parallel_size=1, seed=SEED
378392
)
@@ -392,6 +406,15 @@ def _test_kv_cache_quant_helper(config, rank, size):
392406
prompt_tokens = torch.randint(0, model.vocab_size, (2, model.max_sequence_length)).cuda()
393407

394408
def forward_fn(model):
409+
"""
410+
Run megatron_prefill with the predefined `prompt_tokens` on the given model.
411+
412+
Parameters:
413+
model: The Megatron model to execute the prefill pass on.
414+
415+
Returns:
416+
The outputs produced by `megatron_prefill` for the provided model.
417+
"""
395418
return megatron_prefill(model, prompt_tokens)
396419

397420
# Test KV cache quantization with the given config
@@ -418,7 +441,17 @@ def forward_fn(model):
418441

419442

420443
def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
421-
"""Helper for testing KV cache quantization with sharded state dict save/load."""
444+
"""
445+
Validate that KV-cache quantizers on TEDotProductAttention modules are created and correctly preserved across sharded state_dict save/load.
446+
447+
This helper initializes Megatron, constructs two small GPT models using transformer_impl="modelopt" (TEDotProductAttention), quantizes both models with the provided config (KV quantizers must exist on both before checkpointing), runs a sharded save/load roundtrip, and asserts that k_bmm_quantizer and v_bmm_quantizer instances are present and that their internal `_amax` tensors (when present) match between the reference and restored models.
448+
449+
Parameters:
450+
tmp_path (pathlib.Path): Temporary directory to write sharded checkpoints.
451+
config (dict): Quantization configuration dictionary to use for mtq.quantize.
452+
rank (int): Distributed process rank for this helper.
453+
size (int): Tensor-model-parallel size / world size used to initialize Megatron.
454+
"""
422455
# Disable output_layer quantization (same as other sharded state dict tests)
423456
config["quant_cfg"]["*output_layer*"] = {"enable": False}
424457

@@ -450,6 +483,15 @@ def _test_kv_cache_sharded_state_dict_helper(tmp_path, config, rank, size):
450483
).cuda()
451484

452485
def forward_fn(model):
486+
"""
487+
Run megatron_prefill with the predefined `prompt_tokens` on the given model.
488+
489+
Parameters:
490+
model: The Megatron model to execute the prefill pass on.
491+
492+
Returns:
493+
The outputs produced by `megatron_prefill` for the provided model.
494+
"""
453495
return megatron_prefill(model, prompt_tokens)
454496

455497
# Quantize the reference model
@@ -519,13 +561,10 @@ def forward_fn(model):
519561
],
520562
)
521563
def test_kv_cache_quant(config):
522-
"""Verify KV cache quantization works correctly with TEDotProductAttention.
523-
524-
This test ensures TEDotProductAttention is properly registered and gets the
525-
expected q/k/v_bmm_quantizers when using KV cache configs.
526-
527-
Note: This test requires Transformer Engine to be installed since TEDotProductAttention
528-
is only available with transformer_impl="modelopt" or "transformer_engine" (not "local").
564+
"""
565+
Verify that TEDotProductAttention modules gain the expected KV-cache quantizers when using KV cache configurations.
566+
567+
Runs the KV-cache quantization smoke test (via a single-process multiprocess spawn) and requires Transformer Engine or modelopt since TEDotProductAttention is not available with the local transformer implementation.
529568
"""
530569
spawn_multiprocess_job(size=1, job=partial(_test_kv_cache_quant_helper, config), backend="nccl")
531570

@@ -538,11 +577,10 @@ def test_kv_cache_quant(config):
538577
],
539578
)
540579
def test_kv_cache_sharded_state_dict(tmp_path, config):
541-
"""Test KV cache quantization with sharded state dict save/load.
542-
543-
This test verifies the complete workflow of saving and loading KV cache quantized
544-
models with distributed checkpointing, ensuring quantizer states are properly
545-
preserved across the save/load cycle.
580+
"""
581+
Run a sharded-state-dict save/load test for KV-cache quantized models.
582+
583+
Spawns up to 2 processes using NCCL to execute the sharded-state-dict helper with the provided temporary path and quantization config.
546584
"""
547585
size = min(2, torch.cuda.device_count()) # Use 2 GPUs if available, else 1
548586
spawn_multiprocess_job(

0 commit comments

Comments
 (0)