Skip to content

Commit 23421f7

Browse files
committed
[OMNIML-2336][feat] w4a8 nvfp4 fp8 exports scale factor properly
Now, Modelopt will export its scale factor in range 448/6. Signed-off-by: Shiyang Chen <[email protected]>
1 parent 7facac0 commit 23421f7

File tree

2 files changed

+3
-13
lines changed

2 files changed

+3
-13
lines changed

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,14 +2047,6 @@ def load_all_fp4_weight_scales_and_alphas(
20472047
super().load_all_fp4_weight_scales_and_alphas(
20482048
module, weights, load_expert_ids, dst_w3_w1_weight_scale,
20492049
dst_w2_weight_scale, dst_fc31_alpha, dst_fc2_alpha)
2050-
# The kernel we use will convert nvfp4 to e4m3 before matmul,
2051-
# so the range of the scale factor can only be [0,448/6].
2052-
dst_w3_w1_weight_scale.copy_((dst_w3_w1_weight_scale.to(torch.float32) /
2053-
6.0).to(torch.float8_e4m3fn))
2054-
dst_w2_weight_scale.copy_((dst_w2_weight_scale.to(torch.float32) /
2055-
6.0).to(torch.float8_e4m3fn))
2056-
dst_fc31_alpha.copy_(dst_fc31_alpha * 6.0)
2057-
dst_fc2_alpha.copy_(dst_fc2_alpha * 6.0)
20582050

20592051

20602052
def _get_weight_alignment(weight_alignment, scaling_vector_size, tp_size,

tensorrt_llm/_torch/modules/linear.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,15 +1007,13 @@ def load_weight_scales(
10071007
tp_mode,
10081008
device=device).contiguous()
10091009
assert ws.dtype == torch.float8_e4m3fn
1010-
# The kernel we use will convert nvfp4 to e4m3 before matmul,
1011-
# so the range of the scale factor can only be [0,448/6].
1012-
ws = (ws.to(torch.float32) / 6.0).to(torch.float8_e4m3fn)
1010+
ws = ws.to(torch.float8_e4m3fn)
10131011
weight_scale.append(ws.view(dtype=fp4_utils.float4_sf_dtype))
10141012
if "weight_scale_2" in w:
10151013
if weight_scale_2 is None:
1016-
weight_scale_2 = w["weight_scale_2"][...] * 6.0
1014+
weight_scale_2 = w["weight_scale_2"][...]
10171015
else:
1018-
assert weight_scale_2 == w["weight_scale_2"][...] * 6.0, (
1016+
assert weight_scale_2 == w["weight_scale_2"][...], (
10191017
f"The weight_scale_2 should be same for all the weights: {weight_scale_2} vs. {w['weight_scale_2']}*6"
10201018
)
10211019

0 commit comments

Comments
 (0)