Skip to content

Commit ec18006

Browse files
committed
[OMNIML-2336] make w4a8_nvfp4_fp8's scale factor in range of 448/6
Signed-off-by: Shiyang Chen <[email protected]>
1 parent 1537885 commit ec18006

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,16 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
270270
QUANTIZATION_NVFP4_AWQ,
271271
QUANTIZATION_W4A8_NVFP4_FP8,
272272
]:
273+
if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8:
274+
# wsf2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
275+
# This is because the kernel dequantizes weight to fp8, which is in range 448.
276+
wsf2 = weight_quantizer._amax.float() / 448.0
277+
else:
278+
wsf2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
273279
return NVFP4QTensor.get_weights_scaling_factor(
274280
weight,
275281
weight_quantizer.block_sizes[-1],
276-
NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer).to(
277-
weight.device
278-
),
282+
wsf2.to(weight.device),
279283
)[0]
280284

281285
if quantization_format in [QUANTIZATION_W4A8_MXFP4_FP8, QUANTIZATION_MXFP4]:
@@ -295,9 +299,12 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
295299
if get_quantization_format(module) in [
296300
QUANTIZATION_NVFP4,
297301
QUANTIZATION_NVFP4_AWQ,
298-
QUANTIZATION_W4A8_NVFP4_FP8,
299302
]:
300303
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
304+
elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8:
305+
# wsf2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
306+
# This is because the kernel dequantizes weight to fp8, which is in range 448.
307+
return weight_quantizer._amax.float() / 448.0
301308

302309
# SequentialQuantizer is required
303310
if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled:

0 commit comments

Comments
 (0)