Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 134 additions & 3 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames

if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"):
return QUANTIZATION_NVFP4_AWQ
if getattr(layer, "fused_with_layernorm", False):
if getattr(layer, "fused_with_prequant", False):
return QUANTIZATION_NVFP4_AWQ
assert input_quantizer is not None, (
f"input_quantizer is None for {quantizer_attr_names}"
Expand Down Expand Up @@ -923,18 +923,149 @@ def all_items_same(item_list):
return all(x == item_list[0] for x in item_list)


# Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale))
PQS_FUSE_MODULE_MAPPING = [
# Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension
# Mathematical equivalence:
# Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T
# After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T
(["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")),
# MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension
# Mathematical equivalence:
# Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T
# After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T
(["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")),
]


def fuse_prequant_to_linear(model: torch.nn.Module, fuse_grouped_heads=False):
"""Fuse pre_quant_scale to the linear weights if possible.

For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that
the results are mathematically equivalent to the following::

out_proj.input = (attn_weights @ v_proj.output)
out_proj.output = (out_proj.input * pre_quant_scale) * out_proj.weight
= attn_weights @ (v_proj.output * pre_quant_scale) * out_proj.weight

For GQA/MQA models where v_proj output dimension < o_proj input dimension,
the pre_quant_scale is averaged across the repeated head groups and then the
o_proj's pre_quant_scale is updated to maintain mathematical equivalence.

Args:
model: The model to fuse pre_quant_scale to.
fuse_grouped_heads: If True, fuse the pre_quant_scale even if dimension between pre_quant_scale
and linear weights is not the same. This is useful for GQA/MQA models but may lead to accuracy
drop.

Note:
Fuse_grouped_heads is useful for GQA/MQA models but may lead to accuracy drop.
"""
# Fuse pre_quant_scale to the linear weights
for _, module in model.named_modules():
for module_map in PQS_FUSE_MODULE_MAPPING:
target_module_list = module_map[0]
linear_pair = module_map[1]
if any(module_name in type(module).__name__ for module_name in target_module_list):
linear_fuse_into = module.get_submodule(linear_pair[0])
linear_pqs_from = module.get_submodule(linear_pair[1])
if hasattr(linear_pqs_from, "input_quantizer") and hasattr(
linear_pqs_from.input_quantizer, "_pre_quant_scale"
):
pre_quant_scale = linear_pqs_from.input_quantizer._pre_quant_scale

# for GQA/MQA models, we can apply averaging to the pre_quant_scale for shared head groups
if pre_quant_scale.numel() != linear_fuse_into.weight.shape[-2]:
if (
not fuse_grouped_heads
or "attention" not in type(module).__name__.lower()
):
warn(
f"Skipping pattern fuse prequant for {type(module).__name__}"
f"pqs dim {pre_quant_scale.numel()} != out_ch dim {linear_fuse_into.weight.shape[-2]}"
)
continue
config = module.config
num_kv_heads = config.num_key_value_heads
kv_head_dim = linear_fuse_into.weight.shape[0] // num_kv_heads
n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim

# Reshape:(num_kv_heads, n_rep, kv_head_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's n_rep here?

averaged_scale = pre_quant_scale.view(
num_kv_heads, n_rep, kv_head_dim
).mean(dim=1)

# To update o_proj, we need to repeat back to original shape
repeated_scale = (
averaged_scale.unsqueeze(1)
.expand(num_kv_heads, n_rep, kv_head_dim)
.reshape(-1)
)

def _update_pre_quant_scale(module, new_pre_quant_scale):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we merge duplicated code with line 1090?

old_pre_quant_scale = module.input_quantizer._pre_quant_scale
module.weight = nn.Parameter(
module.weight
* old_pre_quant_scale.to(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to cast to fp32 for this manipulation?

dtype=module.weight.dtype, device=module.weight.device
)
/ new_pre_quant_scale.to(
dtype=module.weight.dtype, device=module.weight.device
)
)
module.input_quantizer.pre_quant_scale = new_pre_quant_scale

# Redo weights collection
module.weight_quantizer.reset_amax()
enable_stats_collection(module.weight_quantizer)
module.weight_quantizer(module.weight)
finish_stats_collection(module.weight_quantizer)

# Update o_proj's pre_quant_scale
_update_pre_quant_scale(linear_pqs_from, repeated_scale)

# Use averaged scale (flattened) for v_proj fusion
pre_quant_scale = averaged_scale.reshape(-1)

# Fuse the pre_quant_scale to weight
linear_fuse_into.weight = torch.nn.Parameter(
linear_fuse_into.weight * pre_quant_scale.view(-1, 1)
)
if hasattr(linear_fuse_into, "bias") and linear_fuse_into.bias is not None:
linear_fuse_into.bias = torch.nn.Parameter(
linear_fuse_into.bias * pre_quant_scale
)

delattr(linear_pqs_from.input_quantizer, "_pre_quant_scale")
setattr(linear_pqs_from, "fused_with_prequant", True)


def fuse_prequant_layernorm(
layernorm_module: torch.nn.Module,
modules: list[torch.Tensor],
):
"""Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted."""
"""Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted.

original:
layernorm_output = (normalization(input) * weight) + bias
layernorm_output_scaled = layernorm_output * pre_quant_scale

fused:
fused_weight = weight * avg_pre_quant_scale
fused_bias = bias * avg_pre_quant_scale
layernorm_output_scaled = (normalization(input) * fused_weight) + fused_bias
"""
layernorm_module.weight = torch.nn.Parameter(
layernorm_module.weight * getattr(modules[0].input_quantizer, "_pre_quant_scale")
)
if hasattr(layernorm_module, "bias"):
layernorm_module.bias = torch.nn.Parameter(
layernorm_module.bias * getattr(modules[0].input_quantizer, "_pre_quant_scale")
)
# Pre_quant_scales of modules must not be exported, since they have been fused with layernorm
for module in modules:
delattr(module.input_quantizer, "_pre_quant_scale")
setattr(module, "fused_with_layernorm", True)
setattr(module, "fused_with_prequant", True)


def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False):
Expand Down
5 changes: 5 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
from .quant_utils import (
fuse_prequant_layernorm,
fuse_prequant_to_linear,
get_activation_scaling_factor,
get_quant_config,
get_quantization_format,
Expand Down Expand Up @@ -106,6 +107,10 @@ def _output_hook(module, input, output):
fused_linears = {}
module_names = set()

# Fuse pre_quant_scale to the linear weights if possible
if "NVFP4_AWQ" in quantization_format:
fuse_prequant_to_linear(model)

for name, module in model.named_modules():
module_names.add(name)

Expand Down
193 changes: 193 additions & 0 deletions tests/gpu/torch/export/test_quant_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
import torch

pytest.importorskip("transformers")

from transformers import LlamaConfig, LlamaForCausalLM

import modelopt.torch.quantization as mtq
from modelopt.torch.export.quant_utils import fuse_prequant_to_linear


def get_tiny_llama(attention_heads=4, key_value_heads=4):
"""Create a tiny Llama model for testing."""
config = LlamaConfig(
hidden_size=64,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=attention_heads,
num_key_value_heads=key_value_heads,
max_position_embeddings=128,
vocab_size=256,
)
return LlamaForCausalLM(config)


@pytest.mark.parametrize(
"quant_config",
[
mtq.INT4_AWQ_CFG,
mtq.NVFP4_AWQ_LITE_CFG,
],
)
@pytest.mark.parametrize(
"attention_kv_heads_pair",
[
(4, 4), # MHA
(4, 2), # GQA
(4, 1), # MQA
],
)
def test_pattern_fuse_prequant(quant_config, attention_kv_heads_pair):
"""Test pattern_fuse_prequant on modules from a tiny Llama model."""
model = get_tiny_llama(attention_kv_heads_pair[0], attention_kv_heads_pair[1]).to("cuda")

# Quantize the model
dummy_input = torch.randint(0, 256, (1, 16), device="cuda")
mtq.quantize(model, quant_config, lambda m: m(dummy_input))

# Run forward pass before fusion
model.eval()
with torch.no_grad():
output_before_fuse = model(dummy_input)

traget_module_name_list = [
"model.layers.0.self_attn.o_proj",
"model.layers.0.mlp.down_proj",
"model.layers.1.self_attn.o_proj",
"model.layers.1.mlp.down_proj",
]

# Apply fusion
fuse_prequant_to_linear(model, fuse_grouped_heads=True)

# Check if pre_quant_scale and fused_with_prequant flag are removed correctly
for target_module_name in traget_module_name_list:
target_module = model.get_submodule(target_module_name)

# Verify pre_quant_scale was removed
assert not hasattr(target_module.input_quantizer, "_pre_quant_scale"), (
f"{target_module_name}: pre_quant_scale should be removed after fusion"
)

# Verify fused_with_prequant flag was set
assert (
hasattr(target_module, "fused_with_prequant") and target_module.fused_with_prequant
), f"{target_module_name}: fused_with_prequant flag should be set"

# Verify output is close to the original output
with torch.no_grad():
output_after_fuse = model(dummy_input)
# There will be some small difference due to quantization errors after pre_quant_scale fusion to the weights
assert torch.allclose(
output_before_fuse.logits, output_after_fuse.logits, rtol=1e-1, atol=5e-1
), "Output should be the same before and after fusion"


@pytest.mark.parametrize(
"quant_config",
[
mtq.INT4_AWQ_CFG,
mtq.NVFP4_AWQ_LITE_CFG,
],
)
def test_pattern_fuse_prequant_moe(quant_config):
"""Test pattern_fuse_prequant on Qwen3 MoE sparse MLP."""
pytest.importorskip("transformers")
from transformers import Qwen3MoeConfig, Qwen3MoeForCausalLM

# Create a tiny Qwen3MoE model for testing
config = Qwen3MoeConfig(
hidden_size=128,
intermediate_size=256,
moe_intermediate_size=256,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=4,
num_experts=4,
num_experts_per_tok=2,
max_position_embeddings=128,
vocab_size=256,
shared_expert_intermediate_size=256,
)
model = Qwen3MoeForCausalLM(config).to("cuda")

# Quantize the model
dummy_input = torch.randint(0, 256, (1, 16), device="cuda")
mtq.quantize(model, quant_config, lambda m: m(dummy_input))

# Collect MoE expert modules to verify (down_proj should be fused)
moe_down_proj_modules = []
moe_gate_proj_modules = []
moe_up_proj_modules = []
for name, module in model.named_modules():
if "mlp" in name and "experts" in name:
if "gate_proj" in name and not any(x in name for x in ["weight", "quantizer"]):
moe_gate_proj_modules.append((name, module))
elif "down_proj" in name and not any(x in name for x in ["weight", "quantizer"]):
moe_down_proj_modules.append((name, module))
elif "up_proj" in name and not any(x in name for x in ["weight", "quantizer"]):
moe_up_proj_modules.append((name, module))

# Verify experts have pre_quant_scale before fusion
for name, module in moe_gate_proj_modules:
if hasattr(module, "input_quantizer"):
assert hasattr(module.input_quantizer, "_pre_quant_scale"), (
f"{name}: gate_proj should have pre_quant_scale before fusion"
)

for name, module in moe_up_proj_modules:
if hasattr(module, "input_quantizer"):
assert hasattr(module.input_quantizer, "_pre_quant_scale"), (
f"{name}: up_proj should have pre_quant_scale before fusion"
)

for name, module in moe_down_proj_modules:
if hasattr(module, "input_quantizer"):
assert hasattr(module.input_quantizer, "_pre_quant_scale"), (
f"{name}: down_proj should have pre_quant_scale before fusion"
)

# Run forward pass before fusion
model.eval()
with torch.no_grad():
output_before_fuse = model(dummy_input)

# Apply fusion (fuse_mismatch_dim only needed for GQA/MQA attention, not for MLP)
fuse_prequant_to_linear(model)

# Check if down_proj's pre_quant_scale was removed and fused into up_proj
for name, module in moe_down_proj_modules:
if hasattr(module, "input_quantizer"):
# Verify pre_quant_scale was removed from down_proj
assert not hasattr(module.input_quantizer, "_pre_quant_scale"), (
f"{name}: down_proj pre_quant_scale should be removed after fusion"
)
# Verify fused_with_prequant flag was set
assert hasattr(module, "fused_with_prequant") and module.fused_with_prequant, (
f"{name}: down_proj should have fused_with_prequant flag set"
)

# Verify output is close to the original output
with torch.no_grad():
output_after_fuse = model(dummy_input)

# There will be some difference due to quantization errors after pre_quant_scale fusion
assert torch.allclose(
output_before_fuse.logits, output_after_fuse.logits, rtol=1e-1, atol=5e-1
), "Output should be similar before and after Qwen3 MoE fusion"
Loading