Skip to content

Commit 4b85028

Browse files
bbeckcafacebook-github-bot
authored andcommitted
Vectorize RMSNorm CUDA kernel (vllm-project#22602)
Summary: What: Make RMSNorm faster by reading data in bigger aligned chunks, caching fp16 rows in shared memory, making the FP8-quant version work with strided inputs, and using the same launch settings as the unfused path. Why: Cut global memory traffic using aligned vector inputs/outputs and shared-mem reuse (avoids second read), make the FP8 path safe for strided inputs, and preserve numerics by matching the unfused reduction/launch order. Test Plan: 1) Run tests ``` [[email protected] /data/users/benjibeck/fbsource/fbcode/vllm (1043a27694)]$ buck2 test :test_kernels_layernorm Buck UI: https://www.internalfb.com/buck2/054ebad3-ad92-4676-a4d2-3bf43e44f31a Test UI: https://www.internalfb.com/intern/testinfra/testrun/10414574240710255 Network: Up: 152MiB Down: 2.9GiB (reSessionID-14af330c-26bf-41d5-87b0-5775bf7d6f8a) Loading targets. Remaining 0/7 150 dirs read, 69 targets declared Analyzing targets. Remaining 0/32 772 actions, 819 artifacts declared Executing actions. Remaining 0/245 48.3s exec time total Command: test. Finished 1 local, 14 remote, 131 cache (90% hit) 45.2s exec time cached (93%) Time elapsed: 4:53.4s Tests finished: Pass 3169. Fail 0. Fatal 0. Skip 0. Build failure 0 ``` 2) Run benchmark ``` buck run :benchmark_layernorm -- --num-tokens 16384 --hidden-size 1024 --dtype half --num-iters 500 ``` Performance (selected, non-strided / HND) | T | H | dtype | baseline (µs) | current (µs) | Δ | | 4096 | 1024 | fp16 | 24.592 | 16.552 | -32.7% | | 16384 | 1024 | fp16 | 106.699 | 42.739 | -60.0% | | 4096 | 8192 | fp16 | 118.566 | 97.059 | -18.1% | | 16384 | 8192 | fp16 | 450.738 | 356.125 | -21.0% | | 4096 | 1024 | bf16 | 24.743 | 16.683 | -32.6% | | 16384 | 1024 | bf16 | 107.009 | 56.946 | -46.8% | | 4096 | 8192 | bf16 | 119.293 | 96.774 | -18.9% | | 16384 | 8192 | bf16 | 451.181 | 357.799 | -20.7% | Strided (NHD) overhead (current kernel) — penalty vs. HND (same T/H/dtype): - 4096×1024 fp16: 1.39× (22.983 / 16.552) - 16384×1024 fp16: 2.13× (90.995 / 42.739) - 4096×8192 fp16: 1.93× (186.931 / 97.059) Rollback Plan: Differential Revision: D79969610
1 parent 3253ae7 commit 4b85028

File tree

3 files changed

+347
-72
lines changed

3 files changed

+347
-72
lines changed

