@@ -938,9 +938,8 @@ def all_items_same(item_list):
938938]
939939
940940
941- # TODO: make this more general instead of rule based
942- def pattern_fuse_prequant (model : torch .nn .Module , fuse_mismatch_dim = False ):
943- """Fuse pre_quant_scale to the linear weights.
941+ def fuse_prequant_to_linear (model : torch .nn .Module , fuse_grouped_heads = False ):
942+ """Fuse pre_quant_scale to the linear weights if possible.
944943
945944 For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that
946945 the results are mathematically equivalent to the following::
@@ -955,26 +954,13 @@ def pattern_fuse_prequant(model: torch.nn.Module, fuse_mismatch_dim=False):
955954
956955 Args:
957956 model: The model to fuse pre_quant_scale to.
958- fuse_mismatch_dim : If True, fuse the pre_quant_scale even if dimension between pre_quant_scale
957+ fuse_grouped_heads : If True, fuse the pre_quant_scale even if dimension between pre_quant_scale
959958 and linear weights is not the same. This is useful for GQA/MQA models but may lead to accuracy
960959 drop.
961960
962961 Note:
963- This is an experimental feature, and it might mess up the quantization errors
964- of fused linear modules.
962+ Fuse_grouped_heads is useful for GQA/MQA models but may lead to accuracy drop.
965963 """
966- # For MoE models, let's first resmooth the w1 and w3 in experts to get the average pre_quant_scale
967- for _ , module in model .named_modules ():
968- if (
969- hasattr (module , "experts" )
970- and "Qwen3MoeSparseMoeBlock" .lower () in type (module ).__name__ .lower ()
971- ):
972- linear_list = []
973- linear_list .extend ([getattr (expert , "up_proj" ) for expert in module .experts ])
974- linear_list .extend ([getattr (expert , "gate_proj" ) for expert in module .experts ])
975- preprocess_linear_fusion (linear_list , resmooth_only = True )
976-
977- # import pdb; pdb.set_trace()
978964 # Fuse pre_quant_scale to the linear weights
979965 for _ , module in model .named_modules ():
980966 for module_map in PQS_FUSE_MODULE_MAPPING :
@@ -988,10 +974,10 @@ def pattern_fuse_prequant(model: torch.nn.Module, fuse_mismatch_dim=False):
988974 ):
989975 pre_quant_scale = linear_pqs_from .input_quantizer ._pre_quant_scale
990976
991- # for GQA/MQA models, we apply averaging to the pre_quant_scale for shared head groups
977+ # for GQA/MQA models, we can apply averaging to the pre_quant_scale for shared head groups
992978 if pre_quant_scale .numel () != linear_fuse_into .weight .shape [- 2 ]:
993979 if (
994- not fuse_mismatch_dim
980+ not fuse_grouped_heads
995981 or "attention" not in type (module ).__name__ .lower ()
996982 ):
997983 warn (
@@ -1041,7 +1027,7 @@ def _update_pre_quant_scale(module, new_pre_quant_scale):
10411027 # Use averaged scale (flattened) for v_proj fusion
10421028 pre_quant_scale = averaged_scale .reshape (- 1 )
10431029
1044- # Fuse the pre_quant_scale to v_proj weight
1030+ # Fuse the pre_quant_scale to weight
10451031 linear_fuse_into .weight = torch .nn .Parameter (
10461032 linear_fuse_into .weight * pre_quant_scale .view (- 1 , 1 )
10471033 )
0 commit comments