Skip to content

Commit d3346f6

Browse files
authored
bugfix: fix synchronize logic error in tests/comm/test_trtllm_alltoall.py (#1841)
<!-- .github/pull_request_template.md --> ## 📌 Description Fix the incorrect sychronization logic error in the unit test that causes it to hang. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 5f78377 commit d3346f6

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

tests/comm/test_trtllm_alltoall.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,8 @@ def generate_references():
702702
)
703703

704704
stream = torch.cuda.Stream()
705+
cur_stream = torch.cuda.current_stream()
706+
stream.wait_stream(cur_stream)
705707
with torch.cuda.stream(stream):
706708
tllm_alltoall.moe_prepare(
707709
expert_ids_all_ranks[0],
@@ -715,7 +717,7 @@ def generate_references():
715717
slot_count,
716718
top_k,
717719
)
718-
stream.wait_stream(torch.cuda.current_stream())
720+
cur_stream.wait_stream(stream)
719721

720722
# Make torch alloc tensor to avoid cuda sync
721723
prepared_local_experts = []
@@ -776,8 +778,11 @@ def generate_references():
776778
)
777779

778780
# do prepare in parallel
781+
cur_stream = torch.cuda.current_stream()
779782
for rank in range(ep_size):
780-
with torch.cuda.stream(cuda_streams_all_ranks[rank]):
783+
s = cuda_streams_all_ranks[rank]
784+
s.wait_stream(cur_stream)
785+
with torch.cuda.stream(s):
781786
if rank == ep_rank:
782787
(
783788
prepared_local_experts,

0 commit comments

Comments
 (0)