benchmarks/kernels/benchmark_layernorm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
@torch.inference_mode()
14-
def main(
14+
def run_benchmark(
1515
num_tokens: int,
1616
hidden_size: int,
1717
add_residual: bool,
@@ -59,7 +59,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
5959
print(f"Kernel running time: {latency * 1000000:.3f} us")
6060

6161

62-
if __name__ == "__main__":
62+
def main():
63+
"""Main function for Buck compatibility."""
6364
parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.")
6465
parser.add_argument("--num-tokens", type=int, default=4096)
6566
parser.add_argument("--hidden-size", type=int, default=8192)
@@ -81,7 +82,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
8182
args = parser.parse_args()
8283
print(args)
8384

84-
main(
85+
run_benchmark(
8586
num_tokens=args.num_tokens,
8687
hidden_size=args.hidden_size,
8788
add_residual=args.add_residual,
@@ -91,3 +92,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
9192
num_warmup_iters=args.num_warmup_iters,
9293
num_iters=args.num_iters,
9394
)
95+
96+
97+
if __name__ == "__main__":
98+
main()

csrc/layernorm_kernels.cu

Lines changed: 247 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include <torch/cuda.h>
55
#include <c10/cuda/CUDAGuard.h>
6+
#include "quantization/vectorization_utils.cuh"
67

78
#ifndef USE_ROCM
89
#include <cub/cub.cuh>
@@ -12,35 +13,227 @@
1213

1314
namespace vllm {
1415

15-
// TODO(woosuk): Further optimize this kernel.
16+
constexpr int kVecBytes = 16; // 128-bit phase
17+
18+
template <typename T>
19+
__device__ __forceinline__ T warp_sum(T v) {
20+
#ifdef __HIP_PLATFORM_AMD__
21+
const unsigned long long m = 0xffffffffffffffffull;
22+
#else
23+
const unsigned m = 0xffffffffu;
24+
#endif
25+
constexpr int kWidth = 32;
26+
v += __shfl_down_sync(m, v, 16, kWidth);
27+
v += __shfl_down_sync(m, v, 8, kWidth);
28+
v += __shfl_down_sync(m, v, 4, kWidth);
29+
v += __shfl_down_sync(m, v, 2, kWidth);
30+
v += __shfl_down_sync(m, v, 1, kWidth);
31+
return v;
32+
}
33+
34+
template <typename T>
35+
__device__ __forceinline__ bool same_phase(const T* a, const T* b, int bytes) {
36+
const auto ai = reinterpret_cast<uintptr_t>(a);
37+
const auto bi = reinterpret_cast<uintptr_t>(b);
38+
return ((ai ^ bi) & (bytes - 1)) == 0;
39+
}
40+
41+
// copy input row to shared with 16B phase when possible
42+
template <typename T>
43+
__device__ __forceinline__ void copy_row_to_shared_aligned(
44+
const T* __restrict__ src, T* __restrict__ dst, int n_elems, int tid) {
45+
const auto sa = reinterpret_cast<uintptr_t>(src);
46+
const auto da = reinterpret_cast<uintptr_t>(dst);
47+
const bool same = (((sa ^ da) & (kVecBytes - 1)) == 0);
48+
49+
if (!same) {
50+
for (int i = tid; i < n_elems; i += blockDim.x) dst[i] = src[i];
51+
__syncthreads();
52+
return;
53+
}
54+
55+
const int ebytes = sizeof(T);
56+
const int perVec = kVecBytes / ebytes;
57+
58+
int prefix = 0;
59+
const int mis = sa & (kVecBytes - 1);
60+
if (mis) prefix = (kVecBytes - mis) / ebytes;
61+
if (prefix > n_elems) prefix = n_elems;
62+
63+
for (int i = tid; i < prefix; i += blockDim.x) dst[i] = src[i];
64+
65+
const int remain = n_elems - prefix;
66+
const int main_elems = (remain / perVec) * perVec;
67+
if (main_elems > 0) {
68+
const uint4* __restrict__ vsrc =
69+
reinterpret_cast<const uint4*>(src + prefix);
70+
#if defined(__HIP_PLATFORM_AMD__)
71+
uint32_t* __restrict__ s32 = reinterpret_cast<uint32_t*>(dst + prefix);
72+
const int nvec = main_elems / perVec;
73+
constexpr int WORDS_PER_PKT = kVecBytes / sizeof(uint32_t); // 4
74+
for (int v = tid; v < nvec; v += blockDim.x) {
75+
const uint4 p = vsrc[v];
76+
const int base = v * WORDS_PER_PKT;
77+
s32[base + 0] = p.x;
78+
s32[base + 1] = p.y;
79+
s32[base + 2] = p.z;
80+
s32[base + 3] = p.w;
81+
}
82+
#else
83+
uint4* __restrict__ vdst = reinterpret_cast<uint4*>(dst + prefix);
84+
const int nvec = main_elems / perVec;
85+
for (int v = tid; v < nvec; v += blockDim.x) {
86+
vdst[v] = vsrc[v];
87+
}
88+
#endif
89+
}
90+
91+
const int tail = prefix + main_elems;
92+
for (int i = tid + tail; i < n_elems; i += blockDim.x) dst[i] = src[i];
93+
__syncthreads();
94+
}
95+
96+
// functors for vectorized write
97+
template <int V, typename T>
98+
struct VecMulNormWeight {
99+
const vec_n_t<T, V>* __restrict__ wv;
100+
float inv_rms;
101+
int stride_vec;
102+
mutable int64_t vec_idx;
103+
__device__ __forceinline__ void operator()(vec_n_t<T, V>& dst,
104+
const vec_n_t<T, V>& src) const {
105+
const vec_n_t<T, V> w = wv[vec_idx];
106+
#pragma unroll
107+
for (int j = 0; j < V; ++j) {
108+
const T xn = static_cast<T>(static_cast<float>(src.val[j]) * inv_rms);
109+
dst.val[j] = xn * w.val[j];
110+
}
111+
vec_idx += stride_vec;
112+
}
113+
};
114+
115+
template <typename T>
116+
struct ScalarMulNormWeight {
117+
const T* __restrict__ w_base;
118+
T* __restrict__ out_base;
119+
float inv_rms;
120+
__device__ __forceinline__ void operator()(T& dst, const T src) const {
121+
const int i = static_cast<int>(&dst - out_base);
122+
const T xn = static_cast<T>(static_cast<float>(src) * inv_rms);
123+
dst = xn * w_base[i];
124+
}
125+
};
126+
127+
template <int V, typename T>
128+
struct VecNormMulWeightScalarW {
129+
const T* __restrict__ w_base; // offset by prefix
130+
float inv_rms;
131+
int stride_vec;
132+
mutable int vec_idx;
133+
__device__ __forceinline__ void operator()(vec_n_t<T, V>& dst,
134+
const vec_n_t<T, V>& src) const {
135+
const int base = vec_idx * V;
136+
#pragma unroll
137+
for (int j = 0; j < V; ++j) {
138+
const float x = static_cast<float>(src.val[j]) * inv_rms;
139+
dst.val[j] = static_cast<T>(x * static_cast<float>(w_base[base + j]));
140+
}
141+
vec_idx += stride_vec;
142+
}
143+
};
144+
16145
template <typename scalar_t>
17146
__global__ void rms_norm_kernel(
18-
scalar_t* __restrict__ out, // [..., hidden_size]
19-
const scalar_t* __restrict__ input, // [..., hidden_size]
147+
scalar_t* __restrict__ out,
148+
const scalar_t* __restrict__ input,
20149
const int64_t input_stride,
21-
const scalar_t* __restrict__ weight, // [hidden_size]
22-
const float epsilon, const int num_tokens, const int hidden_size) {
23-
__shared__ float s_variance;
24-
float variance = 0.0f;
150+
const scalar_t* __restrict__ weight,
151+
const float epsilon, const int /*num_tokens*/, const int hidden_size,
152+
int smem_elems) {
25153

26-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
27-
const float x = (float)input[blockIdx.x * input_stride + idx];
28-
variance += x * x;
154+
const scalar_t* __restrict__ in_row = input + blockIdx.x * input_stride;
155+
scalar_t* __restrict__ out_row = out + blockIdx.x * hidden_size;
156+
157+
extern __shared__ unsigned char smem_raw[];
158+
scalar_t* s_in = reinterpret_cast<scalar_t*>(smem_raw);
159+
160+
#ifdef __HIP_PLATFORM_AMD__
161+
constexpr bool kAllowCache = false;
162+
#else
163+
constexpr bool kAllowCache = true;
164+
#endif
165+
const bool use_cached =
166+
kAllowCache && (sizeof(scalar_t) == 2) && (smem_elems > 0);
167+
168+
#if !defined(__HIP_PLATFORM_AMD__)
169+
if (use_cached) copy_row_to_shared_aligned(in_row, s_in, hidden_size, threadIdx.x);
170+
#endif
171+
172+
float sumsq = 0.f;
173+
{
174+
const scalar_t* base = use_cached ? s_in : in_row;
175+
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
176+
const float x = static_cast<float>(base[i]);
177+
sumsq += x * x;
178+
}
29179
}
30180

31-
using BlockReduce = cub::BlockReduce<float, 1024>;
32-
__shared__ typename BlockReduce::TempStorage reduceStore;
33-
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
181+
float wsum = warp_sum<float>(sumsq);
182+
__shared__ float warp_sums_sh[32];
183+
if ((threadIdx.x & 31) == 0) warp_sums_sh[threadIdx.x >> 5] = wsum;
184+
__syncthreads();
34185

35-
if (threadIdx.x == 0) {
36-
s_variance = rsqrtf(variance / hidden_size + epsilon);
186+
if (threadIdx.x < 32) {
187+
const int nwarps = (blockDim.x + 31) / 32;
188+
const float v = (threadIdx.x < nwarps) ? warp_sums_sh[threadIdx.x] : 0.f;
189+
const float total = warp_sum<float>(v);
190+
if (threadIdx.x == 0) warp_sums_sh[0] = total;
37191
}
38192
__syncthreads();
39193

40-
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
41-
float x = (float)input[blockIdx.x * input_stride + idx];
42-
out[blockIdx.x * hidden_size + idx] =
43-
((scalar_t)(x * s_variance)) * weight[idx];
194+
const float inv_rms =
195+
rsqrtf(warp_sums_sh[0] / static_cast<float>(hidden_size) + epsilon);
196+
197+
if (hidden_size == blockDim.x) {
198+
const int i = threadIdx.x;
199+
const float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]);
200+
const scalar_t xn = static_cast<scalar_t>(x * inv_rms);
201+
out_row[i] = xn * weight[i];
202+
return;
203+
}
204+
205+
constexpr int V = (sizeof(scalar_t) == 2) ? 8 : 4; // 16B
206+
constexpr int WIDTH = V * sizeof(scalar_t);
207+
const bool vec_store_ok = (hidden_size % V == 0) && same_phase(in_row, out_row, WIDTH);
208+
209+
const bool s_same = use_cached && same_phase(in_row, s_in, kVecBytes);
210+
const scalar_t* vin = s_same ? s_in : in_row;
211+
212+
if (vec_store_ok) {
213+
ScalarMulNormWeight<scalar_t> sca_op{weight, out_row, inv_rms};
214+
215+
const auto addr = reinterpret_cast<uintptr_t>(vin);
216+
const int mis = addr & (WIDTH - 1);
217+
const int prefix = mis ? (WIDTH - mis) / static_cast<int>(sizeof(scalar_t)) : 0;
218+
219+
if (same_phase(in_row, weight, WIDTH)) {
220+
using VecT = vec_n_t<scalar_t, V>;
221+
const VecT* __restrict__ wv =
222+
reinterpret_cast<const VecT*>(weight + prefix);
223+
VecMulNormWeight<V, scalar_t> vec_op{wv, inv_rms, (int)blockDim.x, (int64_t)threadIdx.x};
224+
vectorize_with_alignment<V>(vin, out_row, hidden_size, threadIdx.x, blockDim.x, vec_op, sca_op);
225+
} else {
226+
VecNormMulWeightScalarW<V, scalar_t> vec_op{weight + prefix, inv_rms, (int)blockDim.x, (int)threadIdx.x};
227+
vectorize_with_alignment<V>(vin, out_row, hidden_size, threadIdx.x, blockDim.x, vec_op, sca_op);
228+
}
229+
return;
230+
}
231+
232+
// scalar fallback (keeps op order identical to fused path)
233+
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
234+
const float x = static_cast<float>(use_cached ? s_in[i] : in_row[i]);
235+
const scalar_t xn = static_cast<scalar_t>(x * inv_rms);
236+
out_row[i] = xn * weight[i];
44237
}
45238
}
46239

@@ -142,6 +335,13 @@ fused_add_rms_norm_kernel(
142335

143336
} // namespace vllm
144337

338+
static inline int ln_block_threads_unified(int H) {
339+
int threads = (H >= 1024) ? 256
340+
: (H >= 512) ? 512
341+
: std::min(1024, ((H + 31) / 32) * 32);
342+
return std::min(1024, std::max(128, ((threads + 31) / 32) * 32));
343+
}
344+
145345
void rms_norm(torch::Tensor& out, // [..., hidden_size]
146346
torch::Tensor& input, // [..., hidden_size]
147347
torch::Tensor& weight, // [hidden_size]
@@ -150,21 +350,42 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
150350
TORCH_CHECK(input.stride(-1) == 1);
151351
TORCH_CHECK(weight.is_contiguous());
152352

153-
int hidden_size = input.size(-1);
154-
int num_tokens = input.numel() / hidden_size;
155-
int64_t input_stride = input.stride(-2);
353+
const int hidden_size = input.size(-1);
354+
const int num_tokens = input.numel() / hidden_size;
355+
const int64_t in_stride = input.stride(-2);
156356

157357
dim3 grid(num_tokens);
158-
dim3 block(std::min(hidden_size, 1024));
358+
dim3 block(ln_block_threads_unified(hidden_size));
359+
159360
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
160361
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
362+
363+
// Optional per-block row cache in dynamic shared memory.
364+
// If enabled (FP16 and HS <= 4096), the kernel copies the row to smem once
365+
// and reuses it on the second pass to cut a global read. If shmem_bytes == 0,
366+
// the kernel takes the non-cached path
367+
size_t shmem_bytes = 0;
368+
int smem_elems = 0;
369+
if (input.scalar_type() == at::kHalf && hidden_size <= 4096) {
370+
shmem_bytes = static_cast<size_t>(hidden_size) * sizeof(at::Half);
371+
smem_elems = hidden_size; // flag to kernel that shmem was provisioned
372+
}
373+
161374
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
162-
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
163-
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
164-
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
375+
vllm::rms_norm_kernel<scalar_t>
376+
<<<grid, block, shmem_bytes, stream>>>(
377+
out.data_ptr<scalar_t>(),
378+
input.data_ptr<scalar_t>(),
379+
in_stride,
380+
weight.data_ptr<scalar_t>(),
381+
static_cast<float>(epsilon),
382+
num_tokens,
383+
hidden_size,
384+
smem_elems);
165385
});
166386
}
167387

388+
168389
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
169390
VLLM_DISPATCH_FLOATING_TYPES( \
170391
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \

0 commit comments

Comments
 (0)