Skip to content

Commit cba02c4

Browse files
perf: add fast path to TopPRenormProbKernel for top_p >= 1.0, significantly boosting SGLang workloads (#1483)
<!-- .github/pull_request_template.md --> ## 📌 Description ### Summary When deploying **DeepSeek-R1** with **DP+EP+MTP** in SGLang and sending requests with `temperature > 0.0`, we observed a **\~15% performance degradation**. Profiling revealed that `TopPRenormProbKernel` took about **13 ms**, whereas disabling MTP switches to `TopKTopPSamplingFromProbKernel`, which only costs \~**100 μs**. By default, SGLang sets `top_p = 1.0`. --- ### Root Cause Code analysis shows: * **`TopKTopPSamplingFromProbKernel`** Has an early-stop when `top_k > vocab_size` and `top_p >= 1.0`: ```cu if (aggregate_gt_pivot_0.count < k && aggregate_gt_pivot_0.value < p) { // case 1: pivot_0 accepted break; } ``` * **`TopKRenormProbKernel`** Has a fast path when `top_k > vocab_size`: ```cu if (k < d) { // do search } // Write probs directly to output. ``` * **`TopPRenormProbKernel`** Performs a ternary search until `min_gt_low == max_le_high`, even when `top_p >= 1.0`. This makes the search both **time-consuming** and **unnecessary** in such cases. --- ### Optimization When `top_p >= 1.0`, ternary search can be skipped entirely, and only normalization is needed (to handle the rare case where `sum(probs) != 1.0`). This matches the fast-path logic in the other kernels. --- ### Performance Results **Before Optimization** ``` vocab_size: 128512, batch_size: 1, distrib: normal_distribution(std=1), p: 1.0, duration: 1642.94 us, effective bandwidth: 0.63 GB/s vocab_size: 128512, batch_size: 1, distrib: normal_distribution(std=5), p: 1.0, duration: 4790.69 us, effective bandwidth: 0.21 GB/s vocab_size: 128512, batch_size: 1, distrib: gumbel_distribution(beta=0.1), p: 1.0, duration: 5107.44 us, effective bandwidth: 0.20 GB/s vocab_size: 128512, batch_size: 1, distrib: gumbel_distribution(beta=1), p: 1.0, duration: 2114.08 us, effective bandwidth: 0.49 GB/s vocab_size: 128512, batch_size: 16, distrib: normal_distribution(std=1), p: 1.0, duration: 1960.58 us, effective bandwidth: 8.39 GB/s vocab_size: 128512, batch_size: 16, distrib: normal_distribution(std=5), p: 1.0, duration: 6848.14 us, effective bandwidth: 2.40 GB/s vocab_size: 128512, batch_size: 16, distrib: gumbel_distribution(beta=0.1), p: 1.0, duration: 12505.41 us, effective bandwidth: 1.32 GB/s vocab_size: 128512, batch_size: 16, distrib: gumbel_distribution(beta=1), p: 1.0, duration: 2744.32 us, effective bandwidth: 5.99 GB/s vocab_size: 128512, batch_size: 32, distrib: normal_distribution(std=1), p: 1.0, duration: 2116.48 us, effective bandwidth: 15.54 GB/s vocab_size: 128512, batch_size: 32, distrib: normal_distribution(std=5), p: 1.0, duration: 6851.92 us, effective bandwidth: 4.80 GB/s vocab_size: 128512, batch_size: 32, distrib: gumbel_distribution(beta=0.1), p: 1.0, duration: 12682.69 us, effective bandwidth: 2.59 GB/s vocab_size: 128512, batch_size: 32, distrib: gumbel_distribution(beta=1), p: 1.0, duration: 2905.60 us, effective bandwidth: 11.32 GB/s vocab_size: 128512, batch_size: 64, distrib: normal_distribution(std=1), p: 1.0, duration: 2905.23 us, effective bandwidth: 22.65 GB/s vocab_size: 128512, batch_size: 64, distrib: normal_distribution(std=5), p: 1.0, duration: 7010.11 us, effective bandwidth: 9.39 GB/s vocab_size: 128512, batch_size: 64, distrib: gumbel_distribution(beta=0.1), p: 1.0, duration: 12832.13 us, effective bandwidth: 5.13 GB/s vocab_size: 128512, batch_size: 64, distrib: gumbel_distribution(beta=1), p: 1.0, duration: 3066.14 us, effective bandwidth: 21.46 GB/s vocab_size: 128512, batch_size: 128, distrib: normal_distribution(std=1), p: 1.0, duration: 3924.93 us, effective bandwidth: 33.53 GB/s vocab_size: 128512, batch_size: 128, distrib: normal_distribution(std=5), p: 1.0, duration: 11966.38 us, effective bandwidth: 11.00 GB/s vocab_size: 128512, batch_size: 128, distrib: gumbel_distribution(beta=0.1), p: 1.0, duration: 17618.42 us, effective bandwidth: 7.47 GB/s vocab_size: 128512, batch_size: 128, distrib: gumbel_distribution(beta=1), p: 1.0, duration: 5036.77 us, effective bandwidth: 26.13 GB/s vocab_size: 128512, batch_size: 256, distrib: normal_distribution(std=1), p: 1.0, duration: 7226.40 us, effective bandwidth: 36.42 GB/s vocab_size: 128512, batch_size: 256, distrib: normal_distribution(std=5), p: 1.0, duration: 21546.50 us, effective bandwidth: 12.22 GB/s vocab_size: 128512, batch_size: 256, distrib: gumbel_distribution(beta=0.1), p: 1.0, duration: 29879.33 us, effective bandwidth: 8.81 GB/s vocab_size: 128512, batch_size: 256, distrib: gumbel_distribution(beta=1), p: 1.0, duration: 9606.72 us, effective bandwidth: 27.40 GB/s vocab_size: 128512, batch_size: 512, distrib: normal_distribution(std=1), p: 1.0, duration: 12952.24 us, effective bandwidth: 40.64 GB/s vocab_size: 128512, batch_size: 512, distrib: normal_distribution(std=5), p: 1.0, duration: 37822.32 us, effective bandwidth: 13.92 GB/s vocab_size: 128512, batch_size: 512, distrib: gumbel_distribution(beta=0.1), p: 1.0, duration: 49733.36 us, effective bandwidth: 10.58 GB/s vocab_size: 128512, batch_size: 512, distrib: gumbel_distribution(beta=1), p: 1.0, duration: 17081.22 us, effective bandwidth: 30.82 GB/s ``` **After Optimization** ``` vocab_size: 128512, batch_size: 1, distrib: normal_distribution(std=1), p: 1.0, duration: 21.60 us, effective bandwidth: 47.60 GB/s vocab_size: 128512, batch_size: 1, distrib: normal_distribution(std=5), p: 1.0, duration: 21.70 us, effective bandwidth: 47.39 GB/s vocab_size: 128512, batch_size: 1, distrib: gumbel_distribution(beta=0.1), p: 1.0, duration: 21.86 us, effective bandwidth: 47.04 GB/s vocab_size: 128512, batch_size: 1, distrib: gumbel_distribution(beta=1), p: 1.0, duration: 21.70 us, effective bandwidth: 47.39 GB/s vocab_size: 128512, batch_size: 16, distrib: normal_distribution(std=1), p: 1.0, duration: 23.39 us, effective bandwidth: 703.21 GB/s vocab_size: 128512, batch_size: 16, distrib: normal_distribution(std=5), p: 1.0, duration: 23.14 us, effective bandwidth: 710.99 GB/s vocab_size: 128512, batch_size: 16, distrib: gumbel_distribution(beta=0.1), p: 1.0, duration: 23.39 us, effective bandwidth: 703.21 GB/s vocab_size: 128512, batch_size: 16, distrib: gumbel_distribution(beta=1), p: 1.0, duration: 23.14 us, effective bandwidth: 710.99 GB/s vocab_size: 128512, batch_size: 32, distrib: normal_distribution(std=1), p: 1.0, duration: 25.54 us, effective bandwidth: 1288.34 GB/s vocab_size: 128512, batch_size: 32, distrib: normal_distribution(std=5), p: 1.0, duration: 25.50 us, effective bandwidth: 1289.96 GB/s vocab_size: 128512, batch_size: 32, distrib: gumbel_distribution(beta=0.1), p: 1.0, duration: 25.73 us, effective bandwidth: 1278.73 GB/s vocab_size: 128512, batch_size: 32, distrib: gumbel_distribution(beta=1), p: 1.0, duration: 25.47 us, effective bandwidth: 1291.58 GB/s vocab_size: 128512, batch_size: 64, distrib: normal_distribution(std=1), p: 1.0, duration: 42.53 us, effective bandwidth: 1547.17 GB/s vocab_size: 128512, batch_size: 64, distrib: normal_distribution(std=5), p: 1.0, duration: 42.30 us, effective bandwidth: 1555.36 GB/s vocab_size: 128512, batch_size: 64, distrib: gumbel_distribution(beta=0.1), p: 1.0, duration: 42.75 us, effective bandwidth: 1539.07 GB/s vocab_size: 128512, batch_size: 64, distrib: gumbel_distribution(beta=1), p: 1.0, duration: 42.43 us, effective bandwidth: 1550.67 GB/s vocab_size: 128512, batch_size: 128, distrib: normal_distribution(std=1), p: 1.0, duration: 73.92 us, effective bandwidth: 1780.25 GB/s vocab_size: 128512, batch_size: 128, distrib: normal_distribution(std=5), p: 1.0, duration: 74.27 us, effective bandwidth: 1771.82 GB/s vocab_size: 128512, batch_size: 128, distrib: gumbel_distribution(beta=0.1), p: 1.0, duration: 74.37 us, effective bandwidth: 1769.53 GB/s vocab_size: 128512, batch_size: 128, distrib: gumbel_distribution(beta=1), p: 1.0, duration: 73.76 us, effective bandwidth: 1784.11 GB/s vocab_size: 128512, batch_size: 256, distrib: normal_distribution(std=1), p: 1.0, duration: 141.92 us, effective bandwidth: 1854.51 GB/s vocab_size: 128512, batch_size: 256, distrib: normal_distribution(std=5), p: 1.0, duration: 141.66 us, effective bandwidth: 1857.86 GB/s vocab_size: 128512, batch_size: 256, distrib: gumbel_distribution(beta=0.1), p: 1.0, duration: 142.05 us, effective bandwidth: 1852.84 GB/s vocab_size: 128512, batch_size: 256, distrib: gumbel_distribution(beta=1), p: 1.0, duration: 141.82 us, effective bandwidth: 1855.77 GB/s vocab_size: 128512, batch_size: 512, distrib: normal_distribution(std=1), p: 1.0, duration: 259.39 us, effective bandwidth: 2029.30 GB/s vocab_size: 128512, batch_size: 512, distrib: normal_distribution(std=5), p: 1.0, duration: 257.34 us, effective bandwidth: 2045.45 GB/s vocab_size: 128512, batch_size: 512, distrib: gumbel_distribution(beta=0.1), p: 1.0, duration: 259.49 us, effective bandwidth: 2028.55 GB/s vocab_size: 128512, batch_size: 512, distrib: gumbel_distribution(beta=1), p: 1.0, duration: 258.24 us, effective bandwidth: 2038.36 GB/s ``` --- ### Impact This change brings **\~200× speedup** for `top_p >= 1.0` in `TopPRenormProbKernel`, eliminating unnecessary ternary search and aligning its behavior with other sampling kernels. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ```python #pytest tests/test_sampling.py::test_top_p_renorm_probs -v -s =============================================================== test session starts =============================================================== platform linux -- Python 3.10.16, pytest-8.4.1, pluggy-1.5.0 -- /opt/conda/bin/python3 cachedir: .pytest_cache rootdir: /home/local/code/flashinfer configfile: pyproject.toml plugins: anyio-4.10.0 collected 36 items tests/test_sampling.py::test_top_p_renorm_probs[0.1-111-1] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.1-111-99] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.1-111-989] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.1-32000-1] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.1-32000-99] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.1-32000-989] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.1-128256-1] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.1-128256-99] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.1-128256-989] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.5-111-1] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.5-111-99] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.5-111-989] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.5-32000-1] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.5-32000-99] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.5-32000-989] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.5-128256-1] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.5-128256-99] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.5-128256-989] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.9-111-1] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.9-111-99] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.9-111-989] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.9-32000-1] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.9-32000-99] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.9-32000-989] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.9-128256-1] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.9-128256-99] PASSED tests/test_sampling.py::test_top_p_renorm_probs[0.9-128256-989] PASSED tests/test_sampling.py::test_top_p_renorm_probs[1.0-111-1] PASSED tests/test_sampling.py::test_top_p_renorm_probs[1.0-111-99] PASSED tests/test_sampling.py::test_top_p_renorm_probs[1.0-111-989] PASSED tests/test_sampling.py::test_top_p_renorm_probs[1.0-32000-1] PASSED tests/test_sampling.py::test_top_p_renorm_probs[1.0-32000-99] PASSED tests/test_sampling.py::test_top_p_renorm_probs[1.0-32000-989] PASSED tests/test_sampling.py::test_top_p_renorm_probs[1.0-128256-1] PASSED tests/test_sampling.py::test_top_p_renorm_probs[1.0-128256-99] PASSED tests/test_sampling.py::test_top_p_renorm_probs[1.0-128256-989] PASSED ================================================================ warnings summary ================================================================= tests/test_sampling.py::test_top_p_renorm_probs[0.1-111-1] /opt/conda/lib/python3.10/site-packages/torch/utils/cpp_extension.py:2356: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST']. warnings.warn( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ========================================================== 36 passed, 1 warning in 0.99s ========================================================== ``` ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 74b8785 commit cba02c4

File tree

3 files changed

+56
-3
lines changed

3 files changed

+56
-3
lines changed

benchmarks/bench_renorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def main():
3636
gumbel_distribution(0.1),
3737
gumbel_distribution(1),
3838
]:
39-
for p in [0.1, 0.5, 0.9]:
39+
for p in [0.1, 0.5, 0.9, 1.0]:
4040
logits = distrib((batch_size, vocab_size), device="cuda")
4141
probs = torch.softmax(logits, dim=-1)
4242
measurements = bench_gpu_time(

include/flashinfer/sampling.cuh

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1539,6 +1539,7 @@ struct RenormTempStorage {
15391539
struct {
15401540
float max_val;
15411541
float min_val;
1542+
float row_sum;
15421543
union {
15431544
struct {
15441545
float values[2];
@@ -1565,9 +1566,61 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float*
15651566
uint8_t smem_renorm[];
15661567
auto& temp_storage =
15671568
reinterpret_cast<RenormTempStorage<BLOCK_THREADS, REDUCE_ALGO>&>(smem_renorm);
1568-
temp_storage.max_val = 0;
15691569
vec_t<float, VEC_SIZE> probs_vec;
15701570

1571+
// Fast-path: when p >= 1.0 (e.g., p == 1.0), perform simple sum and normalization
1572+
if (p >= 1.0f) {
1573+
// Stage A: per-thread float accumulation over assigned lanes (vectorized)
1574+
float thread_sum = 0.0f;
1575+
const uint32_t num_iters = ceil_div(d, BLOCK_THREADS * VEC_SIZE);
1576+
for (uint32_t i = 0; i < num_iters; ++i) {
1577+
probs_vec.fill(0.0f);
1578+
const uint32_t base_idx = (i * BLOCK_THREADS + tx) * VEC_SIZE;
1579+
if (base_idx < d) {
1580+
probs_vec.cast_load(probs + row_idx * d + base_idx);
1581+
}
1582+
#pragma unroll
1583+
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
1584+
const uint32_t idx = base_idx + j;
1585+
if (idx < d) thread_sum += probs_vec[j];
1586+
}
1587+
}
1588+
1589+
// Block reduce (float)
1590+
float row_sum =
1591+
BlockReduce<float, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage.block_prim.reduce)
1592+
.Sum(thread_sum);
1593+
// Broadcast via shared
1594+
if (tx == 0) temp_storage.row_sum = row_sum;
1595+
__syncthreads();
1596+
row_sum = temp_storage.row_sum;
1597+
1598+
// Guard against zero sum
1599+
const float denom = (row_sum <= 1e-8f) ? 1.0f : row_sum;
1600+
const float normalizer = math::ptx_rcp(denom);
1601+
1602+
// Stage B: normalize and store
1603+
for (uint32_t i = 0; i < num_iters; ++i) {
1604+
probs_vec.fill(0.0f);
1605+
const uint32_t base_idx = (i * BLOCK_THREADS + tx) * VEC_SIZE;
1606+
if (base_idx < d) {
1607+
probs_vec.cast_load(probs + row_idx * d + base_idx);
1608+
}
1609+
#pragma unroll
1610+
for (uint32_t j = 0; j < VEC_SIZE; ++j) {
1611+
const uint32_t idx = base_idx + j;
1612+
float v = probs_vec[j];
1613+
probs_vec[j] = (idx < d) ? (v * normalizer) : 0.0f;
1614+
}
1615+
if (base_idx < d) {
1616+
probs_vec.cast_store(renormed_prob + row_idx * d + base_idx);
1617+
}
1618+
}
1619+
return; // Exit after fast-path processing
1620+
}
1621+
1622+
// Original Top-P renormalization logic
1623+
temp_storage.max_val = 0;
15711624
float max_val = GetMaxValue<VEC_SIZE, BLOCK_THREADS, REDUCE_ALGORITHM,
15721625
RenormTempStorage<BLOCK_THREADS, REDUCE_ALGORITHM>>(probs, row_idx, d,
15731626
temp_storage);

tests/test_sampling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ def test_top_k_top_p_joint_sampling_from_logits(batch_size, vocab_size, p):
388388

389389
@pytest.mark.parametrize("batch_size", [1, 99, 989])
390390
@pytest.mark.parametrize("vocab_size", [111, 32000, 128256])
391-
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9])
391+
@pytest.mark.parametrize("p", [0.1, 0.5, 0.9, 1.0])
392392
def test_top_p_renorm_probs(batch_size, vocab_size, p):
393393
torch.manual_seed(42)
394394
pre_norm_prob = torch.rand(batch_size, vocab_size, device="cuda:0")

0 commit comments

Comments
 (0)