Skip to content

Commit aac66dd

Browse files
authored
[CUDA] sampling kernel improvements (microsoft#1732)
Key improvements: (1) Unified Fused Kernel: We have adopted the single FusedSamplingKernel for all k <= 256. This kernel performs all sampling steps (temperature scaling, softmax, CDF calculation, filtering, re-normalization, and final selection) in one pass. (2) Refactoring to top-k related code to separated files. (3) Add unit tests and benchmarks for top-k cuda kernels and sampling cuda kernels. (4) Improve selection sort when k=1 by a special kernel that does not write to global memory. (5) Update selection sort to use copy then write, which avoids updating input logits. (6) Update cpu sampling to use same logic for temperature handling. ### Sampling Macro Benchmark on performance improvement of fused kernel We saw 5x speed up on sampling kernel from macro benchmark for batch_size=1, vocab_size=204800, k=50, stride=50 on RTX 5060 Ti GPU. Only sampling latency is measured (top-k is excluded): ``` --- Sampling Cuda Kernel Benchmark Summary --- Batch Size K Stride Algorithm Latency(us) Stdev(us) P95(us) ------------------------------------------------------------------------------------------------- 1 1 1 FUSED 27.41 12.80 58.12 1 1 1 MULTI_STAGE 102.55 39.03 161.63 --- Running Benchmarks with batch_size=1, vocab_size=204800, k=8, stride=8 --- --- Sampling Cuda Kernel Benchmark Summary --- Batch Size K Stride Algorithm Latency(us) Stdev(us) P95(us) ------------------------------------------------------------------------------------------------- 1 8 8 FUSED 19.71 10.25 39.52 1 8 8 MULTI_STAGE 93.47 35.82 155.23 --- Running Benchmarks with batch_size=1, vocab_size=204800, k=50, stride=50 --- --- Sampling Cuda Kernel Benchmark Summary --- Batch Size K Stride Algorithm Latency(us) Stdev(us) P95(us) ------------------------------------------------------------------------------------------------- 1 50 50 FUSED 12.96 2.54 14.62 1 50 50 MULTI_STAGE 76.41 30.91 127.65 --- Running Benchmarks with batch_size=1, vocab_size=204800, k=64, stride=64 --- --- Sampling Cuda Kernel Benchmark Summary --- Batch Size K Stride Algorithm Latency(us) Stdev(us) P95(us) ------------------------------------------------------------------------------------------------- 1 64 64 FUSED 25.32 18.82 57.03 1 64 64 MULTI_STAGE 84.16 35.26 152.48 --- Running Benchmarks with batch_size=1, vocab_size=204800, k=128, stride=204800 --- --- Sampling Cuda Kernel Benchmark Summary --- Batch Size K Stride Algorithm Latency(us) Stdev(us) P95(us) ------------------------------------------------------------------------------------------------- 1 128 204800 FUSED 27.91 11.98 52.68 1 128 204800 MULTI_STAGE 88.58 35.27 150.78 --- Running Benchmarks with batch_size=1, vocab_size=204800, k=192, stride=204800 --- --- Sampling Cuda Kernel Benchmark Summary --- Batch Size K Stride Algorithm Latency(us) Stdev(us) P95(us) ------------------------------------------------------------------------------------------------- 1 192 204800 FUSED 13.40 3.55 16.23 1 192 204800 MULTI_STAGE 83.21 31.18 146.12 --- Running Benchmarks with batch_size=1, vocab_size=204800, k=256, stride=204800 --- --- Sampling Cuda Kernel Benchmark Summary --- Batch Size K Stride Algorithm Latency(us) Stdev(us) P95(us) ------------------------------------------------------------------------------------------------- 1 256 204800 FUSED 19.85 14.90 42.67 1 256 204800 MULTI_STAGE 69.77 24.24 112.90 ```
1 parent 1499203 commit aac66dd

18 files changed

+1924
-938
lines changed

src/cuda/cuda_common.h

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <stdexcept>
7+
#include <string>
8+
#include <sstream>
9+
#include <memory>
10+
#include <cassert>
11+
12+
#include <cuda_runtime.h>
13+
#include "span.h"
14+
115
namespace Generators {
216

317
cudaStream_t GetStream();
@@ -91,4 +105,54 @@ cuda_unique_ptr<T> CudaMallocArray(size_t count, std::span<T>* p_span = nullptr)
91105
return cuda_unique_ptr<T>{p};
92106
}
93107

108+
inline int CeilDiv(int a, int b) { return (a + (b - 1)) / b; }
109+
110+
class CudaError : public std::runtime_error {
111+
public:
112+
explicit CudaError(const std::string& msg, cudaError_t code)
113+
: std::runtime_error(msg), code_(code) {}
114+
115+
cudaError_t code() const noexcept { return code_; }
116+
117+
private:
118+
cudaError_t code_;
119+
};
120+
121+
#define CUDA_CHECK(call) \
122+
do { \
123+
cudaError_t err = (call); \
124+
if (err != cudaSuccess) { \
125+
std::stringstream ss; \
126+
ss << "CUDA error in " << __func__ << " at " << __FILE__ \
127+
<< ":" << __LINE__ << " - " << cudaGetErrorString(err); \
128+
throw Generators::CudaError(ss.str(), err); \
129+
} \
130+
} while (0)
131+
132+
#ifdef NDEBUG
133+
#define CUDA_CHECK_LAUNCH() \
134+
do { \
135+
cudaError_t err = cudaPeekAtLastError(); \
136+
if (err != cudaSuccess) { \
137+
std::stringstream ss; \
138+
ss << "CUDA launch error in " << __func__ << " at " \
139+
<< __FILE__ << ":" << __LINE__ << " - " \
140+
<< cudaGetErrorString(err); \
141+
throw Generators::CudaError(ss.str(), err); \
142+
} \
143+
} while (0)
144+
#else
145+
#define CUDA_CHECK_LAUNCH() \
146+
do { \
147+
cudaError_t err = cudaGetLastError(); \
148+
if (err != cudaSuccess) { \
149+
std::stringstream ss; \
150+
ss << "CUDA launch error in " << __func__ << " at " \
151+
<< __FILE__ << ":" << __LINE__ << " - " \
152+
<< cudaGetErrorString(err); \
153+
throw Generators::CudaError(ss.str(), err); \
154+
} \
155+
} while (0)
156+
#endif
157+
94158
} // namespace Generators

0 commit comments

Comments
 (0)