Skip to content

Commit da8f8fe

Browse files
mickaelseznecmgoin
authored andcommitted
[perf] Add fused MLA QKV + strided layernorm (vllm-project#21116)
Signed-off-by: Mickael Seznec <[email protected]> Co-authored-by: mgoin <[email protected]> Signed-off-by: avigny <[email protected]>
1 parent 1b24d04 commit da8f8fe

File tree

7 files changed

+214
-66
lines changed

7 files changed

+214
-66
lines changed

csrc/layernorm_kernels.cu

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,16 @@ namespace vllm {
1515
// TODO(woosuk): Further optimize this kernel.
1616
template <typename scalar_t>
1717
__global__ void rms_norm_kernel(
18-
scalar_t* __restrict__ out, // [..., hidden_size]
19-
const scalar_t* __restrict__ input, // [..., hidden_size]
18+
scalar_t* __restrict__ out, // [..., hidden_size]
19+
const scalar_t* __restrict__ input, // [..., hidden_size]
20+
const int64_t input_stride,
2021
const scalar_t* __restrict__ weight, // [hidden_size]
2122
const float epsilon, const int num_tokens, const int hidden_size) {
2223
__shared__ float s_variance;
2324
float variance = 0.0f;
2425

2526
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
26-
const float x = (float)input[blockIdx.x * hidden_size + idx];
27+
const float x = (float)input[blockIdx.x * input_stride + idx];
2728
variance += x * x;
2829
}
2930

@@ -37,7 +38,7 @@ __global__ void rms_norm_kernel(
3738
__syncthreads();
3839

3940
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
40-
float x = (float)input[blockIdx.x * hidden_size + idx];
41+
float x = (float)input[blockIdx.x * input_stride + idx];
4142
out[blockIdx.x * hidden_size + idx] =
4243
((scalar_t)(x * s_variance)) * weight[idx];
4344
}
@@ -50,7 +51,8 @@ __global__ void rms_norm_kernel(
5051
template <typename scalar_t, int width>
5152
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
5253
fused_add_rms_norm_kernel(
53-
scalar_t* __restrict__ input, // [..., hidden_size]
54+
scalar_t* __restrict__ input, // [..., hidden_size]
55+
const int64_t input_stride,
5456
scalar_t* __restrict__ residual, // [..., hidden_size]
5557
const scalar_t* __restrict__ weight, // [hidden_size]
5658
const float epsilon, const int num_tokens, const int hidden_size) {
@@ -59,6 +61,7 @@ fused_add_rms_norm_kernel(
5961
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
6062

6163
const int vec_hidden_size = hidden_size / width;
64+
const int64_t vec_input_stride = input_stride / width;
6265
__shared__ float s_variance;
6366
float variance = 0.0f;
6467
/* These and the argument pointers are all declared `restrict` as they are
@@ -73,7 +76,8 @@ fused_add_rms_norm_kernel(
7376

7477
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
7578
int id = blockIdx.x * vec_hidden_size + idx;
76-
_f16Vec<scalar_t, width> temp = input_v[id];
79+
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
80+
_f16Vec<scalar_t, width> temp = input_v[strided_id];
7781
temp += residual_v[id];
7882
variance += temp.sum_squares();
7983
residual_v[id] = temp;
@@ -90,10 +94,11 @@ fused_add_rms_norm_kernel(
9094

9195
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
9296
int id = blockIdx.x * vec_hidden_size + idx;
97+
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
9398
_f16Vec<scalar_t, width> temp = residual_v[id];
9499
temp *= s_variance;
95100
temp *= weight_v[idx];
96-
input_v[id] = temp;
101+
input_v[strided_id] = temp;
97102
}
98103
}
99104

@@ -103,15 +108,16 @@ fused_add_rms_norm_kernel(
103108
template <typename scalar_t, int width>
104109
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
105110
fused_add_rms_norm_kernel(
106-
scalar_t* __restrict__ input, // [..., hidden_size]
111+
scalar_t* __restrict__ input, // [..., hidden_size]
112+
const int64_t input_stride,
107113
scalar_t* __restrict__ residual, // [..., hidden_size]
108114
const scalar_t* __restrict__ weight, // [hidden_size]
109115
const float epsilon, const int num_tokens, const int hidden_size) {
110116
__shared__ float s_variance;
111117
float variance = 0.0f;
112118

113119
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
114-
scalar_t z = input[blockIdx.x * hidden_size + idx];
120+
scalar_t z = input[blockIdx.x * input_stride + idx];
115121
z += residual[blockIdx.x * hidden_size + idx];
116122
float x = (float)z;
117123
variance += x * x;
@@ -129,7 +135,7 @@ fused_add_rms_norm_kernel(
129135

130136
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
131137
float x = (float)residual[blockIdx.x * hidden_size + idx];
132-
input[blockIdx.x * hidden_size + idx] =
138+
input[blockIdx.x * input_stride + idx] =
133139
((scalar_t)(x * s_variance)) * weight[idx];
134140
}
135141
}
@@ -141,38 +147,42 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
141147
torch::Tensor& weight, // [hidden_size]
142148
double epsilon) {
143149
TORCH_CHECK(out.is_contiguous());
144-
TORCH_CHECK(input.is_contiguous());
150+
TORCH_CHECK(input.stride(-1) == 1);
145151
TORCH_CHECK(weight.is_contiguous());
146152

147153
int hidden_size = input.size(-1);
148154
int num_tokens = input.numel() / hidden_size;
155+
int64_t input_stride = input.stride(-2);
149156

150157
dim3 grid(num_tokens);
151158
dim3 block(std::min(hidden_size, 1024));
152159
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
153160
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
154161
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
155162
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
156-
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
163+
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
157164
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
158165
});
159166
}
160167

161-
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
162-
VLLM_DISPATCH_FLOATING_TYPES( \
163-
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
164-
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
165-
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
166-
residual.data_ptr<scalar_t>(), \
167-
weight.data_ptr<scalar_t>(), epsilon, \
168-
num_tokens, hidden_size); \
168+
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
169+
VLLM_DISPATCH_FLOATING_TYPES( \
170+
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
171+
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
172+
<<<grid, block, 0, stream>>>( \
173+
input.data_ptr<scalar_t>(), input_stride, \
174+
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
175+
epsilon, num_tokens, hidden_size); \
169176
});
170177

171178
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
172179
torch::Tensor& residual, // [..., hidden_size]
173180
torch::Tensor& weight, // [hidden_size]
174181
double epsilon) {
182+
TORCH_CHECK(residual.is_contiguous());
183+
TORCH_CHECK(weight.is_contiguous());
175184
int hidden_size = input.size(-1);
185+
int64_t input_stride = input.stride(-2);
176186
int num_tokens = input.numel() / hidden_size;
177187

178188
dim3 grid(num_tokens);
@@ -194,9 +204,16 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
194204
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
195205
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
196206
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
197-
bool ptrs_are_aligned =
198-
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
199-
if (ptrs_are_aligned && hidden_size % 8 == 0) {
207+
constexpr int vector_width = 8;
208+
constexpr int req_alignment_bytes =
209+
vector_width * 2; // vector_width * sizeof(bfloat16 or float16) (float32
210+
// falls back to non-vectorized version anyway)
211+
bool ptrs_are_aligned = inp_ptr % req_alignment_bytes == 0 &&
212+
res_ptr % req_alignment_bytes == 0 &&
213+
wt_ptr % req_alignment_bytes == 0;
214+
bool offsets_are_multiple_of_vector_width =
215+
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
216+
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
200217
LAUNCH_FUSED_ADD_RMS_NORM(8);
201218
} else {
202219
LAUNCH_FUSED_ADD_RMS_NORM(0);

csrc/layernorm_quant_kernels.cu

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,17 @@ namespace vllm {
2323
// TODO(woosuk): Further optimize this kernel.
2424
template <typename scalar_t, typename fp8_type>
2525
__global__ void rms_norm_static_fp8_quant_kernel(
26-
fp8_type* __restrict__ out, // [..., hidden_size]
27-
const scalar_t* __restrict__ input, // [..., hidden_size]
26+
fp8_type* __restrict__ out, // [..., hidden_size]
27+
const scalar_t* __restrict__ input, // [..., hidden_size]
28+
const int input_stride,
2829
const scalar_t* __restrict__ weight, // [hidden_size]
2930
const float* __restrict__ scale, // [1]
3031
const float epsilon, const int num_tokens, const int hidden_size) {
3132
__shared__ float s_variance;
3233
float variance = 0.0f;
3334

3435
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
35-
const float x = (float)input[blockIdx.x * hidden_size + idx];
36+
const float x = (float)input[blockIdx.x * input_stride + idx];
3637
variance += x * x;
3738
}
3839

@@ -49,7 +50,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
4950
float const scale_inv = 1.0f / *scale;
5051

5152
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
52-
float x = (float)input[blockIdx.x * hidden_size + idx];
53+
float x = (float)input[blockIdx.x * input_stride + idx];
5354
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
5455
out[blockIdx.x * hidden_size + idx] =
5556
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
@@ -63,8 +64,9 @@ __global__ void rms_norm_static_fp8_quant_kernel(
6364
template <typename scalar_t, int width, typename fp8_type>
6465
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
6566
fused_add_rms_norm_static_fp8_quant_kernel(
66-
fp8_type* __restrict__ out, // [..., hidden_size]
67-
scalar_t* __restrict__ input, // [..., hidden_size]
67+
fp8_type* __restrict__ out, // [..., hidden_size]
68+
scalar_t* __restrict__ input, // [..., hidden_size]
69+
const int input_stride,
6870
scalar_t* __restrict__ residual, // [..., hidden_size]
6971
const scalar_t* __restrict__ weight, // [hidden_size]
7072
const float* __restrict__ scale, // [1]
@@ -74,6 +76,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
7476
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
7577

7678
const int vec_hidden_size = hidden_size / width;
79+
const int vec_input_stride = input_stride / width;
7780
__shared__ float s_variance;
7881
float variance = 0.0f;
7982
/* These and the argument pointers are all declared `restrict` as they are
@@ -87,8 +90,9 @@ fused_add_rms_norm_static_fp8_quant_kernel(
8790
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
8891

8992
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
93+
int stride_id = blockIdx.x * vec_input_stride + idx;
9094
int id = blockIdx.x * vec_hidden_size + idx;
91-
_f16Vec<scalar_t, width> temp = input_v[id];
95+
_f16Vec<scalar_t, width> temp = input_v[stride_id];
9296
temp += residual_v[id];
9397
variance += temp.sum_squares();
9498
residual_v[id] = temp;
@@ -125,8 +129,9 @@ fused_add_rms_norm_static_fp8_quant_kernel(
125129
template <typename scalar_t, int width, typename fp8_type>
126130
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
127131
fused_add_rms_norm_static_fp8_quant_kernel(
128-
fp8_type* __restrict__ out, // [..., hidden_size]
129-
scalar_t* __restrict__ input, // [..., hidden_size]
132+
fp8_type* __restrict__ out, // [..., hidden_size]
133+
scalar_t* __restrict__ input, // [..., hidden_size]
134+
const int input_stride,
130135
scalar_t* __restrict__ residual, // [..., hidden_size]
131136
const scalar_t* __restrict__ weight, // [hidden_size]
132137
const float* __restrict__ scale, // [1]
@@ -135,7 +140,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
135140
float variance = 0.0f;
136141

137142
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
138-
scalar_t z = input[blockIdx.x * hidden_size + idx];
143+
scalar_t z = input[blockIdx.x * input_stride + idx];
139144
z += residual[blockIdx.x * hidden_size + idx];
140145
float x = (float)z;
141146
variance += x * x;
@@ -169,7 +174,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
169174
torch::Tensor& weight, // [hidden_size]
170175
torch::Tensor& scale, // [1]
171176
double epsilon) {
177+
TORCH_CHECK(out.is_contiguous());
172178
int hidden_size = input.size(-1);
179+
int input_stride = input.stride(-2);
173180
int num_tokens = input.numel() / hidden_size;
174181

175182
dim3 grid(num_tokens);
@@ -183,8 +190,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
183190
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
184191
<<<grid, block, 0, stream>>>(
185192
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
186-
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),
187-
epsilon, num_tokens, hidden_size);
193+
input_stride, weight.data_ptr<scalar_t>(),
194+
scale.data_ptr<float>(), epsilon, num_tokens,
195+
hidden_size);
188196
});
189197
});
190198
}
@@ -198,7 +206,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
198206
width, fp8_t> \
199207
<<<grid, block, 0, stream>>>( \
200208
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
201-
residual.data_ptr<scalar_t>(), \
209+
input_stride, residual.data_ptr<scalar_t>(), \
202210
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
203211
epsilon, num_tokens, hidden_size); \
204212
}); \
@@ -210,7 +218,10 @@ void fused_add_rms_norm_static_fp8_quant(
210218
torch::Tensor& weight, // [hidden_size]
211219
torch::Tensor& scale, // [1]
212220
double epsilon) {
221+
TORCH_CHECK(out.is_contiguous());
222+
TORCH_CHECK(residual.is_contiguous());
213223
int hidden_size = input.size(-1);
224+
int input_stride = input.stride(-2);
214225
int num_tokens = input.numel() / hidden_size;
215226

216227
dim3 grid(num_tokens);
@@ -234,7 +245,7 @@ void fused_add_rms_norm_static_fp8_quant(
234245
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
235246
bool ptrs_are_aligned =
236247
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
237-
if (ptrs_are_aligned && hidden_size % 8 == 0) {
248+
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) {
238249
LAUNCH_FUSED_ADD_RMS_NORM(8);
239250
} else {
240251
LAUNCH_FUSED_ADD_RMS_NORM(0);

csrc/quantization/fp8/common.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
8888
torch::Tensor const& input, // [..., d]
8989
torch::Tensor const& scale) // [1]
9090
{
91+
TORCH_CHECK(input.is_contiguous());
92+
TORCH_CHECK(out.is_contiguous());
9193
int const block_size = 256;
9294
int const num_tokens = input.numel() / input.size(-1);
9395
int const num_elems = input.numel();
@@ -111,6 +113,8 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
111113
torch::Tensor const& input, // [..., d]
112114
torch::Tensor& scale) // [1]
113115
{
116+
TORCH_CHECK(input.is_contiguous());
117+
TORCH_CHECK(out.is_contiguous());
114118
int const block_size = 256;
115119
int const num_tokens = input.numel() / input.size(-1);
116120
int const num_elems = input.numel();

0 commit comments

Comments
 (0)