Skip to content

Commit b0cb9ca

Browse files
kaiyuxzongfeijing
andauthored
[None] [test] Add MNNVL AlltoAll tests to pre-merge (#7466)
Signed-off-by: Kaiyu Xie <[email protected]> Signed-off-by: Zongfei Jing <[email protected]> Co-authored-by: Zongfei Jing <[email protected]>
1 parent dcfd3ef commit b0cb9ca

File tree

6 files changed

+41
-20
lines changed

6 files changed

+41
-20
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from ..modules.embedding import Embedding
6060
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod,
6161
MoEWeightLoadingMode, create_moe)
62+
from ..modules.fused_moe.fused_moe_wide_ep import WideEPMoE
6263
from ..modules.gated_mlp import GatedMLP
6364
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
6465
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@@ -849,6 +850,9 @@ def compute_routed_output(self, hidden_states, hidden_states_fp4,
849850
output_dtype=hidden_states.dtype,
850851
all_rank_num_tokens=all_rank_num_tokens,
851852
use_dp_padding=use_dp_padding,
853+
**({
854+
"alltoall_result_do_sum": False
855+
} if isinstance(self.experts, WideEPMoE) else {}),
852856
)
853857

854858
return routed_output

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -430,14 +430,15 @@ def is_post_quant_all2all_supported(self):
430430
return False
431431

432432
def forward_chunk(
433-
self,
434-
x: Union[torch.Tensor, Fp4QuantizedTensor],
435-
router_logits: torch.Tensor,
436-
use_all_to_all: bool,
437-
output_dtype: Optional[torch.dtype] = None,
438-
all_rank_num_tokens: Optional[List[int]] = None,
439-
use_dp_padding: Optional[bool] = None,
440-
repeating_info: Tuple = (True, True),
433+
self,
434+
x: Union[torch.Tensor, Fp4QuantizedTensor],
435+
router_logits: torch.Tensor,
436+
use_all_to_all: bool,
437+
output_dtype: Optional[torch.dtype] = None,
438+
all_rank_num_tokens: Optional[List[int]] = None,
439+
use_dp_padding: Optional[bool] = None,
440+
repeating_info: Tuple = (True, True),
441+
alltoall_result_do_sum: bool = True,
441442
) -> torch.Tensor:
442443
all_rank_max_num_tokens = max(all_rank_num_tokens)
443444
if isinstance(x, Fp4QuantizedTensor):
@@ -452,7 +453,7 @@ def forward_chunk(
452453
self.layer_load_balancer.start_wait_gpu_stage()
453454

454455
if not use_all_to_all or self.alltoall_method_type != AlltoallMethodType.MNNVL:
455-
pass
456+
alltoall_result_do_sum = True
456457

457458
weight_dtype = self.w3_w1_weight.dtype
458459

@@ -719,7 +720,8 @@ def forward_chunk(
719720
if self.enable_dummy_allreduce:
720721
self.dummy_allreduce()
721722
final_hidden_states = self.alltoall_combine(
722-
final_hidden_states, alltoall_info, token_count)
723+
final_hidden_states, alltoall_info, token_count,
724+
alltoall_result_do_sum)
723725
elif self.alltoall_method_type == AlltoallMethodType.DeepEP:
724726
final_hidden_states = self.unpad_tensors(
725727
padded, final_hidden_states)
@@ -764,6 +766,7 @@ def forward_impl(
764766
output_dtype: Optional[torch.dtype] = None,
765767
all_rank_num_tokens: Optional[List[int]] = None,
766768
use_dp_padding: Optional[bool] = None,
769+
alltoall_result_do_sum: bool = True,
767770
**kwargs,
768771
) -> torch.Tensor:
769772
assert all_rank_num_tokens is not None
@@ -791,7 +794,8 @@ def forward_impl(
791794
output_dtype,
792795
all_rank_num_tokens=all_rank_num_tokens_padded,
793796
use_dp_padding=use_dp_padding,
794-
repeating_info=(is_first_call, is_last_call))
797+
repeating_info=(is_first_call, is_last_call),
798+
alltoall_result_do_sum=alltoall_result_do_sum)
795799
outputs = self.reducescatter_or_allreduce(
796800
outputs,
797801
use_all_to_all,
@@ -849,7 +853,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
849853
all_rank_num_tokens=all_rank_num_tokens_list[
850854
idx_chunk],
851855
use_dp_padding=use_dp_padding,
852-
repeating_info=(is_first_call, is_last_call))
856+
repeating_info=(is_first_call, is_last_call),
857+
alltoall_result_do_sum=alltoall_result_do_sum)
853858
if idx_chunk > 0:
854859
outputs_list[-1] = self.reducescatter_or_allreduce(
855860
outputs_list[-1],
@@ -865,7 +870,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
865870
all_rank_num_tokens=all_rank_num_tokens_list[
866871
idx_chunk],
867872
use_dp_padding=use_dp_padding,
868-
repeating_info=(is_first_call, is_last_call))
873+
repeating_info=(is_first_call, is_last_call),
874+
alltoall_result_do_sum=alltoall_result_do_sum)
869875
with torch.cuda.stream(self.aux_stream):
870876
outputs_list[-1] = self.reducescatter_or_allreduce(
871877
outputs_list[-1],
@@ -879,7 +885,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
879885
router_logits,
880886
use_all_to_all,
881887
all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk],
882-
repeating_info=(is_first_call, is_last_call))
888+
repeating_info=(is_first_call, is_last_call),
889+
alltoall_result_do_sum=alltoall_result_do_sum)
883890

