Skip to content

Commit 966e785

Browse files
committed
[OMNIML-2336][feat] w4a8 nvfp4 fp8 exports scale factor properly
Now, Modelopt will export its scale factor in range 448/6. Signed-off-by: Shiyang Chen <[email protected]>
1 parent 7facac0 commit 966e785

File tree

2 files changed

+3
-22
lines changed

2 files changed

+3
-22
lines changed

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,23 +2039,6 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
20392039
return super().load_expert_w2_weight_scale_nvfp4(
20402040
module, w2_weight_scale, dst_w2_weight_scale, 32)
20412041

2042-
def load_all_fp4_weight_scales_and_alphas(
2043-
self, module: torch.nn.Module, weights: Dict,
2044-
load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor,
2045-
dst_w2_weight_scale: torch.Tensor, dst_fc31_alpha: torch.Tensor,
2046-
dst_fc2_alpha: torch.Tensor):
2047-
super().load_all_fp4_weight_scales_and_alphas(
2048-
module, weights, load_expert_ids, dst_w3_w1_weight_scale,
2049-
dst_w2_weight_scale, dst_fc31_alpha, dst_fc2_alpha)
2050-
# The kernel we use will convert nvfp4 to e4m3 before matmul,
2051-
# so the range of the scale factor can only be [0,448/6].
2052-
dst_w3_w1_weight_scale.copy_((dst_w3_w1_weight_scale.to(torch.float32) /
2053-
6.0).to(torch.float8_e4m3fn))
2054-
dst_w2_weight_scale.copy_((dst_w2_weight_scale.to(torch.float32) /
2055-
6.0).to(torch.float8_e4m3fn))
2056-
dst_fc31_alpha.copy_(dst_fc31_alpha * 6.0)
2057-
dst_fc2_alpha.copy_(dst_fc2_alpha * 6.0)
2058-
20592042

20602043
def _get_weight_alignment(weight_alignment, scaling_vector_size, tp_size,
20612044
shard_dim_size):

tensorrt_llm/_torch/modules/linear.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,15 +1007,13 @@ def load_weight_scales(
10071007
tp_mode,
10081008
device=device).contiguous()
10091009
assert ws.dtype == torch.float8_e4m3fn
1010-
# The kernel we use will convert nvfp4 to e4m3 before matmul,
1011-
# so the range of the scale factor can only be [0,448/6].
1012-
ws = (ws.to(torch.float32) / 6.0).to(torch.float8_e4m3fn)
1010+
ws = ws.to(torch.float8_e4m3fn)
10131011
weight_scale.append(ws.view(dtype=fp4_utils.float4_sf_dtype))
10141012
if "weight_scale_2" in w:
10151013
if weight_scale_2 is None:
1016-
weight_scale_2 = w["weight_scale_2"][...] * 6.0
1014+
weight_scale_2 = w["weight_scale_2"][...]
10171015
else:
1018-
assert weight_scale_2 == w["weight_scale_2"][...] * 6.0, (
1016+
assert weight_scale_2 == w["weight_scale_2"][...], (
10191017
f"The weight_scale_2 should be same for all the weights: {weight_scale_2} vs. {w['weight_scale_2']}*6"
10201018
)
10211019

0 commit comments

Comments
 (0)