You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Docstrings generation was requested by @kaix-nv.
* #375 (comment)
The following files were modified:
* `modelopt/torch/quantization/plugins/megatron.py`
* `tests/gpu/torch/quantization/plugins/test_megatron.py`
@@ -469,13 +477,19 @@ class _QuantTEDotProductAttention(QuantModule):
469
477
"""Quantized version of TEDotProductAttention for Megatron models with KV cache quantization."""
470
478
471
479
def_setup(self):
472
-
"""Initialize quantizers for Q, K, V tensors."""
480
+
"""
481
+
Create and attach three TensorQuantizer instances as q_bmm_quantizer, k_bmm_quantizer, and v_bmm_quantizer for quantizing query, key, and value tensors.
482
+
"""
473
483
self.q_bmm_quantizer=TensorQuantizer()
474
484
self.k_bmm_quantizer=TensorQuantizer()
475
485
self.v_bmm_quantizer=TensorQuantizer()
476
486
477
487
def_calibrate_quantizers(self):
478
-
"""Calibrate quantizers with minimal dummy tensors."""
488
+
"""
489
+
Calibrate the module's Q/K/V tensor quantizers using minimal dummy inputs.
490
+
491
+
Creates a tiny float16 dummy tensor shaped according to the attention QKV layout (either "sbhd" or "bshd", determined from self.config.apply_rope_fusion) and uses it to compute and store `_amax` values for any enabled q_bmm_quantizer, k_bmm_quantizer, or v_bmm_quantizer that does not yet have an `_amax`. Calibration is performed only for quantizers that are enabled and lack existing scale information.
Adjust quantizer entries in a loaded state dict to match this module's expected keys and tensor shapes before delegating to the parent loader.
604
+
605
+
This method:
606
+
- Renames per-quantizer `_amax` keys from `{prefix}{quantizer_name}._amax` to `{prefix}{quantizer_name}._amax`'s expected TensorQuantizer key format (`{prefix}{quantizer_name}._amax` -> `{prefix}{quantizer_name}._amax` mapped to `{prefix}{quantizer_name}._amax` as `_amax` is remapped to `_{quantizer_name}_amax` format expected by the local TensorQuantizer).
607
+
- Reshapes any remaining quantizer state tensors (keys containing `_quantizer` but not `_amax`) to match the corresponding tensor shapes in this module's `state_dict`.
608
+
- Calls the superclass `_load_from_state_dict` with the adjusted `state_dict`.
609
+
610
+
Parameters:
611
+
state_dict (dict): The incoming state dictionary being loaded; modified in-place to align quantizer keys and shapes.
612
+
prefix (str): The prefix applied to keys for this module in `state_dict`.
"""Restore quantizer states after model loading."""
633
+
"""
634
+
Perform post-restore validation for attention quantizers and trigger calibration if needed.
635
+
636
+
Checks each of the instance's Q/K/V BMM quantizers (if present and enabled) for unsupported saved state keys and emits a warning identifying the provided `name` when such keys are found. If any enabled quantizer lacks a stored `_amax` value, schedules and runs quantizer calibration by calling self._calibrate_quantizers().
637
+
638
+
Parameters:
639
+
name (str): Human-readable identifier for the module being restored; included in warning messages to help locate the layer.
640
+
"""
590
641
super().modelopt_post_restore(name)
591
642
592
643
def_check_unsupported_states(quantizer):
644
+
"""
645
+
Check a quantizer's saved state keys and warn about any unsupported entries.
646
+
647
+
Inspects quantizer.state_dict() (if present) and emits a warning for each key other than `_amax` and `_pre_quant_scale` indicating that restoring that key is not supported.
648
+
649
+
Parameters:
650
+
quantizer: An object with a `state_dict()` method (typically a TensorQuantizer) whose saved state keys will be validated.
"""Helper function for testing KV cache quantization with TEDotProductAttention."""
380
+
"""
381
+
Verify that TEDotProductAttention modules receive KV-cache quantization and remain functional after quantization.
382
+
383
+
Quantizes a minimal GPT model (built with transformer_impl="modelopt") using the provided `config`, checks that each TEDotProductAttention-like module exposes `k_bmm_quantizer` and `v_bmm_quantizer`, asserts those quantizers are enabled, and performs a smoke forward pass to ensure the quantized model runs.
384
+
385
+
Parameters:
386
+
config: Quantization configuration used to quantize the model (e.g., FP8_KV_CFG or NVFP4_KV_CFG).
387
+
rank (int): Process rank in the distributed test invocation.
388
+
size (int): Tensor model parallel size used to construct the model.
"""Helper for testing KV cache quantization with sharded state dict save/load."""
444
+
"""
445
+
Validate that KV-cache quantizers on TEDotProductAttention modules are created and correctly preserved across sharded state_dict save/load.
446
+
447
+
This helper initializes Megatron, constructs two small GPT models using transformer_impl="modelopt" (TEDotProductAttention), quantizes both models with the provided config (KV quantizers must exist on both before checkpointing), runs a sharded save/load roundtrip, and asserts that k_bmm_quantizer and v_bmm_quantizer instances are present and that their internal `_amax` tensors (when present) match between the reference and restored models.
448
+
449
+
Parameters:
450
+
tmp_path (pathlib.Path): Temporary directory to write sharded checkpoints.
451
+
config (dict): Quantization configuration dictionary to use for mtq.quantize.
452
+
rank (int): Distributed process rank for this helper.
453
+
size (int): Tensor-model-parallel size / world size used to initialize Megatron.
454
+
"""
422
455
# Disable output_layer quantization (same as other sharded state dict tests)
Run megatron_prefill with the predefined `prompt_tokens` on the given model.
488
+
489
+
Parameters:
490
+
model: The Megatron model to execute the prefill pass on.
491
+
492
+
Returns:
493
+
The outputs produced by `megatron_prefill` for the provided model.
494
+
"""
453
495
returnmegatron_prefill(model, prompt_tokens)
454
496
455
497
# Quantize the reference model
@@ -519,13 +561,10 @@ def forward_fn(model):
519
561
],
520
562
)
521
563
deftest_kv_cache_quant(config):
522
-
"""Verify KV cache quantization works correctly with TEDotProductAttention.
523
-
524
-
This test ensures TEDotProductAttention is properly registered and gets the
525
-
expected q/k/v_bmm_quantizers when using KV cache configs.
526
-
527
-
Note: This test requires Transformer Engine to be installed since TEDotProductAttention
528
-
is only available with transformer_impl="modelopt" or "transformer_engine" (not "local").
564
+
"""
565
+
Verify that TEDotProductAttention modules gain the expected KV-cache quantizers when using KV cache configurations.
566
+
567
+
Runs the KV-cache quantization smoke test (via a single-process multiprocess spawn) and requires Transformer Engine or modelopt since TEDotProductAttention is not available with the local transformer implementation.
0 commit comments