884891
outputs_list.append(outputs)
885892
if not use_all_to_all:
@@ -935,7 +942,8 @@ def alltoall_dispatch(self, x: torch.Tensor, x_sf: Optional[torch.Tensor],
935942
return x, x_sf, token_selected_slots, token_final_scales
936943

937944
def alltoall_combine(self, final_hidden_states: torch.Tensor,
938-
alltoall_info: MoEAlltoallInfo, token_count: int):
945+
alltoall_info: MoEAlltoallInfo, token_count: int,
946+
alltoall_result_do_sum: bool):
939947
top_k = self.routing_method.experts_per_token
940948
if isinstance(final_hidden_states, list):
941949
final_hidden_states = final_hidden_states[0]
@@ -948,7 +956,7 @@ def alltoall_combine(self, final_hidden_states: torch.Tensor,
948956
top_k=top_k,
949957
token_count=token_count,
950958
use_low_precision_combine=self.use_low_precision_combine,
951-
do_reduce=False)
959+
do_reduce=alltoall_result_do_sum)
952960

953961
return final_hidden_states
954962

tensorrt_llm/_torch/modules/fused_moe/interface.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ def forward(
242242
output_dtype: Optional[torch.dtype] = None,
243243
all_rank_num_tokens: Optional[List[int]] = None,
244244
use_dp_padding: Optional[bool] = None,
245+
**kwargs,
245246
) -> Union[torch.Tensor, List[torch.Tensor]]:
246247
if self.register_to_config and is_torch_compiling():
247248
hidden_states = x.fp4_tensor if isinstance(
@@ -274,6 +275,7 @@ def forward(
274275
output_dtype=output_dtype,
275276
all_rank_num_tokens=all_rank_num_tokens,
276277
use_dp_padding=use_dp_padding,
278+
**kwargs,
277279
)
278280

279281
@property

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ l0_dgx_b200:
1616
tests:
1717
- unittest/_torch/multi_gpu_modeling -k "deepseek"
1818
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[DeepEPLowLatency]
19+
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall_fp4[MNNVL]
1920
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
2021
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
2122
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ l0_dgx_h100:
104104
- unittest/_torch/multi_gpu_modeling/test_deepseek.py::test_deepseek_streaming[tp4-bf16-trtllm-deepseekv3_lite]
105105
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEP]
106106
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[DeepEPLowLatency]
107+
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_alltoall[MNNVL]
107108
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype0]
108109
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.VANILLA-dtype1]
109110
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_w4afp8[MoEWeightLoadingMode.W4A8_CUSTOM-dtype0]

tests/unittest/_torch/modules/test_fused_moe.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,14 @@ def per_rank_test_fused_moe_alltoall(job_id):
212212
weights = {}
213213
for expert_id in range(NUM_EXPERTS):
214214
w1_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
215-
dtype=dtype)
215+
dtype=dtype,
216+
device="cuda")
216217
w2_weight = torch.empty((HIDDEN_SIZE, INTERMEDIATE_SIZE),
217-
dtype=dtype)
218+
dtype=dtype,
219+
device="cuda")
218220
w3_weight = torch.empty((INTERMEDIATE_SIZE, HIDDEN_SIZE),
219-
dtype=dtype)
221+
dtype=dtype,
222+
device="cuda")
220223
torch.nn.init.xavier_uniform_(w1_weight)
221224
torch.nn.init.xavier_uniform_(w2_weight)
222225
torch.nn.init.xavier_uniform_(w3_weight)
@@ -292,7 +295,6 @@ def per_rank_test_fused_moe_alltoall(job_id):
292295
assert r is None
293296

294297

295-
@pytest.mark.skip(reason="https://nvbugs/5467531")
296298
@pytest.mark.skipif(torch.cuda.device_count() < 4,
297299
reason="needs 4 GPUs to run this test")
298300
@pytest.mark.parametrize("alltoall_method_type", [
@@ -302,6 +304,9 @@ def per_rank_test_fused_moe_alltoall(job_id):
302304
ids=lambda s: s.name)
303305
def test_fused_moe_alltoall_fp4(alltoall_method_type):
304306

307+
if alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
308+
pytest.skip("Skipped due to https://nvbugs/5467531")
309+
305310
world_size = 4
306311
dtype = torch.bfloat16
307312
HIDDEN_SIZE = 2560

0 commit comments

Comments
 (0)