@@ -270,12 +270,16 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
270270 QUANTIZATION_NVFP4_AWQ ,
271271 QUANTIZATION_W4A8_NVFP4_FP8 ,
272272 ]:
273+ if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8 :
274+ # wsf2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
275+ # This is because the kernel dequantizes weight to fp8, which is in range 448.
276+ wsf2 = weight_quantizer ._amax .float () / 448.0
277+ else :
278+ wsf2 = NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer )
273279 return NVFP4QTensor .get_weights_scaling_factor (
274280 weight ,
275281 weight_quantizer .block_sizes [- 1 ],
276- NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer ).to (
277- weight .device
278- ),
282+ wsf2 .to (weight .device ),
279283 )[0 ]
280284
281285 if quantization_format in [QUANTIZATION_W4A8_MXFP4_FP8 , QUANTIZATION_MXFP4 ]:
@@ -295,9 +299,12 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
295299 if get_quantization_format (module ) in [
296300 QUANTIZATION_NVFP4 ,
297301 QUANTIZATION_NVFP4_AWQ ,
298- QUANTIZATION_W4A8_NVFP4_FP8 ,
299302 ]:
300303 return NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer )
304+ elif get_quantization_format (module ) == QUANTIZATION_W4A8_NVFP4_FP8 :
305+ # wsf2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
306+ # This is because the kernel dequantizes weight to fp8, which is in range 448.
307+ return weight_quantizer ._amax .float () / 448.0
301308
302309 # SequentialQuantizer is required
303310 if not isinstance (weight_quantizer , SequentialQuantizer ) or not weight_quantizer [- 1 ].is_enabled :
0 commit comments