@@ -271,15 +271,17 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") ->
271271 QUANTIZATION_W4A8_NVFP4_FP8 ,
272272 ]:
273273 if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8 :
274- # wsf2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
274+ # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
275275 # This is because the kernel dequantizes weight to fp8, which is in range 448.
276- wsf2 = weight_quantizer ._amax .float () / 448.0
276+ weight_scaling_factor_2 = weight_quantizer ._amax .float () / 448.0
277277 else :
278- wsf2 = NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer )
278+ weight_scaling_factor_2 = NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (
279+ weight_quantizer
280+ )
279281 return NVFP4QTensor .get_weights_scaling_factor (
280282 weight ,
281283 weight_quantizer .block_sizes [- 1 ],
282- wsf2 .to (weight .device ),
284+ weight_scaling_factor_2 .to (weight .device ),
283285 )[0 ]
284286
285287 if quantization_format in [QUANTIZATION_W4A8_MXFP4_FP8 , QUANTIZATION_MXFP4 ]:
@@ -302,7 +304,7 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
302304 ]:
303305 return NVFP4QTensor .get_weights_scaling_factor_2_from_quantizer (weight_quantizer )
304306 elif get_quantization_format (module ) == QUANTIZATION_W4A8_NVFP4_FP8 :
305- # wsf2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
307+ # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6.
306308 # This is because the kernel dequantizes weight to fp8, which is in range 448.
307309 return weight_quantizer ._amax .float () / 448.0
308310
0 commit comments