@@ -923,10 +923,19 @@ def all_items_same(item_list):
923923 return all (x == item_list [0 ] for x in item_list )
924924
925925
926+ # Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale))
926927PQS_FUSE_MODULE_MAPPING = [
927- # format: (list of target modules, tuple of (linear_pqs_fuse_to, linear_pqs_from), dim to fuse)
928- (["LlamaAttention" , "Qwen3Attention" , "Qwen3MoeAttention" ], ("v_proj" , "o_proj" ), "input" ),
929- (["LlamaMLP" , "Qwen3MLP" , "Qwen3MoeMLP" ], ("up_proj" , "down_proj" ), "output" ),
928+ # Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension
929+ # Mathematical equivalence:
930+ # Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T
931+ # After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T
932+ # note: for GQA models, TODO:
933+ (["LlamaAttention" , "Qwen3Attention" , "Qwen3MoeAttention" ], ("v_proj" , "o_proj" )),
934+ # MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension
935+ # Mathematical equivalence:
936+ # Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T
937+ # After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T
938+ (["LlamaMLP" , "Qwen3MLP" , "Qwen3MoeMLP" ], ("up_proj" , "down_proj" )),
930939]
931940
932941
@@ -947,23 +956,70 @@ def pattern_fuse_prequant(model: torch.nn.Module):
947956 for module_map in PQS_FUSE_MODULE_MAPPING :
948957 target_module_list = module_map [0 ]
949958 linear_pair = module_map [1 ]
950- dim_to_fuse = module_map [2 ]
951959 if any (module_name in type (module ).__name__ for module_name in target_module_list ):
952960 linear_to = module .get_submodule (linear_pair [0 ])
953961 linear_from = module .get_submodule (linear_pair [1 ])
954962 if hasattr (linear_from , "input_quantizer" ) and hasattr (
955963 linear_from .input_quantizer , "_pre_quant_scale"
956964 ):
957965 pre_quant_scale = linear_from .input_quantizer ._pre_quant_scale
958- # check if we need to apply to the last dimension or the first dimension
959- pre_quant_scale = (
960- pre_quant_scale .view (- 1 , 1 )
961- if dim_to_fuse == "output"
962- else pre_quant_scale .view (1 , - 1 )
966+
967+ # for GQA/MQA models, we apply averaging to the pre_quant_scale
968+ if pre_quant_scale .numel () != linear_to .weight .shape [0 ]:
969+ if "attention" not in type (module ).__name__ .lower ():
970+ continue
971+ else :
972+ config = module .config
973+ num_kv_heads = config .num_key_value_heads
974+ kv_head_dim = linear_to .weight .shape [0 ] // num_kv_heads
975+ n_rep = pre_quant_scale .numel () // num_kv_heads // kv_head_dim
976+
977+ # Reshape:(num_kv_heads, n_rep, kv_head_dim)
978+ averaged_scale = pre_quant_scale .view (
979+ num_kv_heads , n_rep , kv_head_dim
980+ ).mean (dim = 1 )
981+
982+ # To update o_proj, we need to repeat back to original shape
983+ repeated_scale = (
984+ averaged_scale .unsqueeze (1 ) # (2, 1, 16)
985+ .expand (num_kv_heads , n_rep , kv_head_dim ) # (2, 2, 16)
986+ .reshape (- 1 ) # (64,)
987+ )
988+
989+ def _update_pre_quant_scale (module , new_pre_quant_scale ):
990+ old_pre_quant_scale = module .input_quantizer ._pre_quant_scale
991+ module .weight = nn .Parameter (
992+ module .weight
993+ * old_pre_quant_scale .to (
994+ dtype = module .weight .dtype , device = module .weight .device
995+ )
996+ / new_pre_quant_scale .to (
997+ dtype = module .weight .dtype , device = module .weight .device
998+ )
999+ )
1000+ module .input_quantizer .pre_quant_scale = new_pre_quant_scale
1001+
1002+ # Redo weights collection
1003+ module .weight_quantizer .reset_amax ()
1004+ enable_stats_collection (module .weight_quantizer )
1005+ module .weight_quantizer (module .weight )
1006+ finish_stats_collection (module .weight_quantizer )
1007+
1008+ # Update o_proj's pre_quant_scale
1009+ _update_pre_quant_scale (linear_from , repeated_scale )
1010+
1011+ # Use averaged scale (flattened) for v_proj fusion
1012+ pre_quant_scale = averaged_scale .reshape (- 1 )
1013+
1014+ # Fuse the pre_quant_scale to v_proj weight (linear_to)
1015+ # v_proj.weight shape: (out_features, in_features) = (32, hidden_size)
1016+ # We scale the output dimension (first dimension)
1017+ linear_to .weight = torch .nn .Parameter (
1018+ linear_to .weight * pre_quant_scale .view (- 1 , 1 )
9631019 )
964- linear_to .weight = torch .nn .Parameter (linear_to .weight * pre_quant_scale )
9651020 if hasattr (linear_to , "bias" ) and linear_to .bias is not None :
9661021 linear_to .bias = torch .nn .Parameter (linear_to .bias * pre_quant_scale )
1022+
9671023 delattr (linear_from .input_quantizer , "_pre_quant_scale" )
9681024 setattr (linear_from , "fused_with_prequant" , True )
9691025
0 commit comments