@@ -15,15 +15,16 @@ namespace vllm {
15
15
// TODO(woosuk): Further optimize this kernel.
16
16
template <typename scalar_t >
17
17
__global__ void rms_norm_kernel (
18
- scalar_t * __restrict__ out, // [..., hidden_size]
19
- const scalar_t * __restrict__ input, // [..., hidden_size]
18
+ scalar_t * __restrict__ out, // [..., hidden_size]
19
+ const scalar_t * __restrict__ input, // [..., hidden_size]
20
+ const int64_t input_stride,
20
21
const scalar_t * __restrict__ weight, // [hidden_size]
21
22
const float epsilon, const int num_tokens, const int hidden_size) {
22
23
__shared__ float s_variance;
23
24
float variance = 0 .0f ;
24
25
25
26
for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
26
- const float x = (float )input[blockIdx .x * hidden_size + idx];
27
+ const float x = (float )input[blockIdx .x * input_stride + idx];
27
28
variance += x * x;
28
29
}
29
30
@@ -37,7 +38,7 @@ __global__ void rms_norm_kernel(
37
38
__syncthreads ();
38
39
39
40
for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
40
- float x = (float )input[blockIdx .x * hidden_size + idx];
41
+ float x = (float )input[blockIdx .x * input_stride + idx];
41
42
out[blockIdx .x * hidden_size + idx] =
42
43
((scalar_t )(x * s_variance)) * weight[idx];
43
44
}
@@ -50,7 +51,8 @@ __global__ void rms_norm_kernel(
50
51
template <typename scalar_t , int width>
51
52
__global__ std::enable_if_t <(width > 0 ) && _typeConvert<scalar_t >::exists>
52
53
fused_add_rms_norm_kernel (
53
- scalar_t * __restrict__ input, // [..., hidden_size]
54
+ scalar_t * __restrict__ input, // [..., hidden_size]
55
+ const int64_t input_stride,
54
56
scalar_t * __restrict__ residual, // [..., hidden_size]
55
57
const scalar_t * __restrict__ weight, // [hidden_size]
56
58
const float epsilon, const int num_tokens, const int hidden_size) {
@@ -59,6 +61,7 @@ fused_add_rms_norm_kernel(
59
61
static_assert (sizeof (_f16Vec<scalar_t , width>) == sizeof (scalar_t ) * width);
60
62
61
63
const int vec_hidden_size = hidden_size / width;
64
+ const int64_t vec_input_stride = input_stride / width;
62
65
__shared__ float s_variance;
63
66
float variance = 0 .0f ;
64
67
/* These and the argument pointers are all declared `restrict` as they are
@@ -73,7 +76,8 @@ fused_add_rms_norm_kernel(
73
76
74
77
for (int idx = threadIdx .x ; idx < vec_hidden_size; idx += blockDim .x ) {
75
78
int id = blockIdx .x * vec_hidden_size + idx;
76
- _f16Vec<scalar_t , width> temp = input_v[id];
79
+ int64_t strided_id = blockIdx .x * vec_input_stride + idx;
80
+ _f16Vec<scalar_t , width> temp = input_v[strided_id];
77
81
temp += residual_v[id];
78
82
variance += temp.sum_squares ();
79
83
residual_v[id] = temp;
@@ -90,10 +94,11 @@ fused_add_rms_norm_kernel(
90
94
91
95
for (int idx = threadIdx .x ; idx < vec_hidden_size; idx += blockDim .x ) {
92
96
int id = blockIdx .x * vec_hidden_size + idx;
97
+ int64_t strided_id = blockIdx .x * vec_input_stride + idx;
93
98
_f16Vec<scalar_t , width> temp = residual_v[id];
94
99
temp *= s_variance;
95
100
temp *= weight_v[idx];
96
- input_v[id ] = temp;
101
+ input_v[strided_id ] = temp;
97
102
}
98
103
}
99
104
@@ -103,15 +108,16 @@ fused_add_rms_norm_kernel(
103
108
template <typename scalar_t , int width>
104
109
__global__ std::enable_if_t <(width == 0 ) || !_typeConvert<scalar_t >::exists>
105
110
fused_add_rms_norm_kernel (
106
- scalar_t * __restrict__ input, // [..., hidden_size]
111
+ scalar_t * __restrict__ input, // [..., hidden_size]
112
+ const int64_t input_stride,
107
113
scalar_t * __restrict__ residual, // [..., hidden_size]
108
114
const scalar_t * __restrict__ weight, // [hidden_size]
109
115
const float epsilon, const int num_tokens, const int hidden_size) {
110
116
__shared__ float s_variance;
111
117
float variance = 0 .0f ;
112
118
113
119
for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
114
- scalar_t z = input[blockIdx .x * hidden_size + idx];
120
+ scalar_t z = input[blockIdx .x * input_stride + idx];
115
121
z += residual[blockIdx .x * hidden_size + idx];
116
122
float x = (float )z;
117
123
variance += x * x;
@@ -129,7 +135,7 @@ fused_add_rms_norm_kernel(
129
135
130
136
for (int idx = threadIdx .x ; idx < hidden_size; idx += blockDim .x ) {
131
137
float x = (float )residual[blockIdx .x * hidden_size + idx];
132
- input[blockIdx .x * hidden_size + idx] =
138
+ input[blockIdx .x * input_stride + idx] =
133
139
((scalar_t )(x * s_variance)) * weight[idx];
134
140
}
135
141
}
@@ -141,38 +147,42 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
141
147
torch::Tensor& weight, // [hidden_size]
142
148
double epsilon) {
143
149
TORCH_CHECK (out.is_contiguous ());
144
- TORCH_CHECK (input.is_contiguous () );
150
+ TORCH_CHECK (input.stride (- 1 ) == 1 );
145
151
TORCH_CHECK (weight.is_contiguous ());
146
152
147
153
int hidden_size = input.size (-1 );
148
154
int num_tokens = input.numel () / hidden_size;
155
+ int64_t input_stride = input.stride (-2 );
149
156
150
157
dim3 grid (num_tokens);
151
158
dim3 block (std::min (hidden_size, 1024 ));
152
159
const at::cuda::OptionalCUDAGuard device_guard (device_of (input));
153
160
const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
154
161
VLLM_DISPATCH_FLOATING_TYPES (input.scalar_type (), " rms_norm_kernel" , [&] {
155
162
vllm::rms_norm_kernel<scalar_t ><<<grid, block, 0 , stream>>> (
156
- out.data_ptr <scalar_t >(), input.data_ptr <scalar_t >(),
163
+ out.data_ptr <scalar_t >(), input.data_ptr <scalar_t >(), input_stride,
157
164
weight.data_ptr <scalar_t >(), epsilon, num_tokens, hidden_size);
158
165
});
159
166
}
160
167
161
- #define LAUNCH_FUSED_ADD_RMS_NORM (width ) \
162
- VLLM_DISPATCH_FLOATING_TYPES ( \
163
- input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
164
- vllm::fused_add_rms_norm_kernel<scalar_t , width> \
165
- <<<grid, block, 0 , stream>>> (input. data_ptr < scalar_t >(), \
166
- residual. data_ptr < scalar_t >(), \
167
- weight .data_ptr <scalar_t >(), epsilon , \
168
- num_tokens, hidden_size); \
168
+ #define LAUNCH_FUSED_ADD_RMS_NORM (width ) \
169
+ VLLM_DISPATCH_FLOATING_TYPES ( \
170
+ input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
171
+ vllm::fused_add_rms_norm_kernel<scalar_t , width> \
172
+ <<<grid, block, 0 , stream>>> ( \
173
+ input. data_ptr < scalar_t >(), input_stride, \
174
+ residual .data_ptr <scalar_t >(), weight. data_ptr < scalar_t >() , \
175
+ epsilon, num_tokens, hidden_size); \
169
176
});
170
177
171
178
void fused_add_rms_norm (torch::Tensor& input, // [..., hidden_size]
172
179
torch::Tensor& residual, // [..., hidden_size]
173
180
torch::Tensor& weight, // [hidden_size]
174
181
double epsilon) {
182
+ TORCH_CHECK (residual.is_contiguous ());
183
+ TORCH_CHECK (weight.is_contiguous ());
175
184
int hidden_size = input.size (-1 );
185
+ int64_t input_stride = input.stride (-2 );
176
186
int num_tokens = input.numel () / hidden_size;
177
187
178
188
dim3 grid (num_tokens);
@@ -194,9 +204,16 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
194
204
auto inp_ptr = reinterpret_cast <std::uintptr_t >(input.data_ptr ());
195
205
auto res_ptr = reinterpret_cast <std::uintptr_t >(residual.data_ptr ());
196
206
auto wt_ptr = reinterpret_cast <std::uintptr_t >(weight.data_ptr ());
197
- bool ptrs_are_aligned =
198
- inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0 ;
199
- if (ptrs_are_aligned && hidden_size % 8 == 0 ) {
207
+ constexpr int vector_width = 8 ;
208
+ constexpr int req_alignment_bytes =
209
+ vector_width * 2 ; // vector_width * sizeof(bfloat16 or float16) (float32
210
+ // falls back to non-vectorized version anyway)
211
+ bool ptrs_are_aligned = inp_ptr % req_alignment_bytes == 0 &&
212
+ res_ptr % req_alignment_bytes == 0 &&
213
+ wt_ptr % req_alignment_bytes == 0 ;
214
+ bool offsets_are_multiple_of_vector_width =
215
+ hidden_size % vector_width == 0 && input_stride % vector_width == 0 ;
216
+ if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
200
217
LAUNCH_FUSED_ADD_RMS_NORM (8 );
201
218
} else {
202
219
LAUNCH_FUSED_ADD_RMS_NORM (0 );
0 commit comments