Skip to content

Commit 5d3b2e8

Browse files
committed
[OMNIML-2336] make w4a8_nvfp4_fp8's scale factor in range of 448/6
Signed-off-by: Shiyang Chen <[email protected]>
1 parent 1537885 commit 5d3b2e8

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

modelopt/torch/export/quant_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,10 @@ def get_weight_scaling_factor_2(module: nn.Module, weight_name: str = "weight")
295295
if get_quantization_format(module) in [
296296
QUANTIZATION_NVFP4,
297297
QUANTIZATION_NVFP4_AWQ,
298-
QUANTIZATION_W4A8_NVFP4_FP8,
299298
]:
300299
return NVFP4QTensor.get_weights_scaling_factor_2_from_quantizer(weight_quantizer)
300+
elif get_quantization_format(module) == QUANTIZATION_W4A8_NVFP4_FP8:
301+
return weight_quantizer._amax.float() / 448.0
301302

302303
# SequentialQuantizer is required
303304
if not isinstance(weight_quantizer, SequentialQuantizer) or not weight_quantizer[-1].is_enabled:

0 commit comments

Comments
 (0)