You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _posts/2025-03-10-sampling.md
+12-3Lines changed: 12 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -3,7 +3,7 @@ layout: post
3
3
title: "Sorting-Free GPU Kernels for LLM Sampling"
4
4
date: 2025-03-10
5
5
comments: true
6
-
author: Shanli Xing (UW), Zihao Ye (UW), Bohan Hou (CMU), Luis Ceze (UW), Tianqi Chen (CMU)
6
+
author: Shanli Xing (UW), Zihao Ye (UW, NVIDIA), Bohan Hou (CMU), Luis Ceze (UW, NVIDIA), Tianqi Chen (CMU, NVIDIA)
7
7
---
8
8
9
9
## Background
@@ -178,7 +178,16 @@ Beyond token sampling, the rejection sampling algorithm have proven valuable in
178
178
179
179
While the algorithm is elegant in theory, implementing it efficiently in a GPU kernel requires careful attention to detail, particularly in the token selection logic in inverse transform sampling. One key challenge lies in the parallel prefix-sum operation used to locate sampled tokens. Due to the non-associative and non-commutative nature of floating-point arithmetic, parallel prefix-sum **cannot guarantee monotonic outputs** even with non-negative inputs. This can lead to invalid token generation if not handled properly. Special care must be taken to ensure numerical stability and correctness in the sampling implementation (and we made a lot of mistakes before got it right)
180
180
181
-
For the complete implementation details, including how we address these challenges, please refer to the [source code](https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/sampling.cuh).
181
+
For a detailed look at our implementation and how we tackle these challenges, you can explore our [source code](https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/sampling.cuh). Additionally, FlashInfer offers a comprehensive set of APIs for probability cutoff and renormalization, such as [top_p_renorm_probs](https://docs.flashinfer.ai/generated/flashinfer.sampling.top_p_renorm_probs.html#flashinfer.sampling.top_p_renorm_probs) and [top_k_renorm_probs](https://docs.flashinfer.ai/generated/flashinfer.sampling.top_k_renorm_probs.html#flashinfer.sampling.top_k_renorm_probs), enabling flexible composition of multiple sampling filters. These tools allow developers to build sophisticated sampling strategies tailored to their specific needs.
182
+
183
+
## Acknowledgement
184
+
185
+
This blog is written by [Shanli Xing](https://xsl.ing/), we thank the flashinfer team for their contributions to the flashinfer.sampling module:
186
+
* Zihao Ye: design and implementation of sampling kernels in CUDA.
187
+
* Bohan Hou: design and implementation of sampling kernels in TVM.
188
+
* Shanli Xing: design and implementation of min-p sampling kernels in CUDA.
189
+
* Tianqi Chen: propose the idea of rejection sampling for top-p.
182
190
183
191
## Footnotes
184
-
[^1]: FlashInfer provides both "Top-K First" and "Joint" filtering options, with the latter applying Top-K and Top-P simultaneously at each round. More on the [doc](https://docs.flashinfer.ai/generated/flashinfer.sampling.top_k_top_p_sampling_from_probs.html).
192
+
[^1]: FlashInfer provides both "Top-K First" and "Joint" filtering options, with the latter applying Top-K and Top-P simultaneously at each round. More on the [doc](https://docs.flashinfer.ai/generated/flashinfer.sampling.top_k_top_p_sampling_from_probs.html).
0 commit comments