Skip to content

Commit 6a6124d

Browse files
sychen52Shiyang Chen
andauthored
[OMNIML-2336][feat] w4a8 nvfp4 fp8 exports scale factor properly (#8180)
Signed-off-by: Shiyang Chen <[email protected]> Co-authored-by: Shiyang Chen <[email protected]>
1 parent f4e7738 commit 6a6124d

File tree

3 files changed

+2
-23
lines changed

3 files changed

+2
-23
lines changed

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,23 +2039,6 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
20392039
return super().load_expert_w2_weight_scale_nvfp4(
20402040
module, w2_weight_scale, dst_w2_weight_scale, 32)
20412041

2042-
def load_all_fp4_weight_scales_and_alphas(
2043-
self, module: torch.nn.Module, weights: Dict,
2044-
load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor,
2045-
dst_w2_weight_scale: torch.Tensor, dst_fc31_alpha: torch.Tensor,
2046-
dst_fc2_alpha: torch.Tensor):
2047-
super().load_all_fp4_weight_scales_and_alphas(
2048-
module, weights, load_expert_ids, dst_w3_w1_weight_scale,
2049-
dst_w2_weight_scale, dst_fc31_alpha, dst_fc2_alpha)
2050-
# The kernel we use will convert nvfp4 to e4m3 before matmul,
2051-
# so the range of the scale factor can only be [0,448/6].
2052-
dst_w3_w1_weight_scale.copy_((dst_w3_w1_weight_scale.to(torch.float32) /
2053-
6.0).to(torch.float8_e4m3fn))
2054-
dst_w2_weight_scale.copy_((dst_w2_weight_scale.to(torch.float32) /
2055-
6.0).to(torch.float8_e4m3fn))
2056-
dst_fc31_alpha.copy_(dst_fc31_alpha * 6.0)
2057-
dst_fc2_alpha.copy_(dst_fc2_alpha * 6.0)
2058-
20592042

20602043
def _get_weight_alignment(weight_alignment, scaling_vector_size, tp_size,
20612044
shard_dim_size):

tensorrt_llm/_torch/modules/linear.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,15 +1027,12 @@ def load_weight_scales(
10271027
tp_mode,
10281028
device=device).contiguous()
10291029
assert ws.dtype == torch.float8_e4m3fn
1030-
# The kernel we use will convert nvfp4 to e4m3 before matmul,
1031-
# so the range of the scale factor can only be [0,448/6].
1032-
ws = (ws.to(torch.float32) / 6.0).to(torch.float8_e4m3fn)
10331030
weight_scale.append(ws.view(dtype=fp4_utils.float4_sf_dtype))
10341031
if "weight_scale_2" in w:
10351032
if weight_scale_2 is None:
1036-
weight_scale_2 = w["weight_scale_2"][...] * 6.0
1033+
weight_scale_2 = w["weight_scale_2"][...]
10371034
else:
1038-
assert weight_scale_2 == w["weight_scale_2"][...] * 6.0, (
1035+
assert weight_scale_2 == w["weight_scale_2"][...], (
10391036
f"The weight_scale_2 should be same for all the weights: {weight_scale_2} vs. {w['weight_scale_2']}*6"
10401037
)
10411038

tests/unittest/_torch/thop/parallel/test_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,6 @@ class TestMoeFp4:
991991
the default tactic selection works. This reduces unnecessary test runs for CI
992992
"""
993993

994-
@pytest.mark.skip(reason="https://nvbugs/5550249")
995994
@pytest.mark.parametrize("num_tokens", [1, 1024])
996995
@pytest.mark.parametrize("hidden_size", [1024])
997996
@pytest.mark.parametrize("intermediate_size", [1024, 768, 384, 192])

0 commit comments

Comments
 (0)