Skip to content

Commit decb05e

Browse files
committed
[OMNIML-2336] make w4a8_nvfp4_fp8's scale factor in range of 448/6
Signed-off-by: Shiyang Chen <[email protected]> fix based on comments Signed-off-by: Shiyang Chen <[email protected]>
1 parent c511477 commit decb05e

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,18 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
270270
QUANTIZATION_NVFP4_AWQ,
271271
QUANTIZATION_W4A8_NVFP4_FP8,
272272
]:
273+
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
274+
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
275+
# This is because the kernel dequantizes weight to fp8, which is in range 448.
276+
weight_scaling_factor_2 = weight_quantizer._amax.float() / 448.0
277+
else:
278+
weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(
279+
weight_quantizer
280+
)
273281
return NVFP4QTensor.get_weights_scaling_factor(
274282
weight,
275283
weight_quantizer.block_sizes[-1],
276-
NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer).to(
277-
weight.device
278-
),
284+
weight_scaling_factor_2.to(weight.device),
279285
)[0]
280286

281287
if quantization_format in [QUANTIZATION_W4A8_MXFP4_FP8, QUANTIZATION_MXFP4]:
@@ -295,9 +301,12 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
295301
if get_quantization_format(module) in [
296302
QUANTIZATION_NVFP4,
297303
QUANTIZATION_NVFP4_AWQ,
298-
QUANTIZATION_W4A8_NVFP4_FP8,
299304
]:
300305
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
306+
elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8:
307+
# weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
308+
# This is because the kernel dequantizes weight to fp8, which is in range 448.
309+
return weight_quantizer._amax.float() / 448.0
301310

302311
# SequentialQuantizer is required
303312
if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled:

0 commit comments

Comments
 (0)