@@ -270,12 +270,18 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
270270 QUANTIZATION_NVFP4_AWQ ,
271271 QUANTIZATION_W4A8_NVFP4_FP8 ,
272272 ]:
273+ if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8 :
274+ # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
275+ # This is because the kernel dequantizes weight to fp8, which is in range 448.
276+ weight_scaling_factor_2 = weight_quantizer ._amax .float () / 448.0
277+ else :
278+ weight_scaling_factor_2 = NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (
279+ weight_quantizer
280+ )
273281 return NVFP4QTensor .get_weights_scaling_factor (
274282 weight ,
275283 weight_quantizer .block_sizes [- 1 ],
276- NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer ).to (
277- weight .device
278- ),
284+ weight_scaling_factor_2 .to (weight .device ),
279285 )[0 ]
280286
281287 if quantization_format in [QUANTIZATION_W4A8_MXFP4_FP8 , QUANTIZATION_MXFP4 ]:
@@ -295,9 +301,12 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
295301 if get_quantization_format (module ) in [
296302 QUANTIZATION_NVFP4 ,
297303 QUANTIZATION_NVFP4_AWQ ,
298- QUANTIZATION_W4A8_NVFP4_FP8 ,
299304 ]:
300305 return NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer )
306+ elif get_quantization_format (module ) == QUANTIZATION_W4A8_NVFP4_FP8 :
307+ # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
308+ # This is because the kernel dequantizes weight to fp8, which is in range 448.
309+ return weight_quantizer ._amax .float () / 448.0
301310
302311 # SequentialQuantizer is required
303312 if not isinstance (weight_quantizer , SequentialQuantizer ) or not weight_quantizer [- 1 ].is_enabled :
0 commit comments