Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2039,23 +2039,6 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
return super().load_expert_w2_weight_scale_nvfp4(
module, w2_weight_scale, dst_w2_weight_scale, 32)

def load_all_fp4_weight_scales_and_alphas(
self, module: torch.nn.Module, weights: Dict,
load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor,
dst_w2_weight_scale: torch.Tensor, dst_fc31_alpha: torch.Tensor,
dst_fc2_alpha: torch.Tensor):
super().load_all_fp4_weight_scales_and_alphas(
module, weights, load_expert_ids, dst_w3_w1_weight_scale,
dst_w2_weight_scale, dst_fc31_alpha, dst_fc2_alpha)
# The kernel we use will convert nvfp4 to e4m3 before matmul,
# so the range of the scale factor can only be [0,448/6].
dst_w3_w1_weight_scale.copy_((dst_w3_w1_weight_scale.to(torch.float32) /
6.0).to(torch.float8_e4m3fn))
dst_w2_weight_scale.copy_((dst_w2_weight_scale.to(torch.float32) /
6.0).to(torch.float8_e4m3fn))
dst_fc31_alpha.copy_(dst_fc31_alpha * 6.0)
dst_fc2_alpha.copy_(dst_fc2_alpha * 6.0)


def _get_weight_alignment(weight_alignment, scaling_vector_size, tp_size,
shard_dim_size):
Expand Down
7 changes: 2 additions & 5 deletions tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,15 +1007,12 @@ def load_weight_scales(
tp_mode,
device=device).contiguous()
assert ws.dtype == torch.float8_e4m3fn
# The kernel we use will convert nvfp4 to e4m3 before matmul,
# so the range of the scale factor can only be [0,448/6].
ws = (ws.to(torch.float32) / 6.0).to(torch.float8_e4m3fn)
weight_scale.append(ws.view(dtype=fp4_utils.float4_sf_dtype))
if "weight_scale_2" in w:
if weight_scale_2 is None:
weight_scale_2 = w["weight_scale_2"][...] * 6.0
weight_scale_2 = w["weight_scale_2"][...]
else:
assert weight_scale_2 == w["weight_scale_2"][...] * 6.0, (
assert weight_scale_2 == w["weight_scale_2"][...], (
f"The weight_scale_2 should be same for all the weights: {weight_scale_2} vs. {w['weight_scale_2']}*6"
)

Expand Down
1 change: 0 additions & 1 deletion tests/unittest/_torch/thop/parallel/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,6 @@ class TestMoeFp4:
the default tactic selection works. This reduces unnecessary test runs for CI
"""

@pytest.mark.skip(reason="https://nvbugs/5550249")
@pytest.mark.parametrize("num_tokens", [1, 1024])
@pytest.mark.parametrize("hidden_size", [1024])
@pytest.mark.parametrize("intermediate_size", [1024, 768, 384, 192])
Expand Down