Skip to content

Commit 16ad77f

Browse files
committed
pattern-based fusion
Signed-off-by: weimingc <[email protected]>
1 parent 35f90d0 commit 16ad77f

File tree

2 files changed

+65
-3
lines changed

2 files changed

+65
-3
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
478478

479479
if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"):
480480
return QUANTIZATION_NVFP4_AWQ
481-
if getattr(layer, "fused_with_layernorm", False):
481+
if getattr(layer, "fused_with_prequant", False):
482482
return QUANTIZATION_NVFP4_AWQ
483483
assert input_quantizer is not None, (
484484
f"input_quantizer is None for {quantizer_attr_names}"
@@ -923,18 +923,77 @@ def all_items_same(item_list):
923923
return all(x == item_list[0] for x in item_list)
924924

925925

926+
PQS_FUSE_MODULE_MAPPING = [
927+
# format: (list of target modules, tuple of (linear_pqs_fuse_to, linear_pqs_from), dim to fuse)
928+
(["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj"), "input"),
929+
(["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj"), "output"),
930+
]
931+
932+
933+
# TODO: make this more general instead of rule based
934+
def pattern_fuse_prequant(model: torch.nn.Module):
935+
"""Fuse pre_quant_scale to the linear weights.
936+
937+
For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that
938+
The results are mathematically equivalent to the following:
939+
940+
out_proj.input = (attn_weights @ v_proj.output)
941+
out_proj.output = (out_proj.input * pre_quant_scale) * out_proj.weight
942+
= attn_weights @ (v_proj.output * pre_quant_scale) * out_proj.weight
943+
944+
Note: This is an experimental feature, and it might mess up the quantization errors of fused linear modules.
945+
"""
946+
for _, module in model.named_modules():
947+
for module_map in PQS_FUSE_MODULE_MAPPING:
948+
target_module_list = module_map[0]
949+
linear_pair = module_map[1]
950+
dim_to_fuse = module_map[2]
951+
if any(module_name in type(module).__name__ for module_name in target_module_list):
952+
linear_to = module.get_submodule(linear_pair[0])
953+
linear_from = module.get_submodule(linear_pair[1])
954+
if hasattr(linear_from, "input_quantizer") and hasattr(
955+
linear_from.input_quantizer, "_pre_quant_scale"
956+
):
957+
pre_quant_scale = linear_from.input_quantizer._pre_quant_scale
958+
# check if we need to apply to the last dimension or the first dimension
959+
pre_quant_scale = (
960+
pre_quant_scale.view(-1, 1)
961+
if dim_to_fuse == "output"
962+
else pre_quant_scale.view(1, -1)
963+
)
964+
linear_to.weight = torch.nn.Parameter(linear_to.weight * pre_quant_scale)
965+
if hasattr(linear_to, "bias") and linear_to.bias is not None:
966+
linear_to.bias = torch.nn.Parameter(linear_to.bias * pre_quant_scale)
967+
delattr(linear_from.input_quantizer, "_pre_quant_scale")
968+
setattr(linear_from, "fused_with_prequant", True)
969+
970+
926971
def fuse_prequant_layernorm(
927972
layernorm_module: torch.nn.Module,
928973
modules: list[torch.Tensor],
929974
):
930-
"""Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted."""
975+
"""Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted.
976+
977+
original:
978+
layernorm_output = (normalization(input) * weight) + bias
979+
layernorm_output_scaled = layernorm_output * pre_quant_scale
980+
981+
fused:
982+
fused_weight = weight * avg_pre_quant_scale
983+
fused_bias = bias * avg_pre_quant_scale
984+
layernorm_output_scaled = (normalization(input) * fused_weight) + fused_bias
985+
"""
931986
layernorm_module.weight = torch.nn.Parameter(
932987
layernorm_module.weight * getattr(modules[0].input_quantizer, "_pre_quant_scale")
933988
)
989+
if hasattr(layernorm_module, "bias"):
990+
layernorm_module.bias = torch.nn.Parameter(
991+
layernorm_module.bias * getattr(modules[0].input_quantizer, "_pre_quant_scale")
992+
)
934993
# Pre_quant_scales of modules must not be exported, since they have been fused with layernorm
935994
for module in modules:
936995
delattr(module.input_quantizer, "_pre_quant_scale")
937-
setattr(module, "fused_with_layernorm", True)
996+
setattr(module, "fused_with_prequant", True)
938997

939998

940999
def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False):

modelopt/torch/export/unified_export_hf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
get_weight_scaling_factor,
6565
get_weight_scaling_factor_2,
6666
maybe_transpose_expert_weight_dimensions,
67+
pattern_fuse_prequant,
6768
postprocess_state_dict,
6869
preprocess_linear_fusion,
6970
to_quantized_weight,
@@ -173,6 +174,8 @@ def _output_hook(module, input, output):
173174
# Pre quant scale of modules is already updated to avg_pre_quant_scale
174175
fuse_prequant_layernorm(output_to_layernorm[tensor], modules)
175176

177+
pattern_fuse_prequant(model)
178+
176179
# The dummy forward may not be able to activate all the experts.
177180
# Process experts by naming rules like experts.0, experts.1, etc.
178181
for name, modules_fused in fused_linears.items():

0 commit comments

Comments
 (0)