Skip to content

Commit 6fb5105

Browse files
authored
update allreduce to match trtllm (#1507)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> Updated allreduce launch config logic to match trtllm. On llama3 concurrency=128 tp2 gen-only phase, the kernel time improved from ~26.8us to ~9.8us. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 6bfb43a commit 6fb5105

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

include/flashinfer/comm/trtllm_allreduce_fusion.cuh

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1364,12 +1364,16 @@ cudaError_t allreduce_fusion_kernel_launcher(AllReduceFusionParams<T> const& par
13641364
threads_per_block *= 2;
13651365
cluster_size /= 2;
13661366
}
1367+
int sm_count = get_sm_count();
1368+
while (cluster_num * cluster_size > sm_count && cluster_size > 1 && threads_per_block <= 512) {
1369+
threads_per_block *= 2;
1370+
cluster_size /= 2;
1371+
}
13671372
FLASHINFER_CHECK(oneshot || threads_per_block >= params.nranks,
13681373
"not oneshot, or threads_per_block < nranks");
13691374
int block_size = threads_per_block;
13701375
FLASHINFER_CHECK(block_size <= 1024 && cluster_size > 0,
13711376
"block_size > 1024 or cluster_size <= 0");
1372-
int sm_count = get_sm_count();
13731377
int grid_size = (std::min(sm_count, cluster_num * cluster_size) / cluster_size) * cluster_size;
13741378
cudaLaunchConfig_t cfg;
13751379
cudaLaunchAttribute attribute[2];

0 commit comments

Comments
 (0)