Skip to content

Commit 1a85c43

Browse files
authored
bugfix: fix the register overflow issue for topk renorm kernels on blackwell (#1597)
<!-- .github/pull_request_template.md --> ## 📌 Description On blackwell, we will encounter `failed with error code too many resources requested for launch` (register overflow) issue when using the `TopKRenormProbKernel` (another issue that only happens on blackwell). This PR fixes the issue by reducing the loop unroll number to reduce the number of registers used in the kernel. This PR is based upon #1596 , these two PRs together will fix the failed UTs on blackwell. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc @bkryu
1 parent 52f850c commit 1a85c43

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

include/flashinfer/sampling.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1891,7 +1891,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType*
18911891
ValueCount<float> aggregate_gt_pivot_0{0, 0}, aggregate_gt_pivot_1{0, 0};
18921892
min_gt_low = high;
18931893
max_le_high = low;
1894-
#pragma unroll 2
1894+
#pragma unroll 1
18951895
for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) {
18961896
probs_vec.fill(0);
18971897
if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) {

0 commit comments

Comments
 (0)