@@ -478,7 +478,7 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
478478
479479 if input_quantizer is not None and hasattr (input_quantizer , "_pre_quant_scale" ):
480480 return QUANTIZATION_NVFP4_AWQ
481- if getattr (layer , "fused_with_layernorm " , False ):
481+ if getattr (layer , "fused_with_prequant " , False ):
482482 return QUANTIZATION_NVFP4_AWQ
483483 assert input_quantizer is not None , (
484484 f"input_quantizer is None for { quantizer_attr_names } "
@@ -923,18 +923,77 @@ def all_items_same(item_list):
923923 return all (x == item_list [0 ] for x in item_list )
924924
925925
926+ PQS_FUSE_MODULE_MAPPING = [
927+ # format: (list of target modules, tuple of (linear_pqs_fuse_to, linear_pqs_from), dim to fuse)
928+ (["LlamaAttention" , "Qwen3Attention" , "Qwen3MoeAttention" ], ("v_proj" , "o_proj" ), "input" ),
929+ (["LlamaMLP" , "Qwen3MLP" , "Qwen3MoeMLP" ], ("up_proj" , "down_proj" ), "output" ),
930+ ]
931+
932+
933+ # TODO: make this more general instead of rule based
934+ def pattern_fuse_prequant (model : torch .nn .Module ):
935+ """Fuse pre_quant_scale to the linear weights.
936+
937+ For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that
938+ The results are mathematically equivalent to the following:
939+
940+ out_proj.input = (attn_weights @ v_proj.output)
941+ out_proj.output = (out_proj.input * pre_quant_scale) * out_proj.weight
942+ = attn_weights @ (v_proj.output * pre_quant_scale) * out_proj.weight
943+
944+ Note: This is an experimental feature, and it might mess up the quantization errors of fused linear modules.
945+ """
946+ for _ , module in model .named_modules ():
947+ for module_map in PQS_FUSE_MODULE_MAPPING :
948+ target_module_list = module_map [0 ]
949+ linear_pair = module_map [1 ]
950+ dim_to_fuse = module_map [2 ]
951+ if any (module_name in type (module ).__name__ for module_name in target_module_list ):
952+ linear_to = module .get_submodule (linear_pair [0 ])
953+ linear_from = module .get_submodule (linear_pair [1 ])
954+ if hasattr (linear_from , "input_quantizer" ) and hasattr (
955+ linear_from .input_quantizer , "_pre_quant_scale"
956+ ):
957+ pre_quant_scale = linear_from .input_quantizer ._pre_quant_scale
958+ # check if we need to apply to the last dimension or the first dimension
959+ pre_quant_scale = (
960+ pre_quant_scale .view (- 1 , 1 )
961+ if dim_to_fuse == "output"
962+ else pre_quant_scale .view (1 , - 1 )
963+ )
964+ linear_to .weight = torch .nn .Parameter (linear_to .weight * pre_quant_scale )
965+ if hasattr (linear_to , "bias" ) and linear_to .bias is not None :
966+ linear_to .bias = torch .nn .Parameter (linear_to .bias * pre_quant_scale )
967+ delattr (linear_from .input_quantizer , "_pre_quant_scale" )
968+ setattr (linear_from , "fused_with_prequant" , True )
969+
970+
926971def fuse_prequant_layernorm (
927972 layernorm_module : torch .nn .Module ,
928973 modules : list [torch .Tensor ],
929974):
930- """Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted."""
975+ """Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted.
976+
977+ original:
978+ layernorm_output = (normalization(input) * weight) + bias
979+ layernorm_output_scaled = layernorm_output * pre_quant_scale
980+
981+ fused:
982+ fused_weight = weight * avg_pre_quant_scale
983+ fused_bias = bias * avg_pre_quant_scale
984+ layernorm_output_scaled = (normalization(input) * fused_weight) + fused_bias
985+ """
931986 layernorm_module .weight = torch .nn .Parameter (
932987 layernorm_module .weight * getattr (modules [0 ].input_quantizer , "_pre_quant_scale" )
933988 )
989+ if hasattr (layernorm_module , "bias" ):
990+ layernorm_module .bias = torch .nn .Parameter (
991+ layernorm_module .bias * getattr (modules [0 ].input_quantizer , "_pre_quant_scale" )
992+ )
934993 # Pre_quant_scales of modules must not be exported, since they have been fused with layernorm
935994 for module in modules :
936995 delattr (module .input_quantizer , "_pre_quant_scale" )
937- setattr (module , "fused_with_layernorm " , True )
996+ setattr (module , "fused_with_prequant " , True )
938997
939998
940999def preprocess_linear_fusion (modules : list [torch .nn .Module ], resmooth_only = False ):
0 commit comments