Skip to content

Commit f040ef9

Browse files
authored
[https://nvbugs/5467531][fix] Fix moe test and wide ep fake impl (#8883)
Signed-off-by: Jin Li <[email protected]>
1 parent c2fe686 commit f040ef9

File tree

2 files changed

+61
-35
lines changed

2 files changed

+61
-35
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,3 +1041,30 @@ def load_weights(self, weights: List[Dict]):
10411041

10421042
def post_load_weights(self):
10431043
self.quant_method.post_load_weights(self)
1044+
1045+
def forward_fake(
1046+
self,
1047+
x: Union[torch.Tensor, Fp4QuantizedTensor],
1048+
router_logits: torch.Tensor,
1049+
*,
1050+
do_finalize: bool = True,
1051+
output_dtype: Optional[torch.dtype] = None,
1052+
all_rank_num_tokens: Optional[List[int]] = None,
1053+
use_dp_padding: Optional[bool] = None,
1054+
**kwargs,
1055+
) -> Union[torch.Tensor, List[torch.Tensor]]:
1056+
moe_output = super().forward_fake(
1057+
x,
1058+
router_logits,
1059+
do_finalize=do_finalize,
1060+
output_dtype=torch.bfloat16,
1061+
all_rank_num_tokens=all_rank_num_tokens,
1062+
use_dp_padding=use_dp_padding,
1063+
**kwargs)
1064+
if self.alltoall_method_type == AlltoallMethodType.MNNVL:
1065+
shape = moe_output.shape
1066+
top_k = self.routing_method.experts_per_token
1067+
new_shape = [shape[0], top_k, shape[1]]
1068+
return moe_output.new_empty(new_shape)
1069+
else:
1070+
return moe_output

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,6 @@ def per_rank_test_fused_moe_alltoall(job_id):
292292
assert r is None
293293

294294

295-
@pytest.mark.skip(reason="https://nvbugs/5467531")
296295
@pytest.mark.skipif(torch.cuda.device_count() < 4,
297296
reason="needs 4 GPUs to run this test")
298297
@pytest.mark.parametrize("alltoall_method_type", [
@@ -304,7 +303,7 @@ def test_fused_moe_alltoall_fp4(alltoall_method_type):
304303

305304
world_size = 4
306305
dtype = torch.bfloat16
307-
HIDDEN_SIZE = 2560
306+
HIDDEN_SIZE = 4096
308307
INTERMEDIATE_SIZE = 1536
309308
NUM_EXPERTS = 72
310309
TOP_K = 6
@@ -320,8 +319,8 @@ def test_fused_moe_alltoall_fp4(alltoall_method_type):
320319
x_list = []
321320
m = MAX_NUM_TOKENS
322321
while m >= 1:
323-
x = torch.randn((m, HIDDEN_SIZE), dtype=dtype, device="cuda")
324-
x_list.append(x.cuda(i))
322+
x = torch.randn((m, HIDDEN_SIZE), dtype=dtype)
323+
x_list.append(x)
325324
m //= 2
326325

327326
x_abs_max = torch.cat([x.flatten() for x in x_list]).abs().max().float()
@@ -366,49 +365,37 @@ def test_fused_moe_alltoall_fp4(alltoall_method_type):
366365
w3_sf_block_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
367366
w3_sf_block.cpu().view(INTERMEDIATE_SIZE, -1))
368367

369-
w1_input_scale = x_sf_global.cuda(i)
370-
w2_input_scale = x_sf_global.cuda(i)
371-
w3_input_scale = x_sf_global.cuda(i)
368+
weights[f"{expert_id}.w1.weight"] = w1_weight_nvfp4.cpu()
369+
weights[f"{expert_id}.w2.weight"] = w2_weight_nvfp4.cpu()
370+
weights[f"{expert_id}.w3.weight"] = w3_weight_nvfp4.cpu()
371+
weights[f"{expert_id}.w1.weight_scale"] = w1_sf_block_unswizzled
372+
weights[f"{expert_id}.w2.weight_scale"] = w2_sf_block_unswizzled
373+
weights[f"{expert_id}.w3.weight_scale"] = w3_sf_block_unswizzled
372374

373-
weights[f"{expert_id}.w1.weight"] = w1_weight_nvfp4.cuda(i)
374-
weights[f"{expert_id}.w2.weight"] = w2_weight_nvfp4.cuda(i)
375-
weights[f"{expert_id}.w3.weight"] = w3_weight_nvfp4.cuda(i)
376-
weights[
377-
f"{expert_id}.w1.weight_scale"] = w1_sf_block_unswizzled.cuda(i)
378-
weights[
379-
f"{expert_id}.w2.weight_scale"] = w2_sf_block_unswizzled.cuda(i)
380-
weights[
381-
f"{expert_id}.w3.weight_scale"] = w3_sf_block_unswizzled.cuda(i)
382-
383-
weights[f"{expert_id}.w1.input_scale"] = 1.0 / w1_input_scale.cuda(
384-
i)
385-
weights[f"{expert_id}.w2.input_scale"] = 1.0 / w2_input_scale.cuda(
386-
i)
387-
weights[f"{expert_id}.w3.input_scale"] = 1.0 / w3_input_scale.cuda(
388-
i)
389-
weights[f"{expert_id}.w1.weight_scale_2"] = 1.0 / w3_w1_global.cuda(
390-
i)
391-
weights[f"{expert_id}.w2.weight_scale_2"] = 1.0 / w2_sf_global.cuda(
392-
i)
393-
weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global.cuda(
394-
i)
375+
weights[f"{expert_id}.w1.input_scale"] = 1.0 / x_sf_global
376+
weights[f"{expert_id}.w2.input_scale"] = 1.0 / x_sf_global
377+
weights[f"{expert_id}.w3.input_scale"] = 1.0 / x_sf_global
378+
weights[f"{expert_id}.w1.weight_scale_2"] = 1.0 / w3_w1_global.cpu()
379+
weights[f"{expert_id}.w2.weight_scale_2"] = 1.0 / w2_sf_global.cpu()
380+
weights[f"{expert_id}.w3.weight_scale_2"] = 1.0 / w3_w1_global.cpu()
395381

396382
x_list_world.append(x_list)
397383
weights_world.append(weights)
384+
torch.cuda.synchronize()
398385

399-
def per_rank_test_fused_moe_alltoall(job_id):
386+
def per_rank_test_fused_moe_alltoall(job_id, weights, x_list):
400387
routing_method = DefaultMoeRoutingMethod(top_k=TOP_K)
401388
mapping = Mapping(world_size=world_size,
402-
rank=mpi_rank(),
389+
rank=job_id,
403390
tp_size=world_size,
404391
moe_ep_size=world_size,
405392
moe_tp_size=1,
406393
enable_attention_dp=True)
407394
torch.cuda.set_device(mapping.rank)
408395
torch.manual_seed(mapping.rank)
409396

410-
x_list = x_list_world[mapping.rank]
411-
weights = weights_world[mapping.rank]
397+
weights = {k: v.cuda() for k, v in weights.items()}
398+
x_list = [x.cuda() for x in x_list]
412399

413400
quant_config = QuantConfig(quant_algo=QuantAlgo.NVFP4)
414401
with mock.patch.object(WideEPMoE,
@@ -459,6 +446,16 @@ def per_rank_test_fused_moe_alltoall(job_id):
459446
router_logits,
460447
all_rank_num_tokens=all_rank_num_tokens,
461448
use_dp_padding=False)
449+
# Verify the fake impl is correct.
450+
output_fake = alltoall_model.forward_fake(
451+
x,
452+
router_logits,
453+
all_rank_num_tokens=all_rank_num_tokens,
454+
use_dp_padding=False)
455+
assert output_fake.shape == output.shape
456+
assert output_fake.dtype == output.dtype
457+
if len(output.shape) == 3:
458+
output = torch.sum(output, dim=1, keepdim=False)
462459
ref_output = ref_model.forward(
463460
x,
464461
router_logits,
@@ -470,8 +467,10 @@ def per_rank_test_fused_moe_alltoall(job_id):
470467
m //= 2
471468

472469
with MPIPoolExecutor(max_workers=world_size) as executor:
473-
results = executor.map(per_rank_test_fused_moe_alltoall,
474-
range(world_size))
470+
results = executor.map(
471+
per_rank_test_fused_moe_alltoall,
472+
*zip(*[(i, weights_world[i], x_list_world[i])
473+
for i in range(world_size)]))
475474
for r in results:
476475
assert r is None
477476

0 commit comments

Comments
 (0)