Skip to content

Commit 2a61472

Browse files
authored
feat: auto deduce use_oneshot from token_num in all-reduce (#1365)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent b9f218f commit 2a61472

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

flashinfer/comm/trtllm_ar.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -775,10 +775,10 @@ def trtllm_allreduce_fusion(
775775
hidden_dim: int,
776776
workspace_ptrs: torch.Tensor,
777777
launch_with_pdl: bool,
778-
use_oneshot: bool,
779778
trigger_completion_at_end: bool,
780779
fp32_acc: bool,
781780
pattern_code: AllReduceFusionPattern,
781+
use_oneshot: Optional[bool],
782782
allreduce_out: Optional[torch.Tensor],
783783
residual_in: Optional[torch.Tensor],
784784
residual_out: Optional[torch.Tensor],
@@ -815,14 +815,16 @@ def trtllm_allreduce_fusion(
815815
- layout_code: the layout code.
816816
817817
Note:
818-
Regarding the `use_oneshot` parameter:
818+
Regarding the `use_oneshot` parameter, you could force to use the one-shot strategy based on your use case.
819+
Otherwise, it would be enabled if token_num is less than the one-shot max token number (currently 128) for min-latency mode.
820+
"""
819821

820-
It should only be enabled when:
821-
(1) Force to use the one-shot strategy based on your use case.
822-
(2) In min-latency mode, the sequence length is less than the one-shot max token number (currently 128).
822+
if use_oneshot is None:
823+
logging.warning(
824+
f"use_oneshot is not specified. It would be enabled if token_num is less than the one-shot max token number (currently 128) for min-latency mode."
825+
)
826+
use_oneshot = token_num <= 128
823827

824-
Otherwise, it should be disabled (as False).
825-
"""
826828
if not use_oneshot:
827829
assert token_num > world_size, "sequence length should be larger than tp_size"
828830

tests/test_trtllm_allreduce_fusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini
4949
comm.FP4QuantizationSFLayout.SWIZZLED,
5050
]
5151
launch_with_pdls = [True, False]
52-
use_oneshots = [True, False]
52+
use_oneshots = [True, False, None]
5353
trigger_completion_at_ends = [True, False]
5454
fp32_accs = [True, False]
5555

@@ -315,7 +315,7 @@ def multi_process_parallel(
315315
), f"Process {i} failed with exit code {procs[i].exitcode}"
316316

317317

318-
@pytest.mark.parametrize("world_size", [2, 4])
318+
@pytest.mark.parametrize("world_size", [2, 4, 8])
319319
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
320320
@pytest.mark.parametrize("hidden_dim", [1024, 2048, 4096, 7168, 8192])
321321
def test_trtllm_allreduce_fusion(world_size, dtype, hidden_dim):

0 commit comments

Comments
 (0)