Skip to content

Commit 906ce5c

Browse files
committed
fix based on comments
Signed-off-by: Shiyang Chen <[email protected]>
1 parent ec18006 commit 906ce5c

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,15 +271,17 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
271271
QUANTIZATION_W4A8_NVFP4_FP8,
272272
]:
273273
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
274-
# wsf2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
274+
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
275275
# This is because the kernel dequantizes weight to fp8, which is in range 448.
276-
wsf2 = weight_quantizer._amax.float() / 448.0
276+
weight_scaling_factor_2 = weight_quantizer._amax.float() / 448.0
277277
else:
278-
wsf2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
278+
weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(
279+
weight_quantizer
280+
)
279281
return NVFP4QTensor.get_weights_scaling_factor(
280282
weight,
281283
weight_quantizer.block_sizes[-1],
282-
wsf2.to(weight.device),
284+
weight_scaling_factor_2.to(weight.device),
283285
)[0]
284286

285287
if quantization_format in [QUANTIZATION_W4A8_MXFP4_FP8, QUANTIZATION_MXFP4]:
@@ -302,7 +304,7 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
302304
]:
303305
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
304306
elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8:
305-
# wsf2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
307+
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
306308
# This is because the kernel dequantizes weight to fp8, which is in range 448.
307309
return weight_quantizer._amax.float() / 448.0
308310

0 commit comments

Comments
 (0)