Skip to content

Commit 1853f43

Browse files
committed
updated condition in get_quant_config
Signed-off-by: Suguna Velury <[email protected]>
1 parent 231c147 commit 1853f43

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1079,28 +1079,30 @@ def get_quant_config(named_modules: nn.Module | dict[str, nn.Module]) -> dict[st
10791079
# Try to get block size from each weight attribute (e.g., gate_up_proj, down_proj)
10801080
block_size = 0
10811081
weight_names = list(weight_attr_names(module))
1082+
weight_quantizer_enabled = False
10821083

10831084
for weight_name in weight_names:
10841085
weight_block_size = get_weight_block_size(module, weight_name)
10851086
if weight_block_size > 0:
10861087
block_size = weight_block_size
1088+
weight_quantizer_enabled = True
10871089
break
10881090

10891091
# Fallback to default weight quantizer if no specific weight quantizer found
10901092
if block_size == 0:
10911093
block_size = get_weight_block_size(module)
1094+
weight_quantizer = getattr(
1095+
module, quantizer_attr_names("weight").weight_quantizer, None
1096+
)
1097+
# Check if weight_quantizer is enabled
1098+
weight_quantizer_enabled = block_size > 0 or (
1099+
weight_quantizer is not None and weight_quantizer.is_enabled
1100+
)
10921101

1093-
# In the case of NVFP4, block_size 0 indicates weight_quantizer is not enabled
1094-
if block_size == 0 and quantization_format in [
1095-
QUANTIZATION_NVFP4,
1096-
QUANTIZATION_NVFP4_AWQ,
1097-
QUANTIZATION_W4A8_NVFP4_FP8,
1098-
]:
1099-
continue
1100-
1101-
# Construct per layer config dictionary
1102-
layer_config_dict[name + ".quantization"] = quantization_format
1103-
layer_config_dict[name + ".awq_block_size"] = block_size
1102+
if weight_quantizer_enabled:
1103+
# Construct per layer config dictionary
1104+
layer_config_dict[name + ".quantization"] = quantization_format
1105+
layer_config_dict[name + ".awq_block_size"] = block_size
11041106

11051107
# Find kv cache quant format
11061108
if (

0 commit comments

Comments
 (0)