Skip to content

Commit 6dd1b87

Browse files
committed
minor
Signed-off-by: weimingc <[email protected]>
1 parent c5d9682 commit 6dd1b87

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -478,8 +478,6 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames
478478

479479
if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"):
480480
return QUANTIZATION_NVFP4_AWQ
481-
if getattr(layer, "fused_with_prequant", False):
482-
return QUANTIZATION_NVFP4_AWQ
483481
assert input_quantizer is not None, (
484482
f"input_quantizer is None for {quantizer_attr_names}"
485483
)
@@ -962,21 +960,21 @@ def pattern_fuse_prequant(model: torch.nn.Module):
962960
target_module_list = module_map[0]
963961
linear_pair = module_map[1]
964962
if any(module_name in type(module).__name__ for module_name in target_module_list):
965-
linear_to = module.get_submodule(linear_pair[0])
966-
linear_from = module.get_submodule(linear_pair[1])
967-
if hasattr(linear_from, "input_quantizer") and hasattr(
968-
linear_from.input_quantizer, "_pre_quant_scale"
963+
linear_fuse_into = module.get_submodule(linear_pair[0])
964+
linear_pqs_from = module.get_submodule(linear_pair[1])
965+
if hasattr(linear_pqs_from, "input_quantizer") and hasattr(
966+
linear_pqs_from.input_quantizer, "_pre_quant_scale"
969967
):
970-
pre_quant_scale = linear_from.input_quantizer._pre_quant_scale
968+
pre_quant_scale = linear_pqs_from.input_quantizer._pre_quant_scale
971969

972970
# for GQA/MQA models, we apply averaging to the pre_quant_scale
973-
if pre_quant_scale.numel() != linear_to.weight.shape[0]:
971+
if pre_quant_scale.numel() != linear_fuse_into.weight.shape[0]:
974972
if "attention" not in type(module).__name__.lower():
975973
continue
976974
else:
977975
config = module.config
978976
num_kv_heads = config.num_key_value_heads
979-
kv_head_dim = linear_to.weight.shape[0] // num_kv_heads
977+
kv_head_dim = linear_fuse_into.weight.shape[0] // num_kv_heads
980978
n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim
981979

982980
# Reshape:(num_kv_heads, n_rep, kv_head_dim)
@@ -986,9 +984,9 @@ def pattern_fuse_prequant(model: torch.nn.Module):
986984

987985
# To update o_proj, we need to repeat back to original shape
988986
repeated_scale = (
989-
averaged_scale.unsqueeze(1) # (2, 1, 16)
990-
.expand(num_kv_heads, n_rep, kv_head_dim) # (2, 2, 16)
991-
.reshape(-1) # (64,)
987+
averaged_scale.unsqueeze(1)
988+
.expand(num_kv_heads, n_rep, kv_head_dim)
989+
.reshape(-1)
992990
)
993991

994992
def _update_pre_quant_scale(module, new_pre_quant_scale):
@@ -1011,22 +1009,22 @@ def _update_pre_quant_scale(module, new_pre_quant_scale):
10111009
finish_stats_collection(module.weight_quantizer)
10121010

10131011
# Update o_proj's pre_quant_scale
1014-
_update_pre_quant_scale(linear_from, repeated_scale)
1012+
_update_pre_quant_scale(linear_pqs_from, repeated_scale)
10151013

10161014
# Use averaged scale (flattened) for v_proj fusion
10171015
pre_quant_scale = averaged_scale.reshape(-1)
10181016

1019-
# Fuse the pre_quant_scale to v_proj weight (linear_to)
1020-
# v_proj.weight shape: (out_features, in_features) = (32, hidden_size)
1021-
# We scale the output dimension (first dimension)
1022-
linear_to.weight = torch.nn.Parameter(
1023-
linear_to.weight * pre_quant_scale.view(-1, 1)
1017+
# Fuse the pre_quant_scale to v_proj weight
1018+
linear_fuse_into.weight = torch.nn.Parameter(
1019+
linear_fuse_into.weight * pre_quant_scale.view(-1, 1)
10241020
)
1025-
if hasattr(linear_to, "bias") and linear_to.bias is not None:
1026-
linear_to.bias = torch.nn.Parameter(linear_to.bias * pre_quant_scale)
1021+
if hasattr(linear_fuse_into, "bias") and linear_fuse_into.bias is not None:
1022+
linear_fuse_into.bias = torch.nn.Parameter(
1023+
linear_fuse_into.bias * pre_quant_scale
1024+
)
10271025

1028-
delattr(linear_from.input_quantizer, "_pre_quant_scale")
1029-
setattr(linear_from, "fused_with_prequant", True)
1026+
delattr(linear_pqs_from.input_quantizer, "_pre_quant_scale")
1027+
setattr(linear_pqs_from, "fused_with_prequant", True)
10301028

10311029

10321030
def fuse_prequant_layernorm(

0 commit comments

Comments
 (0)