Skip to content

Commit d4acc9d

Browse files
committed
[OMNIML-2336] make w4a8_nvfp4_fp8's scale factor in range of 448/6
Signed-off-by: Shiyang Chen <[email protected]>
1 parent 1537885 commit d4acc9d

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
298298
QUANTIZATION_W4A8_NVFP4_FP8,
299299
]:
300300
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
301+
elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8:
302+
return weight_quantizer._amax.float() / 448.0
301303

302304
# SequentialQuantizer is required
303305
if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled:

0 commit comments

Comments
 (0)