Skip to content

Commit 6020e94

Browse files
committed
fix moe fusion
Signed-off-by: weimingc <[email protected]>
1 parent a5a6e39 commit 6020e94

File tree

3 files changed

+14
-27
lines changed

3 files changed

+14
-27
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -938,9 +938,8 @@ def all_items_same(item_list):
938938
]
939939

940940

941-
# TODO: make this more general instead of rule based
942-
def pattern_fuse_prequant(model: torch.nn.Module, fuse_mismatch_dim=False):
943-
"""Fuse pre_quant_scale to the linear weights.
941+
def fuse_prequant_to_linear(model: torch.nn.Module, fuse_grouped_heads=False):
942+
"""Fuse pre_quant_scale to the linear weights if possible.
944943
945944
For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that
946945
the results are mathematically equivalent to the following::
@@ -955,26 +954,13 @@ def pattern_fuse_prequant(model: torch.nn.Module, fuse_mismatch_dim=False):
955954
956955
Args:
957956
model: The model to fuse pre_quant_scale to.
958-
fuse_mismatch_dim: If True, fuse the pre_quant_scale even if dimension between pre_quant_scale
957+
fuse_grouped_heads: If True, fuse the pre_quant_scale even if dimension between pre_quant_scale
959958
and linear weights is not the same. This is useful for GQA/MQA models but may lead to accuracy
960959
drop.
961960
962961
Note:
963-
This is an experimental feature, and it might mess up the quantization errors
964-
of fused linear modules.
962+
Fuse_grouped_heads is useful for GQA/MQA models but may lead to accuracy drop.
965963
"""
966-
# For MoE models, let's first resmooth the w1 and w3 in experts to get the average pre_quant_scale
967-
for _, module in model.named_modules():
968-
if (
969-
hasattr(module, "experts")
970-
and "Qwen3MoeSparseMoeBlock".lower() in type(module).__name__.lower()
971-
):
972-
linear_list = []
973-
linear_list.extend([getattr(expert, "up_proj") for expert in module.experts])
974-
linear_list.extend([getattr(expert, "gate_proj") for expert in module.experts])
975-
preprocess_linear_fusion(linear_list, resmooth_only=True)
976-
977-
# import pdb; pdb.set_trace()
978964
# Fuse pre_quant_scale to the linear weights
979965
for _, module in model.named_modules():
980966
for module_map in PQS_FUSE_MODULE_MAPPING:
@@ -988,10 +974,10 @@ def pattern_fuse_prequant(model: torch.nn.Module, fuse_mismatch_dim=False):
988974
):
989975
pre_quant_scale = linear_pqs_from.input_quantizer._pre_quant_scale
990976

991-
# for GQA/MQA models, we apply averaging to the pre_quant_scale for shared head groups
977+
# for GQA/MQA models, we can apply averaging to the pre_quant_scale for shared head groups
992978
if pre_quant_scale.numel() != linear_fuse_into.weight.shape[-2]:
993979
if (
994-
not fuse_mismatch_dim
980+
not fuse_grouped_heads
995981
or "attention" not in type(module).__name__.lower()
996982
):
997983
warn(
@@ -1041,7 +1027,7 @@ def _update_pre_quant_scale(module, new_pre_quant_scale):
10411027
# Use averaged scale (flattened) for v_proj fusion
10421028
pre_quant_scale = averaged_scale.reshape(-1)
10431029

1044-
# Fuse the pre_quant_scale to v_proj weight
1030+
# Fuse the pre_quant_scale to weight
10451031
linear_fuse_into.weight = torch.nn.Parameter(
10461032
linear_fuse_into.weight * pre_quant_scale.view(-1, 1)
10471033
)

modelopt/torch/export/unified_export_hf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,14 @@
5757
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
5858
from .quant_utils import (
5959
fuse_prequant_layernorm,
60+
fuse_prequant_to_linear,
6061
get_activation_scaling_factor,
6162
get_quant_config,
6263
get_quantization_format,
6364
get_weight_block_size,
6465
get_weight_scaling_factor,
6566
get_weight_scaling_factor_2,
6667
maybe_transpose_expert_weight_dimensions,
67-
pattern_fuse_prequant,
6868
postprocess_state_dict,
6969
preprocess_linear_fusion,
7070
to_quantized_weight,
@@ -107,6 +107,9 @@ def _output_hook(module, input, output):
107107
fused_linears = {}
108108
module_names = set()
109109

110+
# Fuse pre_quant_scale to the linear weights if possible
111+
fuse_prequant_to_linear(model)
112+
110113
for name, module in model.named_modules():
111114
module_names.add(name)
112115

@@ -174,8 +177,6 @@ def _output_hook(module, input, output):
174177
# Pre quant scale of modules is already updated to avg_pre_quant_scale
175178
fuse_prequant_layernorm(output_to_layernorm[tensor], modules)
176179

177-
pattern_fuse_prequant(model)
178-
179180
# The dummy forward may not be able to activate all the experts.
180181
# Process experts by naming rules like experts.0, experts.1, etc.
181182
for name, modules_fused in fused_linears.items():

tests/gpu/torch/export/test_quant_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from transformers import LlamaConfig, LlamaForCausalLM
2222

2323
import modelopt.torch.quantization as mtq
24-
from modelopt.torch.export.quant_utils import pattern_fuse_prequant
24+
from modelopt.torch.export.quant_utils import fuse_prequant_to_linear
2525

2626

2727
def get_tiny_llama(attention_heads=4, key_value_heads=4):
@@ -74,7 +74,7 @@ def test_pattern_fuse_prequant(quant_config, attention_kv_heads_pair):
7474
]
7575

7676
# Apply fusion
77-
pattern_fuse_prequant(model, fuse_mismatch_dim=True)
77+
fuse_prequant_to_linear(model, fuse_grouped_heads=True)
7878

7979
# Check if pre_quant_scale and fused_with_prequant flag are removed correctly
8080
for target_module_name in traget_module_name_list:
@@ -172,7 +172,7 @@ def test_pattern_fuse_prequant_moe(quant_config):
172172
output_before_fuse = model(dummy_input)
173173

174174
# Apply fusion (fuse_mismatch_dim only needed for GQA/MQA attention, not for MLP)
175-
pattern_fuse_prequant(model)
175+
fuse_prequant_to_linear(model)
176176

177177
# Check if down_proj's pre_quant_scale was removed and fused into up_proj
178178
for name, module in moe_down_proj_modules:

0 commit comments

Comments
 (0)