Skip to content

Commit e8a009b

Browse files
committed
fix doc
Signed-off-by: weimingc <[email protected]>
1 parent f298096 commit e8a009b

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -943,13 +943,19 @@ def pattern_fuse_prequant(model: torch.nn.Module):
943943
"""Fuse pre_quant_scale to the linear weights.
944944
945945
For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that
946-
The results are mathematically equivalent to the following:
946+
the results are mathematically equivalent to the following::
947947
948-
out_proj.input = (attn_weights @ v_proj.output)
949-
out_proj.output = (out_proj.input * pre_quant_scale) * out_proj.weight
950-
= attn_weights @ (v_proj.output * pre_quant_scale) * out_proj.weight
948+
out_proj.input = (attn_weights @ v_proj.output)
949+
out_proj.output = (out_proj.input * pre_quant_scale) * out_proj.weight
950+
= attn_weights @ (v_proj.output * pre_quant_scale) * out_proj.weight
951951
952-
Note: This is an experimental feature, and it might mess up the quantization errors of fused linear modules.
952+
For GQA/MQA models where v_proj output dimension < o_proj input dimension,
953+
the pre_quant_scale is averaged across the repeated head groups and then the
954+
o_proj's pre_quant_scale is updated to maintain mathematical equivalence.
955+
956+
Note:
957+
This is an experimental feature, and it might mess up the quantization errors
958+
of fused linear modules.
953959
"""
954960
for _, module in model.named_modules():
955961
for module_map in PQS_FUSE_MODULE_MAPPING:

0 commit comments

Comments
 (0)