diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 6b134630cf..41c6dfaf14 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -16,6 +16,43 @@ #ifndef FLASHINFER_SAMPLING_CUH_ #define FLASHINFER_SAMPLING_CUH_ +/** + * @file sampling.cuh + * @brief CUDA kernels and host utilities for softmax and token sampling in FlashInfer. + * + * This header implements high-performance CUDA primitives to transform logits into + * probabilities (softmax) and to draw samples from categorical distributions under + * a variety of constraints used in LLM decoding (greedy, multinomial, top-k, and top-p). + * + * Design highlights: + * - Online softmax with numerical stability and optional temperature scaling, supporting + * both per-batch scalar temperature and per-row temperature arrays. + * - Vocab-splitting map-reduce path optimized for small-batch/large-vocab regimes. + * - Fused path that optionally caches logits in shared memory to reduce global traffic. + * - Deterministic and nondeterministic sampling paths, using Philox RNG for reproducibility. + * - Parameterized by vector width, block size, and CUDA compute capability for portability. + * + * Threading model and performance: + * - Each batch element (row) is processed by one thread block. + * - Vectorized loads/stores via vec_t to maximize memory throughput when d is aligned. + * - CUB block primitives (scan/reduce) are used for efficient intra-block reductions. + * - Programmatic Stream Serialization (PDL) can be enabled to improve launch ordering on + * supported CUDA toolkits. + * + * Error handling: + * - All host entry points return cudaError_t. cudaSuccess indicates success; otherwise + * an appropriate CUDA error code is returned (e.g., cudaErrorInvalidValue for insufficient + * workspace in the map-reduce softmax path). + * + * Usage overview: + * - Call OnlineSoftmax to convert logits to probabilities in-place into an output buffer. + * - Call SamplingFromLogits or SamplingFromProb to draw one token id per row. + * - Use TopKSamplingFromProb or TopPSamplingFromProb to restrict the sampling set. + * + * Template parameters on the entry points allow using different numeric types for probabilities + * (float, half, bfloat16) and integer types for token ids. See individual function docs below. + */ + #include #include #include @@ -242,6 +279,22 @@ __device__ __forceinline__ void DeterministicInclusiveSum( } } +/** + * @brief Compute per-row minimum and maximum values across a vectorized tile loop. + * + * Loads elements of a row in vectorized fashion and performs block-wide reductions + * to find both min and max, storing intermediates in temp_storage. + * + * @tparam VEC_SIZE Vector width for loads. + * @tparam BLOCK_THREADS Number of threads per block. + * @tparam REDUCE_ALGORITHM CUB block reduce algorithm. + * @tparam TempStorage Shared temporary storage type containing reduce buffers. + * @param in_data [in] Pointer to matrix data (row-major). + * @param row_idx Index of the row to scan. + * @param d Number of columns in the row. + * @param temp_storage [in,out] Temporary storage instance. + * @return Tuple of (min_val, max_val). + */ template __device__ __forceinline__ std::tuple GetMinMaxValue(float* in_data, uint32_t row_idx, @@ -281,6 +334,19 @@ __device__ __forceinline__ std::tuple GetMinMaxValue(float* in_dat return std::make_tuple(min_val, max_val); } +/** + * @brief Compute the maximum value of a row using vectorized loads and block reduction. + * + * @tparam VEC_SIZE Vector width for loads. + * @tparam BLOCK_THREADS Number of threads per block. + * @tparam REDUCE_ALGORITHM CUB block reduce algorithm. + * @tparam TempStorage Shared temporary storage type containing reduce buffers. + * @param in_data [in] Pointer to matrix data (row-major). + * @param row_idx Index of the row to scan. + * @param d Number of columns in the row. + * @param temp_storage [in,out] Temporary storage instance. + * @return Maximum value in the row. + */ template __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, uint32_t d, @@ -311,6 +377,23 @@ __device__ __forceinline__ float GetMaxValue(float* in_data, uint32_t row_idx, u return temp_storage.max_val; } +/** + * @brief Fused single-pass online softmax kernel. + * + * Processes one row per block to compute numerically stable softmax with optional + * temperature scaling. Optionally caches the input logits in shared memory to reduce + * global memory traffic when the vocabulary is small enough. + * + * @tparam BLOCK_THREADS Number of threads per block. + * @tparam VEC_SIZE Vector width used for loads/stores. + * @tparam DType Element type for logits/probabilities. + * @tparam CACHE_INPUT Whether to cache logits in shared memory. + * @param logits [in] Pointer to input logits for all rows. + * @param output [out] Pointer to output probabilities for all rows. + * @param temperature_arr [in,opt] Per-row temperatures or nullptr. + * @param temperature_val [in] Scalar temperature when temperature_arr is nullptr. + * @param d Vocabulary size (number of columns per row). + */ template __global__ void OnlineSoftmaxFusedKernel(DType* logits, DType* output, DType* temperature_arr, DType temperature_val, uint32_t d) { @@ -434,16 +517,36 @@ __global__ void OnlineSoftmaxFusedKernel(DType* logits, DType* output, DType* te #endif } +/** + * @brief Phase-1 map kernel for vocab-sliced online softmax. + * + * Each block processes a slice of the vocabulary for one row, computing a partial + * max and denominator contribution for numerically stable softmax with temperature. + * Results are written to partial_results for a later reduction. + * + * @tparam BLOCK_THREADS Number of threads per block. + * @tparam VEC_SIZE Vector width used for loads/stores. + * @tparam DType Element type for logits/probabilities. + * @param logits [in] Pointer to input logits (all rows). + * @param partial_results [out] Per-(row,slice) partial max and denominator. + * @param temperature_arr [in,opt] Per-row temperatures or nullptr. + * @param temperature_val [in] Scalar temperature when temperature_arr is nullptr. + * @param d Vocabulary size. + * @param num_slices Number of vocab slices used. + */ template __global__ void OnlineSoftmaxMapKernel(DType* logits, PartialSoftmaxResult* partial_results, DType* temperature_arr, float temperature_val, uint32_t d, uint32_t num_slices) { + // Map phase: each block handles one (row, slice). We compute a slice-local max and + // denominator contribution using the online softmax update so we need only a single pass. const uint32_t bx = blockIdx.x; const uint32_t by = blockIdx.y; // slice index const uint32_t tx = threadIdx.x; float temperature = temperature_arr == nullptr ? temperature_val : temperature_arr[bx]; const float inv_temp = (temperature == 0.f) ? 0.f : 1.f / temperature; + // Compute slice bounds with vector-aligned stride so vec_t loads are naturally aligned. const uint32_t vec_alignment_elems = alignof(vec_t) / sizeof(DType); const uint32_t slice_stride = round_up(ceil_div(d, num_slices), vec_alignment_elems); const uint32_t slice_start = by * slice_stride; @@ -456,6 +559,7 @@ __global__ void OnlineSoftmaxMapKernel(DType* logits, PartialSoftmaxResult* part auto& temp_storage = reinterpret_cast(smem); vec_t logits_vec; + // Running max/denom maintain numerical stability across tiles; see reduce op for merge rule. float running_max = -cuda::std::numeric_limits::infinity(); float running_denominator = 0.0f; @@ -522,11 +626,30 @@ __global__ void OnlineSoftmaxMapKernel(DType* logits, PartialSoftmaxResult* part #endif } +/** + * @brief Phase-2 reduce kernel for vocab-sliced online softmax. + * + * For each row, reduces partial (max, denominator) across slices to obtain the + * final normalization, then writes normalized probabilities to output. + * + * @tparam BLOCK_THREADS Number of threads per block. + * @tparam VEC_SIZE Vector width used for loads/stores. + * @tparam DType Element type for logits/probabilities. + * @param logits [in] Pointer to input logits (all rows). + * @param output [out] Pointer to output probabilities (all rows). + * @param partial_results [in] Per-(row,slice) partial results from map phase. + * @param temperature_arr [in,opt] Per-row temperatures or nullptr. + * @param temperature_val [in] Scalar temperature when temperature_arr is nullptr. + * @param d Vocabulary size. + * @param num_slices Number of vocab slices used. + */ template __global__ void OnlineSoftmaxReduceKernel(DType* logits, DType* output, PartialSoftmaxResult* partial_results, DType* temperature_arr, float temperature_val, uint32_t d, uint32_t num_slices) { + // Reduce phase: merge per-slice (max, denom) pairs into a single row-level pair, then + // normalize the entire row using the final numerically-stable parameters. const uint32_t bx = blockIdx.x; const uint32_t tx = threadIdx.x; float temperature = temperature_arr == nullptr ? temperature_val : temperature_arr[bx]; @@ -539,6 +662,8 @@ __global__ void OnlineSoftmaxReduceKernel(DType* logits, DType* output, const Float2SoftmaxReduceOp reduce_op; + // Each thread accumulates a subset of slices; reduce_op encodes the online merge rule + // for softmax: (m1,d1) ⊕ (m2,d2) = (m, d1*exp(m1-m) + d2*exp(m2-m)), where m=max(m1,m2). float2 thread_aggregate = make_float2(-cuda::std::numeric_limits::infinity(), 0.0f); #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) @@ -560,9 +685,7 @@ __global__ void OnlineSoftmaxReduceKernel(DType* logits, DType* output, } __syncthreads(); - block_result = - make_float2(temp_storage.shared_state.max_val, temp_storage.shared_state.denominator); - + // Broadcast row-level parameters through shared memory. const float final_max = temp_storage.shared_state.max_val; const float inv_denominator = 1.0f / temp_storage.shared_state.denominator; @@ -579,6 +702,7 @@ __global__ void OnlineSoftmaxReduceKernel(DType* logits, DType* output, #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { + // Apply temperature scaling and stable exponentiation against final_max. logits_vec[j] *= inv_temp; float p = __expf(static_cast(logits_vec[j]) - final_max) * inv_denominator; prob_vec[j] = static_cast(p); @@ -593,6 +717,27 @@ __global__ void OnlineSoftmaxReduceKernel(DType* logits, DType* output, #endif } +/** + * @brief Per-iteration device helper that advances sampling from probabilities. + * + * Computes block-wide mass over a predicate-filtered segment, optionally performs + * a stable inclusive prefix-sum to locate the first index where CDF exceeds u, + * and updates shared temp_storage with the sampled id and last valid id. + * + * @tparam VEC_SIZE Vector width used for loads/stores. + * @tparam BLOCK_THREADS Number of threads per block. + * @tparam SCAN_ALGORITHM CUB block scan algorithm. + * @tparam REDUCE_ALGORITHM CUB block reduce algorithm. + * @tparam DETERMINISTIC Whether to use deterministic inclusive sum. + * @tparam Predicate Unary predicate used to filter candidates. + * @param i Block-iteration index over tiles. + * @param d Vocabulary size. + * @param pred Predicate determining candidate validity. + * @param u Random threshold in [0, q) for CDF comparison. + * @param prob_vec Probabilities vector for this tile. + * @param aggregate [in,out] Running mass accumulated so far. + * @param temp_storage [in,out] Temporary storage (per block). + */ template __device__ __forceinline__ void DeviceSamplingFromProb( @@ -741,11 +886,33 @@ __device__ __forceinline__ vec_t GenerateGumbelNoise(uint64_t p } } +/** + * @brief Multinomial sampling from unnormalized logits using Gumbel-max trick. + * + * Adds independent Gumbel noise to logits, then selects the argmax to draw a + * single sample per row. Optionally restricts to a subset of rows via indices. + * + * @tparam BLOCK_THREADS Threads per block. + * @tparam SCAN_ALGORITHM Unused here (for interface consistency). + * @tparam REDUCE_ALGORITHM CUB reduce algorithm. + * @tparam VEC_SIZE Vector width for loads. + * @tparam DETERMINISTIC Whether to use deterministic RNG stream. + * @tparam DType Logit element type. + * @tparam IdType Index type for token ids. + * @param logits [in] Logits matrix. + * @param output [out] Sampled token id per row. + * @param indices [in,opt] Optional row remapping array or nullptr. + * @param d Vocabulary size. + * @param philox_seed RNG seed. + * @param philox_offset RNG counter offset. + */ template __global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType* indices, uint32_t d, uint64_t philox_seed, uint64_t philox_offset) { + // Gumbel-max sampling: sample argmax(logits + G), where G ~ Gumbel(0,1) iid. + // We avoid materializing a full G tensor by generating PRNG noise per tile and lane. const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; using SharedMem = typename BlockReduce, BLOCK_THREADS, @@ -754,13 +921,17 @@ __global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType* auto& temp_storage = reinterpret_cast(smem_sampling_logit); vec_t logits_vec; + // Track the running blockwise argmax as (value, index). DataAndIndex max_data = {-cuda::std::numeric_limits::infinity(), 0}; for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + // Load a vectorized tile of logits; pad with -inf so padding never wins in argmax. logits_vec.fill(-cuda::std::numeric_limits::infinity()); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { logits_vec.cast_load(logits + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); } + // Generate VEC_SIZE independent uniforms via Philox and transform to Gumbel noise. + // The subsequence is derived from (block,row tile, lane) so draws are unique and reproducible. vec_t gumbel_noise = GenerateGumbelNoise( philox_seed, philox_offset, static_cast(bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE)); @@ -773,22 +944,51 @@ __global__ void SamplingFromLogitsKernel(DType* logits, IdType* output, IdType* cur_data[j].index = (i * BLOCK_THREADS + tx) * VEC_SIZE + j; } + // Reduce VEC_SIZE lane-local candidates to a blockwise argmax using a custom (+) operator + // that selects the larger .data and carries its .index. max_data += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>(temp_storage) .template Sum(cur_data); } if (tx == 0) { + // The final winning index is the sampled token for this row. output[bx] = max_data.index; } } +/** + * @brief Multinomial sampling from probabilities using block CDF search. + * + * Computes a per-row sample by forming the inclusive CDF across tiles until the + * random threshold u is surpassed. Supports deterministic scan for reproducibility. + * + * @tparam BLOCK_THREADS Threads per block. + * @tparam SCAN_ALGORITHM CUB scan algorithm. + * @tparam REDUCE_ALGORITHM CUB reduce algorithm. + * @tparam VEC_SIZE Vector width for loads. + * @tparam DETERMINISTIC Whether to use deterministic inclusive sum. + * @tparam DType Probability element type (promoted to float internally). + * @tparam IdType Index type for token ids. + * @param probs [in] Probability matrix. + * @param output [out] Sampled token id per row. + * @param indices [in,opt] Optional row remapping array or nullptr. + * @param d Vocabulary size. + * @param philox_seed RNG seed. + * @param philox_offset RNG counter offset. + */ template __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, uint32_t d, uint64_t philox_seed, uint64_t philox_offset) { + // Each block samples one row. We generate one uniform u ~ U(0,1) per row, + // then sweep the row in vectorized tiles, computing an inclusive CDF until it + // exceeds u. The scan and reductions are implemented by CUB block primitives + // in SamplingTempStorage and DeviceSamplingFromProb. curandStatePhilox4_32_10_t state; const uint32_t bx = blockIdx.x, tx = threadIdx.x; + // Use (seed, subsequence=bx, offset) so rows have independent streams and are + // reproducible across launches for the same (seed, offset). curand_init(philox_seed, bx, philox_offset, &state); const uint32_t row_idx = indices == nullptr ? bx : indices[bx]; @@ -798,43 +998,75 @@ __global__ void SamplingFromProbKernel(DType* probs, IdType* output, IdType* ind auto& temp_storage = reinterpret_cast&>( smem_sampling); + // Initialize to sentinel d meaning "not found yet"; last_valid_id tracks the + // last index that satisfied the predicate so we can recover from edge cases. temp_storage.sampled_id = d; __syncthreads(); vec_t probs_vec; float aggregate(0); + // Draw the target threshold u once for the entire row. float u = curand_uniform(&state); #pragma unroll 2 for (uint32_t i = 0; i < ceil_div(d, BLOCK_THREADS * VEC_SIZE); ++i) { + // Vectorized load of a tile; out-of-bounds lanes are zeroed to keep math simple. probs_vec.fill(0); if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { probs_vec.cast_load(probs + row_idx * d + i * BLOCK_THREADS * VEC_SIZE + tx * VEC_SIZE); } + // Compute masked inclusive prefix-sum over this tile and update shared state. + // The predicate x>0 prevents negative or NaN probabilities from entering the CDF. DeviceSamplingFromProb( i, d, [](float x) { return x > 0; }, u, probs_vec, aggregate, &temp_storage); + // Early exit once this tile's CDF crosses u; the sampled id is recorded in shared mem. if (float(aggregate) > u) { break; } } int sampled_id = temp_storage.sampled_id; if (sampled_id == d) { - // NOTE(Zihao): this would happen when u is very close to 1 - // and the sum of probabilities is smaller than u - // In this case, we use the last valid index as the sampled id + // If we never crossed u (e.g., due to numerical underflow or u≈1 and row sum<1), + // fall back to the last valid index observed in the scan for deterministic behavior. sampled_id = temp_storage.last_valid_id; } output[bx] = sampled_id; } +/** + * @brief Top-K restricted sampling from probabilities. + * + * Performs iterative pivoting to determine a probability threshold such that fewer + * than K tokens exceed it, then samples within that set using a CDF search. + * + * @tparam BLOCK_THREADS Threads per block. + * @tparam SCAN_ALGORITHM CUB scan algorithm. + * @tparam REDUCE_ALGORITHM CUB reduce algorithm. + * @tparam VEC_SIZE Vector width for loads. + * @tparam DETERMINISTIC Whether to use deterministic inclusive sum. + * @tparam DType Probability element type. + * @tparam IdType Index type for token ids. + * @param probs [in] Probability matrix. + * @param output [out] Sampled token id per row. + * @param indices [in,opt] Optional row remapping array or nullptr. + * @param top_k_arr [in,opt] Per-row K array or nullptr to use top_k_val. + * @param top_k_val Global K when top_k_arr is nullptr. + * @param d Vocabulary size. + * @param philox_seed RNG seed. + * @param philox_offset RNG counter offset. + */ template __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, IdType* top_k_arr, uint32_t top_k_val, uint32_t d, uint64_t philox_seed, uint64_t philox_offset) { + // Top-K via threshold search: we search for a pivot such that strictly fewer than K + // probabilities exceed it. Once found, we sample within that set using a CDF search. + // The variable q tracks the total probability mass of candidates above the current low, + // allowing us to reuse RNG draws (u ~ U(0,q)) between refinement rounds. const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; curandStatePhilox4_32_10_t state; @@ -859,6 +1091,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* round += 1; temp_storage.sampled_id = d; __syncthreads(); + // Draw u ∈ [0,q]; CDF search only considers mass from entries > low. float u = curand_uniform(&state) * q; aggregate = 0; #pragma unroll 2 @@ -868,6 +1101,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* probs_vec.cast_load(probs + row_idx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); } + // Mask elements ≤ low and compute inclusive CDF over the candidate set. DeviceSamplingFromProb( i, d, [&](float x) { return x > low; }, u, probs_vec, aggregate, &temp_storage); @@ -878,11 +1112,10 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* __syncthreads(); sampled_id = temp_storage.sampled_id; if (sampled_id == d) { - // NOTE(Zihao): this would happen when u is very close to 1 - // and the sum of probabilities is smaller than u - // In this case, we use the last valid index as the sampled id + // If no crossing occurred, use last_valid_id for determinism. sampled_id = temp_storage.last_valid_id; } + // Use the sampled value as pivot_0 and mid-point pivot_1 to ternary-refine [low, high]. double pivot_0 = probs[row_idx * d + sampled_id]; double pivot_1 = (pivot_0 + high) / 2; @@ -897,6 +1130,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* ValueCount probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { + // For each pivot, accumulate both mass and count of tokens above the pivot. probs_gt_pivot_0[j] = { (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; @@ -924,16 +1158,16 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, IdType* output, IdType* aggregate_gt_pivot_1 = temp_storage.block_aggregate.pair; } if (aggregate_gt_pivot_0.count < k) { - // case 1: pivot_0 accepted + // Case 1: pivot_0 already yields __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* indices, float* top_p_arr, float top_p_val, uint32_t d, uint64_t philox_seed, uint64_t philox_offset) { + // Nucleus (Top-P) sampling: search a probability threshold whose cumulative mass is ≥ P. + // Similar to Top-K, we iteratively refine a pivot range [low, high], and reuse the mass q + // of candidates above the current low to scale the RNG range. const uint32_t batch_size = gridDim.x; const uint32_t bx = blockIdx.x, tx = threadIdx.x; curandStatePhilox4_32_10_t state; @@ -972,6 +1231,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* do { temp_storage.sampled_id = d; __syncthreads(); + // Draw u ∈ [0,q] and run a CDF search over tokens with prob > low. float u = curand_uniform(&state) * q; aggregate = 0; #pragma unroll 2 @@ -991,11 +1251,10 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* __syncthreads(); sampled_id = temp_storage.sampled_id; if (sampled_id == d) { - // NOTE(Zihao): this would happen when u is very close to 1 - // and the sum of probabilities is smaller than u - // In this case, we use the last valid index as the sampled id + // If no crossing occurred, use last_valid_id for determinism. sampled_id = temp_storage.last_valid_id; } + // Use sampled prob as pivot_0, mid-point pivot_1 to narrow the search window. double pivot_0 = probs[row_idx * d + sampled_id]; double pivot_1 = (pivot_0 + high) / 2; @@ -1010,6 +1269,7 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* float probs_gt_pivot_0[VEC_SIZE], probs_gt_pivot_1[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { + // Accumulate total mass above the two candidate pivots. probs_gt_pivot_0[j] = (probs_vec[j] > pivot_0) ? probs_vec[j] : 0; probs_gt_pivot_1[j] = (probs_vec[j] > pivot_1) ? probs_vec[j] : 0; } @@ -1031,16 +1291,16 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, IdType* output, IdType* aggregate_gt_pivot_1 = temp_storage.block_aggregate.value; } if (aggregate_gt_pivot_0 < top_p) { - // case 1: pivot_0 accepted + // Case 1: pivot_0 already achieves mass

