3
3
4
4
#include < torch/cuda.h>
5
5
#include < c10/cuda/CUDAGuard.h>
6
+ #include " quantization/vectorization_utils.cuh"
6
7
7
8
#ifndef USE_ROCM
8
9
#include < cub/cub.cuh>
12
13
13
14
namespace vllm {
14
15
15
- // TODO(woosuk): Further optimize this kernel.
16
+ constexpr int kVecBytes = 16 ; // 128-bit phase
17
+
18
+ template <typename T>
19
+ __device__ __forceinline__ T warp_sum (T v) {
20
+ #ifdef __HIP_PLATFORM_AMD__
21
+ const unsigned long long m = 0xffffffffffffffffull ;
22
+ #else
23
+ const unsigned m = 0xffffffffu ;
24
+ #endif
25
+ constexpr int kWidth = 32 ;
26
+ v += __shfl_down_sync (m, v, 16 , kWidth );
27
+ v += __shfl_down_sync (m, v, 8 , kWidth );
28
+ v += __shfl_down_sync (m, v, 4 , kWidth );
29
+ v += __shfl_down_sync (m, v, 2 , kWidth );
30
+ v += __shfl_down_sync (m, v, 1 , kWidth );
31
+ return v;
32
+ }
33
+
34
+ template <typename T>
35
+ __device__ __forceinline__ bool same_phase (const T* a, const T* b, int bytes) {
36
+ const auto ai = reinterpret_cast <uintptr_t >(a);
37
+ const auto bi = reinterpret_cast <uintptr_t >(b);
38
+ return ((ai ^ bi) & (bytes - 1 )) == 0 ;
39
+ }
40
+
41
+ // copy input row to shared with 16B phase when possible
42
+ template <typename T>
43
+ __device__ __forceinline__ void copy_row_to_shared_aligned (
44
+ const T* __restrict__ src, T* __restrict__ dst, int n_elems, int tid) {
45
+ const auto sa = reinterpret_cast <uintptr_t >(src);
46
+ const auto da = reinterpret_cast <uintptr_t >(dst);
47
+ const bool same = (((sa ^ da) & (kVecBytes - 1 )) == 0 );
48
+
49
+ if (!same) {
50
+ for (int i = tid; i < n_elems; i += blockDim .x ) dst[i] = src[i];
51
+ __syncthreads ();
52
+ return ;
53
+ }
54
+
55
+ const int ebytes = sizeof (T);
56
+ const int perVec = kVecBytes / ebytes;
57
+
58
+ int prefix = 0 ;
59
+ const int mis = sa & (kVecBytes - 1 );
60
+ if (mis) prefix = (kVecBytes - mis) / ebytes;
61
+ if (prefix > n_elems) prefix = n_elems;
62
+
63
+ for (int i = tid; i < prefix; i += blockDim .x ) dst[i] = src[i];
64
+
65
+ const int remain = n_elems - prefix;
66
+ const int main_elems = (remain / perVec) * perVec;
67
+ if (main_elems > 0 ) {
68
+ const uint4 * __restrict__ vsrc =
69
+ reinterpret_cast <const uint4 *>(src + prefix);
70
+ #if defined(__HIP_PLATFORM_AMD__)
71
+ uint32_t * __restrict__ s32 = reinterpret_cast <uint32_t *>(dst + prefix);
72
+ const int nvec = main_elems / perVec;
73
+ constexpr int WORDS_PER_PKT = kVecBytes / sizeof (uint32_t ); // 4
74
+ for (int v = tid; v < nvec; v += blockDim .x ) {
75
+ const uint4 p = vsrc[v];
76
+ const int base = v * WORDS_PER_PKT;
77
+ s32[base + 0 ] = p.x ;
78
+ s32[base + 1 ] = p.y ;
79
+ s32[base + 2 ] = p.z ;
80
+ s32[base + 3 ] = p.w ;
81
+ }
82
+ #else
83
+ uint4 * __restrict__ vdst = reinterpret_cast <uint4 *>(dst + prefix);
84
+ const int nvec = main_elems / perVec;
85
+ for (int v = tid; v < nvec; v += blockDim .x ) {
86
+ vdst[v] = vsrc[v];
87
+ }
88
+ #endif
89
+ }
90
+
91
+ const int tail = prefix + main_elems;
92
+ for (int i = tid + tail; i < n_elems; i += blockDim .x ) dst[i] = src[i];
93
+ __syncthreads ();
94
+ }
95
+
96
+ // functors for vectorized write
97
+ template <int V, typename T>
98
+ struct VecMulNormWeight {
99
+ const vec_n_t <T, V>* __restrict__ wv;
100
+ float inv_rms;
101
+ int stride_vec;
102
+ mutable int64_t vec_idx;
103
+ __device__ __forceinline__ void operator ()(vec_n_t <T, V>& dst,
104
+ const vec_n_t <T, V>& src) const {
105
+ const vec_n_t <T, V> w = wv[vec_idx];
106
+ #pragma unroll
107
+ for (int j = 0 ; j < V; ++j) {
108
+ const T xn = static_cast <T>(static_cast <float >(src.val [j]) * inv_rms);
109
+ dst.val [j] = xn * w.val [j];
110
+ }
111
+ vec_idx += stride_vec;
112
+ }
113
+ };
114
+
115
+ template <typename T>
116
+ struct ScalarMulNormWeight {
117
+ const T* __restrict__ w_base;
118
+ T* __restrict__ out_base;
119
+ float inv_rms;
120
+ __device__ __forceinline__ void operator ()(T& dst, const T src) const {
121
+ const int i = static_cast <int >(&dst - out_base);
122
+ const T xn = static_cast <T>(static_cast <float >(src) * inv_rms);
123
+ dst = xn * w_base[i];
124
+ }
125
+ };
126
+
127
+ template <int V, typename T>
128
+ struct VecNormMulWeightScalarW {
129
+ const T* __restrict__ w_base; // offset by prefix
130
+ float inv_rms;
131
+ int stride_vec;
132
+ mutable int vec_idx;
133
+ __device__ __forceinline__ void operator ()(vec_n_t <T, V>& dst,
134
+ const vec_n_t <T, V>& src) const {
135
+ const int base = vec_idx * V;
136
+ #pragma unroll
137
+ for (int j = 0 ; j < V; ++j) {
138
+ const float x = static_cast <float >(src.val [j]) * inv_rms;
139
+ dst.val [j] = static_cast <T>(x * static_cast <float >(w_base[base + j]));
140
+ }
141
+ vec_idx += stride_vec;
142
+ }
143
+ };
144
+
16
145
template <typename scalar_t >
17
146
__global__ void rms_norm_kernel (
18
- scalar_t * __restrict__ out, // [..., hidden_size]
19
- const scalar_t * __restrict__ input, // [..., hidden_size]
147
+ scalar_t * __restrict__ out,
148
+ const scalar_t * __restrict__ input,
20
149
const int64_t input_stride,
21
- const scalar_t * __restrict__ weight, // [hidden_size]
22
- const float epsilon, const int num_tokens, const int hidden_size) {
23
- __shared__ float s_variance;
24
- float variance = 0 .0f ;
150
+ const scalar_t * __restrict__ weight,
151
+ const float epsilon, const int /* num_tokens*/ , const int hidden_size,
152
+ int smem_elems) {
25
153
26
- for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
27
- const float x = (float )input[blockIdx .x * input_stride + idx];
28
- variance += x * x;
154
+ const scalar_t * __restrict__ in_row = input + blockIdx .x * input_stride;
155
+ scalar_t * __restrict__ out_row = out + blockIdx .x * hidden_size;
156
+
157
+ extern __shared__ unsigned char smem_raw[];
158
+ scalar_t * s_in = reinterpret_cast <scalar_t *>(smem_raw);
159
+
160
+ #ifdef __HIP_PLATFORM_AMD__
161
+ constexpr bool kAllowCache = false ;
162
+ #else
163
+ constexpr bool kAllowCache = true ;
164
+ #endif
165
+ const bool use_cached =
166
+ kAllowCache && (sizeof (scalar_t ) == 2 ) && (smem_elems > 0 );
167
+
168
+ #if !defined(__HIP_PLATFORM_AMD__)
169
+ if (use_cached) copy_row_to_shared_aligned (in_row, s_in, hidden_size, threadIdx .x );
170
+ #endif
171
+
172
+ float sumsq = 0 .f ;
173
+ {
174
+ const scalar_t * base = use_cached ? s_in : in_row;
175
+ for (int i = threadIdx .x ; i < hidden_size; i += blockDim .x ) {
176
+ const float x = static_cast <float >(base[i]);
177
+ sumsq += x * x;
178
+ }
29
179
}
30
180
31
- using BlockReduce = cub::BlockReduce<float , 1024 >;
32
- __shared__ typename BlockReduce::TempStorage reduceStore;
33
- variance = BlockReduce (reduceStore).Reduce (variance, cub::Sum{}, blockDim .x );
181
+ float wsum = warp_sum<float >(sumsq);
182
+ __shared__ float warp_sums_sh[32 ];
183
+ if ((threadIdx .x & 31 ) == 0 ) warp_sums_sh[threadIdx .x >> 5 ] = wsum;
184
+ __syncthreads ();
34
185
35
- if (threadIdx .x == 0 ) {
36
- s_variance = rsqrtf (variance / hidden_size + epsilon);
186
+ if (threadIdx .x < 32 ) {
187
+ const int nwarps = (blockDim .x + 31 ) / 32 ;
188
+ const float v = (threadIdx .x < nwarps) ? warp_sums_sh[threadIdx .x ] : 0 .f ;
189
+ const float total = warp_sum<float >(v);
190
+ if (threadIdx .x == 0 ) warp_sums_sh[0 ] = total;
37
191
}
38
192
__syncthreads ();
39
193
40
- for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
41
- float x = (float )input[blockIdx .x * input_stride + idx];
42
- out[blockIdx .x * hidden_size + idx] =
43
- ((scalar_t )(x * s_variance)) * weight[idx];
194
+ const float inv_rms =
195
+ rsqrtf (warp_sums_sh[0 ] / static_cast <float >(hidden_size) + epsilon);
196
+
197
+ if (hidden_size == blockDim .x ) {
198
+ const int i = threadIdx .x ;
199
+ const float x = static_cast <float >(use_cached ? s_in[i] : in_row[i]);
200
+ const scalar_t xn = static_cast <scalar_t >(x * inv_rms);
201
+ out_row[i] = xn * weight[i];
202
+ return ;
203
+ }
204
+
205
+ constexpr int V = (sizeof (scalar_t ) == 2 ) ? 8 : 4 ; // 16B
206
+ constexpr int WIDTH = V * sizeof (scalar_t );
207
+ const bool vec_store_ok = (hidden_size % V == 0 ) && same_phase (in_row, out_row, WIDTH);
208
+
209
+ const bool s_same = use_cached && same_phase (in_row, s_in, kVecBytes );
210
+ const scalar_t * vin = s_same ? s_in : in_row;
211
+
212
+ if (vec_store_ok) {
213
+ ScalarMulNormWeight<scalar_t > sca_op{weight, out_row, inv_rms};
214
+
215
+ const auto addr = reinterpret_cast <uintptr_t >(vin);
216
+ const int mis = addr & (WIDTH - 1 );
217
+ const int prefix = mis ? (WIDTH - mis) / static_cast <int >(sizeof (scalar_t )) : 0 ;
218
+
219
+ if (same_phase (in_row, weight, WIDTH)) {
220
+ using VecT = vec_n_t <scalar_t , V>;
221
+ const VecT* __restrict__ wv =
222
+ reinterpret_cast <const VecT*>(weight + prefix);
223
+ VecMulNormWeight<V, scalar_t > vec_op{wv, inv_rms, (int )blockDim .x , (int64_t )threadIdx .x };
224
+ vectorize_with_alignment<V>(vin, out_row, hidden_size, threadIdx .x , blockDim .x , vec_op, sca_op);
225
+ } else {
226
+ VecNormMulWeightScalarW<V, scalar_t > vec_op{weight + prefix, inv_rms, (int )blockDim .x , (int )threadIdx .x };
227
+ vectorize_with_alignment<V>(vin, out_row, hidden_size, threadIdx .x , blockDim .x , vec_op, sca_op);
228
+ }
229
+ return ;
230
+ }
231
+
232
+ // scalar fallback (keeps op order identical to fused path)
233
+ for (int i = threadIdx .x ; i < hidden_size; i += blockDim .x ) {
234
+ const float x = static_cast <float >(use_cached ? s_in[i] : in_row[i]);
235
+ const scalar_t xn = static_cast <scalar_t >(x * inv_rms);
236
+ out_row[i] = xn * weight[i];
44
237
}
45
238
}
46
239
@@ -142,6 +335,13 @@ fused_add_rms_norm_kernel(
142
335
143
336
} // namespace vllm
144
337
338
+ static inline int ln_block_threads_unified (int H) {
339
+ int threads = (H >= 1024 ) ? 256
340
+ : (H >= 512 ) ? 512
341
+ : std::min (1024 , ((H + 31 ) / 32 ) * 32 );
342
+ return std::min (1024 , std::max (128 , ((threads + 31 ) / 32 ) * 32 ));
343
+ }
344
+
145
345
void rms_norm (torch::Tensor& out, // [..., hidden_size]
146
346
torch::Tensor& input, // [..., hidden_size]
147
347
torch::Tensor& weight, // [hidden_size]
@@ -150,21 +350,39 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
150
350
TORCH_CHECK (input.stride (-1 ) == 1 );
151
351
TORCH_CHECK (weight.is_contiguous ());
152
352
153
- int hidden_size = input.size (-1 );
154
- int num_tokens = input.numel () / hidden_size;
155
- int64_t input_stride = input.stride (-2 );
353
+ const int hidden_size = input.size (-1 );
354
+ const int num_tokens = input.numel () / hidden_size;
355
+ const int64_t in_stride = input.stride (-2 );
156
356
157
357
dim3 grid (num_tokens);
158
- dim3 block (std::min (hidden_size, 1024 ));
358
+ dim3 block (ln_block_threads_unified (hidden_size));
359
+
159
360
const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
160
361
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
362
+
363
+ // Optional cached-row for FP16 (recommended). Kernel still works if this is 0.
364
+ size_t shmem_bytes = 0 ;
365
+ int smem_elems = 0 ;
366
+ if (input.scalar_type () == at::kHalf && hidden_size <= 4096 ) {
367
+ shmem_bytes = static_cast <size_t >(hidden_size) * sizeof (at::Half);
368
+ smem_elems = hidden_size; // flag to kernel that shmem was provisioned
369
+ }
370
+
161
371
VLLM_DISPATCH_FLOATING_TYPES (input.scalar_type (), " rms_norm_kernel" , [&] {
162
- vllm::rms_norm_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
163
- out.data_ptr <scalar_t >(), input.data_ptr <scalar_t >(), input_stride,
164
- weight.data_ptr <scalar_t >(), epsilon, num_tokens, hidden_size);
372
+ vllm::rms_norm_kernel<scalar_t >
373
+ <<<grid, block, shmem_bytes, stream>>> (
374
+ out.data_ptr <scalar_t >(),
375
+ input.data_ptr <scalar_t >(),
376
+ in_stride,
377
+ weight.data_ptr <scalar_t >(),
378
+ static_cast <float >(epsilon),
379
+ num_tokens,
380
+ hidden_size,
381
+ smem_elems);
165
382
});
166
383
}
167
384
385
+
168
386
#define LAUNCH_FUSED_ADD_RMS_NORM (width ) \
169
387
VLLM_DISPATCH_FLOATING_TYPES ( \
170
388
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
0 commit comments