diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 885a12582..73d3c44e6 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -270,12 +270,18 @@ def get_weight_scaling_factor(module: nn.Module, weight_name: str = "weight") -> QUANTIZATION_NVFP4_AWQ, QUANTIZATION_W4A8_NVFP4_FP8, ]: + if quantization_format == QUANTIZATION_W4A8_NVFP4_FP8: + # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. + # This is because the kernel dequantizes weight to fp8, which is in range 448. + weight_scaling_factor_2 = weight_quantizer._amax.float() / 448.0 + else: + weight_scaling_factor_2 = NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer( + weight_quantizer + ) return NVFP4QTensor.get_weights_scaling_factor( weight, weight_quantizer.block_sizes[-1], - NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer).to( - weight.device - ), + weight_scaling_factor_2.to(weight.device), )[0] if quantization_format in [QUANTIZATION_W4A8_MXFP4_FP8, QUANTIZATION_MXFP4]: @@ -295,9 +301,12 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight") if get_quantization_format(module) in [ QUANTIZATION_NVFP4, QUANTIZATION_NVFP4_AWQ, - QUANTIZATION_W4A8_NVFP4_FP8, ]: return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer) + elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8: + # weight_scaling_factor_2 for w4a8 needs to be amax/448, so that the wsf is in range 448/6. + # This is because the kernel dequantizes weight to fp8, which is in range 448. + return weight_quantizer._amax.float() / 448.0 # SequentialQuantizer is required if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled: