Skip to content

Commit 13042fa

Browse files
committed
fix GQA
Signed-off-by: weimingc <[email protected]>
1 parent 9bd8e41 commit 13042fa

File tree

1 file changed

+66
-10
lines changed

1 file changed

+66
-10
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 66 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -923,10 +923,19 @@ def all_items_same(item_list):
923923
return all(x == item_list[0] for x in item_list)
924924

925925

926+
# Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale))
926927
PQS_FUSE_MODULE_MAPPING = [
927-
# format: (list of target modules, tuple of (linear_pqs_fuse_to, linear_pqs_from), dim to fuse)
928-
(["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj"), "input"),
929-
(["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj"), "output"),
928+
# Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension
929+
# Mathematical equivalence:
930+
# Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T
931+
# After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T
932+
# note: for GQA models, TODO:
933+
(["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")),
934+
# MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension
935+
# Mathematical equivalence:
936+
# Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T
937+
# After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T
938+
(["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")),
930939
]
931940

932941

@@ -947,23 +956,70 @@ def pattern_fuse_prequant(model: torch.nn.Module):
947956
for module_map in PQS_FUSE_MODULE_MAPPING:
948957
target_module_list = module_map[0]
949958
linear_pair = module_map[1]
950-
dim_to_fuse = module_map[2]
951959
if any(module_name in type(module).__name__ for module_name in target_module_list):
952960
linear_to = module.get_submodule(linear_pair[0])
953961
linear_from = module.get_submodule(linear_pair[1])
954962
if hasattr(linear_from, "input_quantizer") and hasattr(
955963
linear_from.input_quantizer, "_pre_quant_scale"
956964
):
957965
pre_quant_scale = linear_from.input_quantizer._pre_quant_scale
958-
# check if we need to apply to the last dimension or the first dimension
959-
pre_quant_scale = (
960-
pre_quant_scale.view(-1, 1)
961-
if dim_to_fuse == "output"
962-
else pre_quant_scale.view(1, -1)
966+
967+
# for GQA/MQA models, we apply averaging to the pre_quant_scale
968+
if pre_quant_scale.numel() != linear_to.weight.shape[0]:
969+
if "attention" not in type(module).__name__.lower():
970+
continue
971+
else:
972+
config = module.config
973+
num_kv_heads = config.num_key_value_heads
974+
kv_head_dim = linear_to.weight.shape[0] // num_kv_heads
975+
n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim
976+
977+
# Reshape:(num_kv_heads, n_rep, kv_head_dim)
978+
averaged_scale = pre_quant_scale.view(
979+
num_kv_heads, n_rep, kv_head_dim
980+
).mean(dim=1)
981+
982+
# To update o_proj, we need to repeat back to original shape
983+
repeated_scale = (
984+
averaged_scale.unsqueeze(1) # (2, 1, 16)
985+
.expand(num_kv_heads, n_rep, kv_head_dim) # (2, 2, 16)
986+
.reshape(-1) # (64,)
987+
)
988+
989+
def _update_pre_quant_scale(module, new_pre_quant_scale):
990+
old_pre_quant_scale = module.input_quantizer._pre_quant_scale
991+
module.weight = nn.Parameter(
992+
module.weight
993+
* old_pre_quant_scale.to(
994+
dtype=module.weight.dtype, device=module.weight.device
995+
)
996+
/ new_pre_quant_scale.to(
997+
dtype=module.weight.dtype, device=module.weight.device
998+
)
999+
)
1000+
module.input_quantizer.pre_quant_scale = new_pre_quant_scale
1001+
1002+
# Redo weights collection
1003+
module.weight_quantizer.reset_amax()
1004+
enable_stats_collection(module.weight_quantizer)
1005+
module.weight_quantizer(module.weight)
1006+
finish_stats_collection(module.weight_quantizer)
1007+
1008+
# Update o_proj's pre_quant_scale
1009+
_update_pre_quant_scale(linear_from, repeated_scale)
1010+
1011+
# Use averaged scale (flattened) for v_proj fusion
1012+
pre_quant_scale = averaged_scale.reshape(-1)
1013+
1014+
# Fuse the pre_quant_scale to v_proj weight (linear_to)
1015+
# v_proj.weight shape: (out_features, in_features) = (32, hidden_size)
1016+
# We scale the output dimension (first dimension)
1017+
linear_to.weight = torch.nn.Parameter(
1018+
linear_to.weight * pre_quant_scale.view(-1, 1)
9631019
)
964-
linear_to.weight = torch.nn.Parameter(linear_to.weight * pre_quant_scale)
9651020
if hasattr(linear_to, "bias") and linear_to.bias is not None:
9661021
linear_to.bias = torch.nn.Parameter(linear_to.bias * pre_quant_scale)
1022+
9671023
delattr(linear_from.input_quantizer, "_pre_quant_scale")
9681024
setattr(linear_from, "fused_with_prequant", True)
9691025

0 commit comments

Comments
 (0)