@@ -1129,6 +1411,30 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, float* min_p_arr, IdTyp output[bx] = sampled_id; } +/** + * @brief Combined Top-K and Top-P sampling from probabilities. + * + * Iteratively searches for a threshold that satisfies both constraints: fewer than + * K tokens above threshold and cumulative mass below P. Samples within the resulting set. + * + * @tparam BLOCK_THREADS Threads per block. + * @tparam SCAN_ALGORITHM CUB scan algorithm. + * @tparam REDUCE_ALGORITHM CUB reduce algorithm. + * @tparam VEC_SIZE Vector width for loads. + * @tparam DETERMINISTIC Whether to use deterministic inclusive sum. + * @tparam DType Probability element type. + * @tparam IdType Index type for token ids. + * @param probs [in] Probability matrix. + * @param top_k_arr [in,opt] Per-row K values or nullptr to use top_k_val. + * @param top_p_arr [in,opt] Per-row P values or nullptr to use top_p_val. + * @param output [out] Sampled token id per row. + * @param indices [in,opt] Optional row remapping array or nullptr. + * @param top_k_val Global K when top_k_arr is nullptr. + * @param top_p_val Global P when top_p_arr is nullptr. + * @param d Vocabulary size. + * @param philox_seed RNG seed. + * @param philox_offset RNG counter offset. + */ template @@ -1244,6 +1550,28 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, IdType* top_k_arr, } } +/** + * @brief Compute softmax over the vocabulary for each row using an online, numerically + * stable algorithm with optional temperature scaling. + * + * This function chooses between a vocab-splitting map-reduce pipeline (optimized for + * small batch size and large vocab) and a fused single-pass kernel. It writes the + * resulting probabilities to the output buffer; logits are read-only. + * + * @tparam DType Floating point type of logits/probabilities (e.g., float, half, bf16). + * @param logits [in] Pointer to input logits of shape (batch_size, d). + * @param output [out] Pointer to output probabilities of shape (batch_size, d). + * @param batch_size Number of rows in the batch. + * @param d Vocabulary size (number of columns per row). + * @param temperature_arr [in,opt] Per-row temperatures of shape (batch_size), or nullptr. + * @param temperature_val [in] Scalar temperature used when temperature_arr is nullptr. + * @param workspace_buffer [in,opt] Scratch buffer used by the map-reduce path. + * @param workspace_buffer_size_in_bytes Size in bytes of workspace_buffer. + * @param enable_pdl Enable Programmatic Stream Serialization when supported by the toolkit. + * @param stream CUDA stream to launch kernels on (default stream 0). + * @return cudaSuccess on success; cudaErrorInvalidValue if workspace is insufficient for + * the selected path; or other CUDA errors from kernel launches. + */ template cudaError_t OnlineSoftmax(DType* logits, DType* output, uint32_t batch_size, uint32_t d, DType* temperature_arr, DType temperature_val, void* workspace_buffer, @@ -1387,6 +1715,27 @@ cudaError_t OnlineSoftmax(DType* logits, DType* output, uint32_t batch_size, uin return cudaSuccess; } +/** + * @brief Sample one token id per row from unnormalized logits using multinomial sampling. + * + * Internally applies a stable max-trick and exponentiation to obtain probabilities and + * performs a single-sample draw with either deterministic (Philox-based, reproducible) + * or nondeterministic RNG. If indices is non-null, sampling is restricted to the provided + * candidate ids for each row (shape dependent on caller contract). + * + * @tparam T Floating point type of logits (e.g., float). + * @tparam IdType Integer type for token ids (e.g., int32). + * @param logits [in] Pointer to logits (batch_size, d). + * @param output [out] Pointer to sampled token ids (batch_size). + * @param indices [in,opt] Optional candidate indices to restrict sampling set, or nullptr. + * @param batch_size Number of rows to sample. + * @param d Vocabulary size per row. + * @param deterministic If true, use a deterministic Philox RNG stream. + * @param philox_seed RNG seed used when deterministic is true. + * @param philox_offset RNG counter offset used when deterministic is true. + * @param stream CUDA stream (default 0). + * @return cudaSuccess on success, otherwise a CUDA error from kernel launch. + */ template cudaError_t SamplingFromLogits(T* logits, IdType* output, IdType* indices, uint32_t batch_size, uint32_t d, bool deterministic, uint64_t philox_seed, @@ -1412,6 +1761,27 @@ cudaError_t SamplingFromLogits(T* logits, IdType* output, IdType* indices, uint3 }); } +/** + * @brief Sample one token id per row from probabilities. + * + * Assumes probs contains per-row categorical distributions (not necessarily normalized + * due to numerical noise); the kernel computes the inclusive prefix-sum per row and draws + * one sample according to the CDF. If indices is provided, sampling is restricted to + * those candidates. + * + * @tparam T Floating point type of probabilities (e.g., float). + * @tparam IdType Integer type for token ids (e.g., int32). + * @param probs [in] Pointer to probabilities (batch_size, d). + * @param output [out] Pointer to sampled token ids (batch_size). + * @param indices [in,opt] Optional candidate indices to restrict sampling set, or nullptr. + * @param batch_size Number of rows to sample. + * @param d Vocabulary size per row. + * @param deterministic If true, use a deterministic Philox RNG stream. + * @param philox_seed RNG seed used when deterministic is true. + * @param philox_offset RNG counter offset used when deterministic is true. + * @param stream CUDA stream (default 0). + * @return cudaSuccess on success, otherwise a CUDA error from kernel launch. + */ template cudaError_t SamplingFromProb(T* probs, IdType* output, IdType* indices, uint32_t batch_size, uint32_t d, bool deterministic, uint64_t philox_seed, @@ -1436,6 +1806,29 @@ cudaError_t SamplingFromProb(T* probs, IdType* output, IdType* indices, uint32_t }); } +/** + * @brief Sample one token id per row from probabilities restricted to the Top-K set. + * + * For each row, candidates are limited to the K tokens with highest probability + * (either a per-row K via top_k_arr or a global K via top_k_val). Sampling is then + * performed over the renormalized restricted set. If indices is provided, it further + * restricts the candidate pool to its intersection with Top-K. + * + * @tparam T Floating point type of probabilities (e.g., float). + * @tparam IdType Integer type for token ids (e.g., int32). + * @param probs [in] Pointer to probabilities (batch_size, d). + * @param output [out] Pointer to sampled token ids (batch_size). + * @param indices [in,opt] Optional candidate indices to restrict sampling, or nullptr. + * @param top_k_arr [in,opt] Optional per-row K values; if null, top_k_val is used. + * @param batch_size Number of rows to sample. + * @param top_k_val Global K used when top_k_arr is null. + * @param d Vocabulary size per row. + * @param deterministic If true, use a deterministic Philox RNG stream. + * @param philox_seed RNG seed used when deterministic is true. + * @param philox_offset RNG counter offset used when deterministic is true. + * @param stream CUDA stream (default 0). + * @return cudaSuccess on success, otherwise a CUDA error from kernel launch. + */ template cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_k_arr, uint32_t batch_size, uint32_t top_k_val, uint32_t d, @@ -1464,6 +1857,29 @@ cudaError_t TopKSamplingFromProb(T* probs, IdType* output, IdType* indices, T* t }); } +/** + * @brief Sample one token id per row from probabilities restricted by nucleus (Top-P) filter. + * + * For each row, candidates are limited to the smallest set of tokens whose cumulative + * probability mass is at least P (either per-row via top_p_arr or global via top_p_val). + * Sampling is then performed over the renormalized restricted set. If indices is provided, + * it further restricts the candidate pool to its intersection with the Top-P set. + * + * @tparam T Floating point type of probabilities (e.g., float). + * @tparam IdType Integer type for token ids (e.g., int32). + * @param probs [in] Pointer to probabilities (batch_size, d). + * @param output [out] Pointer to sampled token ids (batch_size). + * @param indices [in,opt] Optional candidate indices to restrict sampling, or nullptr. + * @param top_p_arr [in,opt] Optional per-row P values in [0,1]; if null, top_p_val is used. + * @param batch_size Number of rows to sample. + * @param top_p_val Global P used when top_p_arr is null. + * @param d Vocabulary size per row. + * @param deterministic If true, use a deterministic Philox RNG stream. + * @param philox_seed RNG seed used when deterministic is true. + * @param philox_offset RNG counter offset used when deterministic is true. + * @param stream CUDA stream (default 0). + * @return cudaSuccess on success, otherwise a CUDA error from kernel launch. + */ template cudaError_t TopPSamplingFromProb(T* probs, IdType* output, IdType* indices, T* top_p_arr, uint32_t batch_size, T top_p_val, uint32_t d, bool deterministic, @@ -1575,6 +1991,23 @@ struct RenormTempStorage { }; }; +/** + * @brief Renormalize probabilities within the Top-P nucleus set. + * + * Computes the cumulative probability threshold for Top-P and rescales values + * within the nucleus so that the selected set sums to 1. Includes a fast path + * for p >= 1.0 where only a row-wise normalization is needed. + * + * @tparam BLOCK_THREADS Threads per block. + * @tparam REDUCE_ALGORITHM CUB reduce algorithm. + * @tparam VEC_SIZE Vector width for loads. + * @tparam DType Probability element type. + * @param probs [in] Input probabilities. + * @param renormed_prob [out] Output buffer for renormalized probabilities. + * @param top_p_arr [in,opt] Per-row P values or nullptr to use top_p_val. + * @param top_p_val Global P when top_p_arr is nullptr. + * @param d Vocabulary size. + */ template __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* top_p_arr, @@ -1744,6 +2177,23 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, float* } } +/** + * @brief Apply Top-K mask to logits by setting values outside Top-K to -inf. + * + * Determines the Top-K threshold and writes a masked logits tensor where only + * tokens within the Top-K set retain their original values; others are set to -inf. + * + * @tparam BLOCK_THREADS Threads per block. + * @tparam REDUCE_ALGORITHM CUB reduce algorithm. + * @tparam VEC_SIZE Vector width for loads. + * @tparam DType Logit element type. + * @tparam IdType Integer type for K values. + * @param logits [in] Input logits. + * @param masked_logits [out] Output logits with Top-K mask applied. + * @param top_k_arr [in,opt] Per-row K array or nullptr to use top_k_val. + * @param top_k_val Global K when top_k_arr is nullptr. + * @param d Vocabulary size. + */ template __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType* top_k_arr, @@ -1864,10 +2314,29 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType } } +/** + * @brief Renormalize probabilities within the Top-K set. + * + * Finds a probability threshold corresponding to the K-th largest value and rescales + * probabilities of tokens above threshold so they sum to 1. If K >= d, performs + * a simple row-wise normalization. + * + * @tparam BLOCK_THREADS Threads per block. + * @tparam REDUCE_ALGORITHM CUB reduce algorithm. + * @tparam VEC_SIZE Vector width for loads. + * @tparam DType Probability element type. + * @tparam IdType Integer type for K values. + * @param probs [in] Input probabilities. + * @param renormed_prob [out] Output buffer for renormalized probabilities. + * @param top_k_arr [in,opt] Per-row K array or nullptr to use top_k_val. + * @param top_k_val Global K when top_k_arr is nullptr. + * @param d Vocabulary size. + */ template __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* top_k_arr, uint32_t top_k_val, uint32_t d) { + // Find the K-th largest probability threshold (pivot) and renormalize probs above it. const uint32_t bx = blockIdx.x, tx = threadIdx.x; const uint32_t row_idx = bx; uint32_t k = top_k_arr == nullptr ? top_k_val : top_k_arr[bx]; @@ -1887,13 +2356,9 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* double low = 0, high = max_val; float min_gt_low, max_le_high; float sum_low = 1; - // f(x) = len(nonzero(probs > x)), f(x) is non-increasing - // min_gt_low = min{p \in probs | p > low}, max_le_high = max{p \in probs | p <= high} - // loop invariant: - // - f(low) >= k, f(high) < k - // - f(low) > f(min_gt_low) >= f(max_le_high) == f(high) - // stopping condition: min_gt_low == max_le_high - // - f(low) >= k, f(min_gt_low) == f(max_le_high) == f(high) < k + // Invariant-driven search over pivots: + // f(x) = number of tokens with prob > x is non-increasing. + // Maintain f(low) >= k and f(high) < k, narrowing until min_gt_low == max_le_high. do { double pivot_0 = (high + 2 * low) / 3; double pivot_1 = (2 * high + low) / 3; @@ -1910,6 +2375,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* ValueCount probs_gt_pivot_0_pair[VEC_SIZE], probs_gt_pivot_1_pair[VEC_SIZE]; #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { + // Accumulate both count and mass for elements above each tentative pivot. probs_gt_pivot_0_pair[j] = { (probs_vec[j] > pivot_0) ? probs_vec[j] : 0, (probs_vec[j] > pivot_0 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; @@ -1917,6 +2383,8 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* (probs_vec[j] > pivot_1) ? probs_vec[j] : 0, (probs_vec[j] > pivot_1 && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; + // Track the smallest value strictly above low and the largest value ≤ high + // to compute the termination condition without sorting. if (probs_vec[j] > low && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d) { min_gt_low = min(min_gt_low, probs_vec[j]); } @@ -1935,6 +2403,7 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* .template Sum(probs_gt_pivot_1_pair); __syncthreads(); } + // Reduce across the block to finalize min_gt_low / max_le_high sentinels. min_gt_low = BlockReduce(temp_storage.block_prim.reduce) .Reduce(min_gt_low, MinReduceOp{});