File tree Expand file tree Collapse file tree 1 file changed +11
-5
lines changed Expand file tree Collapse file tree 1 file changed +11
-5
lines changed Original file line number Diff line number Diff line change @@ -943,13 +943,19 @@ def pattern_fuse_prequant(model: torch.nn.Module):
943943 """Fuse pre_quant_scale to the linear weights.
944944
945945 For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that
946- The results are mathematically equivalent to the following:
946+ the results are mathematically equivalent to the following: :
947947
948- out_proj.input = (attn_weights @ v_proj.output)
949- out_proj.output = (out_proj.input * pre_quant_scale) * out_proj.weight
950- = attn_weights @ (v_proj.output * pre_quant_scale) * out_proj.weight
948+ out_proj.input = (attn_weights @ v_proj.output)
949+ out_proj.output = (out_proj.input * pre_quant_scale) * out_proj.weight
950+ = attn_weights @ (v_proj.output * pre_quant_scale) * out_proj.weight
951951
952- Note: This is an experimental feature, and it might mess up the quantization errors of fused linear modules.
952+ For GQA/MQA models where v_proj output dimension < o_proj input dimension,
953+ the pre_quant_scale is averaged across the repeated head groups and then the
954+ o_proj's pre_quant_scale is updated to maintain mathematical equivalence.
955+
956+ Note:
957+ This is an experimental feature, and it might mess up the quantization errors
958+ of fused linear modules.
953959 """
954960 for _ , module in model .named_modules ():
955961 for module_map in PQS_FUSE_MODULE_MAPPING :
You can’t perform that action at this time.
0 commit comments