Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modelopt/torch/quantization/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,4 @@
from .trl import *

with import_plugin("mcore"):
from .mcore import *
from .mcore import *
43 changes: 26 additions & 17 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,23 @@
import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import torch
from megatron.core.extensions.transformer_engine import TEDotProductAttention
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
from megatron.core.utils import get_tensor_model_parallel_group_if_none
from megatron.core.extensions.transformer_engine import TEDotProductAttention

from modelopt.torch.opt.plugins.megatron import (
_MegatronMLP,
register_modelopt_extra_state_callbacks,
)
from modelopt.torch.utils.distributed import ParallelState

from ..model_calib import max_calibrate
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
from ..nn.modules.quant_linear import RealQuantLinear
from ..qtensor import QTensorWrapper
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
from ..model_calib import max_calibrate

__all__ = []

Expand Down Expand Up @@ -477,24 +477,32 @@ def _setup(self):
def _calibrate_quantizers(self):
"""Calibrate quantizers with minimal dummy tensors."""
# Get device from parent module parameters
device = next(self.parameters()).device if self.parameters() else torch.device('cuda')
device = next(self.parameters()).device if self.parameters() else torch.device("cuda")

# TEDotProductAttention expects format 'sbhd' or 'bshd' depending on rope_fusion
batch_size = 1
seq_len = 1

# Get dimensions from config
num_heads = self.config.num_attention_heads
head_dim = self.config.kv_channels if hasattr(self.config, 'kv_channels') else self.config.hidden_size // num_heads

head_dim = (
self.config.kv_channels
if hasattr(self.config, "kv_channels")
else self.config.hidden_size // num_heads
)

# Determine tensor format (default to sbhd if not specified)
apply_rope_fusion = getattr(self.config, 'apply_rope_fusion', False)
apply_rope_fusion = getattr(self.config, "apply_rope_fusion", False)
qkv_format = "bshd" if apply_rope_fusion else "sbhd"

if qkv_format == "sbhd":
dummy_tensor = torch.randn(seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16)
dummy_tensor = torch.randn(
seq_len, batch_size, num_heads, head_dim, device=device, dtype=torch.float16
)
else:
dummy_tensor = torch.randn(batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16)
dummy_tensor = torch.randn(
batch_size, seq_len, num_heads, head_dim, device=device, dtype=torch.float16
)

# Calibrate each quantizer
quantizers = [
Expand All @@ -511,7 +519,7 @@ def _calibrate_quantizers(self):

def forward(self, query, key, value, *args, **kwargs):
"""Apply post-RoPE quantization to KV cache.

TEDotProductAttention receives Q, K, V after RoPE is applied,
so we quantize them directly for KV cache quantization.
"""
Expand All @@ -525,7 +533,7 @@ def forward(self, query, key, value, *args, **kwargs):
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
"""Create a sharded state dictionary for distributed checkpointing."""
sharded_state_dict = {}

# First add non-quantizer parameters
for k, v in self.state_dict(prefix="", keep_vars=True).items():
if isinstance(v, torch.Tensor) and v is not None and "_quantizer" not in k:
Expand All @@ -542,10 +550,11 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
sharded_state_dict[amax_key] = quantizer._amax

# Process other quantizer parameters in bmm_quantizers
quantizer_state_dict = {}
for k, v in self.state_dict(prefix="", keep_vars=True).items():
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k:
quantizer_state_dict[k] = v
quantizer_state_dict = {
k: v
for k, v in self.state_dict(prefix="", keep_vars=True).items()
if isinstance(v, torch.Tensor) and "_quantizer" in k and "_amax" not in k
}

if quantizer_state_dict:
sharded_state_dict.update(
Expand Down Expand Up @@ -584,7 +593,7 @@ def _check_unsupported_states(quantizer):
if not hasattr(quantizer, "state_dict"):
return

for k in quantizer.state_dict().keys():
for k in quantizer.state_dict():
if k not in ["_amax", "_pre_quant_scale"]:
warnings.warn(
f"Restore of {k} for {name} is not supported. The restore of this layer might be "
Expand Down
Loading