@@ -478,8 +478,6 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
478478
479479 if input_quantizer is not None and hasattr (input_quantizer , "_pre_quant_scale" ):
480480 return QUANTIZATION_NVFP4_AWQ
481- if getattr (layer , "fused_with_prequant" , False ):
482- return QUANTIZATION_NVFP4_AWQ
483481 assert input_quantizer is not None , (
484482 f"input_quantizer is None for { quantizer_attr_names } "
485483 )
@@ -962,21 +960,21 @@ def pattern_fuse_prequant(model: torch.nn.Module):
962960 target_module_list = module_map [0 ]
963961 linear_pair = module_map [1 ]
964962 if any (module_name in type (module ).__name__ for module_name in target_module_list ):
965- linear_to = module .get_submodule (linear_pair [0 ])
966- linear_from = module .get_submodule (linear_pair [1 ])
967- if hasattr (linear_from , "input_quantizer" ) and hasattr (
968- linear_from .input_quantizer , "_pre_quant_scale"
963+ linear_fuse_into = module .get_submodule (linear_pair [0 ])
964+ linear_pqs_from = module .get_submodule (linear_pair [1 ])
965+ if hasattr (linear_pqs_from , "input_quantizer" ) and hasattr (
966+ linear_pqs_from .input_quantizer , "_pre_quant_scale"
969967 ):
970- pre_quant_scale = linear_from .input_quantizer ._pre_quant_scale
968+ pre_quant_scale = linear_pqs_from .input_quantizer ._pre_quant_scale
971969
972970 # for GQA/MQA models, we apply averaging to the pre_quant_scale
973- if pre_quant_scale .numel () != linear_to .weight .shape [0 ]:
971+ if pre_quant_scale .numel () != linear_fuse_into .weight .shape [0 ]:
974972 if "attention" not in type (module ).__name__ .lower ():
975973 continue
976974 else :
977975 config = module .config
978976 num_kv_heads = config .num_key_value_heads
979- kv_head_dim = linear_to .weight .shape [0 ] // num_kv_heads
977+ kv_head_dim = linear_fuse_into .weight .shape [0 ] // num_kv_heads
980978 n_rep = pre_quant_scale .numel () // num_kv_heads // kv_head_dim
981979
982980 # Reshape:(num_kv_heads, n_rep, kv_head_dim)
@@ -986,9 +984,9 @@ def pattern_fuse_prequant(model: torch.nn.Module):
986984
987985 # To update o_proj, we need to repeat back to original shape
988986 repeated_scale = (
989- averaged_scale .unsqueeze (1 ) # (2, 1, 16)
990- .expand (num_kv_heads , n_rep , kv_head_dim ) # (2, 2, 16)
991- .reshape (- 1 ) # (64,)
987+ averaged_scale .unsqueeze (1 )
988+ .expand (num_kv_heads , n_rep , kv_head_dim )
989+ .reshape (- 1 )
992990 )
993991
994992 def _update_pre_quant_scale (module , new_pre_quant_scale ):
@@ -1011,22 +1009,22 @@ def _update_pre_quant_scale(module, new_pre_quant_scale):
10111009 finish_stats_collection (module .weight_quantizer )
10121010
10131011 # Update o_proj's pre_quant_scale
1014- _update_pre_quant_scale (linear_from , repeated_scale )
1012+ _update_pre_quant_scale (linear_pqs_from , repeated_scale )
10151013
10161014 # Use averaged scale (flattened) for v_proj fusion
10171015 pre_quant_scale = averaged_scale .reshape (- 1 )
10181016
1019- # Fuse the pre_quant_scale to v_proj weight (linear_to)
1020- # v_proj.weight shape: (out_features, in_features) = (32, hidden_size)
1021- # We scale the output dimension (first dimension)
1022- linear_to .weight = torch .nn .Parameter (
1023- linear_to .weight * pre_quant_scale .view (- 1 , 1 )
1017+ # Fuse the pre_quant_scale to v_proj weight
1018+ linear_fuse_into .weight = torch .nn .Parameter (
1019+ linear_fuse_into .weight * pre_quant_scale .view (- 1 , 1 )
10241020 )
1025- if hasattr (linear_to , "bias" ) and linear_to .bias is not None :
1026- linear_to .bias = torch .nn .Parameter (linear_to .bias * pre_quant_scale )
1021+ if hasattr (linear_fuse_into , "bias" ) and linear_fuse_into .bias is not None :
1022+ linear_fuse_into .bias = torch .nn .Parameter (
1023+ linear_fuse_into .bias * pre_quant_scale
1024+ )
10271025
1028- delattr (linear_from .input_quantizer , "_pre_quant_scale" )
1029- setattr (linear_from , "fused_with_prequant" , True )
1026+ delattr (linear_pqs_from .input_quantizer , "_pre_quant_scale" )
1027+ setattr (linear_pqs_from , "fused_with_prequant" , True )
10301028
10311029
10321030def fuse_prequant_layernorm (
0 commit comments