Skip to content

Commit e31446b

Browse files
authored
[Perf] Tune scaled_fp8_quant by increasing vectorization (vllm-project#18844)
Signed-off-by: mgoin <[email protected]>
1 parent bdf1396 commit e31446b

File tree

4 files changed

+115
-110
lines changed

4 files changed

+115
-110
lines changed

csrc/quantization/fp8/common.cu

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,33 +39,33 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
3939
fp8_type* __restrict__ token_output = &out[offset];
4040

4141
// For vectorization, token_input and token_output pointers need to be
42-
// aligned at 8-byte and 4-byte addresses respectively.
43-
bool const can_vectorize = hidden_size % 4 == 0;
42+
// aligned at 32-byte and 16-byte addresses respectively.
43+
bool const can_vectorize = hidden_size % 16 == 0;
4444

4545
float absmax_val = 0.0f;
4646
if (can_vectorize) {
4747
absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
4848
} else {
4949
for (int i = tid; i < hidden_size; i += blockDim.x) {
5050
float const x = static_cast<float>(token_input[i]);
51-
absmax_val = max(absmax_val, fabs(x));
51+
absmax_val = fmaxf(absmax_val, fabsf(x));
5252
}
5353
}
5454

55-
using BlockReduce = cub::BlockReduce<float, 1024>;
55+
using BlockReduce = cub::BlockReduce<float, 256>;
5656
__shared__ typename BlockReduce::TempStorage reduceStorage;
5757
float const block_absmax_val_maybe =
5858
BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x);
5959
__shared__ float token_scale;
6060
if (tid == 0) {
6161
if (scale_ub) {
62-
token_scale = min(block_absmax_val_maybe, *scale_ub);
62+
token_scale = fminf(block_absmax_val_maybe, *scale_ub);
6363
} else {
6464
token_scale = block_absmax_val_maybe;
6565
}
6666
// token scale computation
67-
token_scale = max(token_scale / quant_type_max_v<fp8_type>,
68-
min_scaling_factor<fp8_type>::val());
67+
token_scale = fmaxf(token_scale / quant_type_max_v<fp8_type>,
68+
min_scaling_factor<fp8_type>::val());
6969
scale[token_idx] = token_scale;
7070
}
7171
__syncthreads();
@@ -88,10 +88,11 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
8888
torch::Tensor const& input, // [..., d]
8989
torch::Tensor const& scale) // [1]
9090
{
91-
int64_t num_tokens = input.numel() / input.size(-1);
92-
int64_t num_elems = input.numel();
93-
dim3 grid(num_tokens);
94-
dim3 block(1024);
91+
int const block_size = 256;
92+
int const num_tokens = input.numel() / input.size(-1);
93+
int const num_elems = input.numel();
94+
dim3 const grid(num_tokens);
95+
dim3 const block(block_size);
9596
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
9697
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
9798
VLLM_DISPATCH_FLOATING_TYPES(
@@ -110,10 +111,11 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
110111
torch::Tensor const& input, // [..., d]
111112
torch::Tensor& scale) // [1]
112113
{
113-
int64_t num_tokens = input.numel() / input.size(-1);
114-
int64_t num_elems = input.numel();
115-
dim3 grid(num_tokens);
116-
dim3 block(1024);
114+
int const block_size = 256;
115+
int const num_tokens = input.numel() / input.size(-1);
116+
int const num_elems = input.numel();
117+
dim3 const grid(num_tokens);
118+
dim3 const block(block_size);
117119
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
118120
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
119121
VLLM_DISPATCH_FLOATING_TYPES(
@@ -141,8 +143,9 @@ void dynamic_per_token_scaled_fp8_quant(
141143

142144
int const hidden_size = input.size(-1);
143145
int const num_tokens = input.numel() / hidden_size;
146+
int const block_size = 256;
144147
dim3 const grid(num_tokens);
145-
dim3 const block(std::min(hidden_size, 1024));
148+
dim3 const block(std::min(hidden_size, block_size));
146149

147150
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
148151
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

csrc/quantization/fp8/common.cuh

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
4646
}
4747

4848
float r =
49-
fmax(-quant_type_max_v<fp8_type>, fmin(x, quant_type_max_v<fp8_type>));
49+
fmaxf(-quant_type_max_v<fp8_type>, fminf(x, quant_type_max_v<fp8_type>));
5050
#ifndef USE_ROCM
5151
return static_cast<fp8_type>(r);
5252
#else
@@ -65,15 +65,15 @@ template <typename scalar_t, typename fp8_type>
6565
__global__ void segmented_max_reduction(float* __restrict__ scale,
6666
const scalar_t* __restrict__ input,
6767
int64_t num_elems) {
68-
__shared__ float cache[1024];
68+
__shared__ float cache[256];
6969
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
7070

7171
// First store maximum for all values processes by
7272
// the current thread in cache[threadIdx.x]
7373
scalar_t tmp = 0.0;
7474
while (i < num_elems) {
7575
float x = static_cast<float>(input[i]);
76-
tmp = max(tmp, fabs(x));
76+
tmp = fmaxf(tmp, fabsf(x));
7777
i += blockDim.x * gridDim.x;
7878
}
7979
cache[threadIdx.x] = tmp;
@@ -100,25 +100,27 @@ template <typename scalar_t>
100100
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
101101
int64_t const num_elems, int const tid,
102102
int const step) {
103+
constexpr size_t VEC_SIZE = 16;
104+
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
103105
// Vectorized input/output to better utilize memory bandwidth.
104-
vec4_t<scalar_t> const* vectorized_in =
105-
reinterpret_cast<vec4_t<scalar_t> const*>(input);
106+
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
106107

107-
int64_t const num_vec_elems = num_elems >> 2;
108+
// num_elems / VEC_SIZE (which is 16)
109+
int64_t const num_vec_elems = num_elems >> 4;
108110
float absmax_val = 0.0f;
109111

110-
#pragma unroll 4
112+
#pragma unroll
111113
for (int64_t i = tid; i < num_vec_elems; i += step) {
112-
vec4_t<scalar_t> in_vec = vectorized_in[i];
113-
absmax_val = max(absmax_val, fabs(in_vec.x));
114-
absmax_val = max(absmax_val, fabs(in_vec.y));
115-
absmax_val = max(absmax_val, fabs(in_vec.z));
116-
absmax_val = max(absmax_val, fabs(in_vec.w));
114+
scalarxN_t in_vec = vectorized_in[i];
115+
#pragma unroll
116+
for (int j = 0; j < VEC_SIZE; ++j) {
117+
absmax_val = fmaxf(absmax_val, fabsf(in_vec.val[j]));
118+
}
117119
}
118120

119-
// Handle the remaining elements if num_elems is not divisible by 4
120-
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
121-
absmax_val = max(absmax_val, fabs(input[i]));
121+
// Handle the remaining elements if num_elems is not divisible by VEC_SIZE
122+
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
123+
absmax_val = fmaxf(absmax_val, fabsf(input[i]));
122124
}
123125

124126
return absmax_val;
@@ -130,31 +132,31 @@ __device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
130132
float const scale,
131133
int64_t const num_elems,
132134
int const tid, int const step) {
133-
using float8x4_t = q8x4_t<fp8_type>;
135+
constexpr size_t VEC_SIZE = 16;
136+
using scalarxN_t = vec_n_t<scalar_t, VEC_SIZE>;
137+
using float8xN_t = q8_n_t<fp8_type, VEC_SIZE>;
134138
// Vectorized input/output to better utilize memory bandwidth.
135-
auto const* vectorized_in = reinterpret_cast<vec4_t<scalar_t> const*>(input);
136-
auto* vectorized_out = reinterpret_cast<float8x4_t*>(out);
139+
auto const* vectorized_in = reinterpret_cast<scalarxN_t const*>(input);
140+
auto* vectorized_out = reinterpret_cast<float8xN_t*>(out);
137141

138-
int64_t const num_vec_elems = num_elems >> 2;
142+
// num_elems / VEC_SIZE (which is 16)
143+
int64_t const num_vec_elems = num_elems >> 4;
139144

140-
#pragma unroll 4
145+
#pragma unroll
141146
for (int64_t i = tid; i < num_vec_elems; i += step) {
142-
vec4_t<scalar_t> in_vec = vectorized_in[i];
143-
float8x4_t out_vec;
144-
145-
out_vec.x = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
146-
static_cast<float>(in_vec.x), scale);
147-
out_vec.y = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
148-
static_cast<float>(in_vec.y), scale);
149-
out_vec.z = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
150-
static_cast<float>(in_vec.z), scale);
151-
out_vec.w = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
152-
static_cast<float>(in_vec.w), scale);
147+
scalarxN_t in_vec = vectorized_in[i];
148+
float8xN_t out_vec;
149+
150+
#pragma unroll
151+
for (int j = 0; j < VEC_SIZE; ++j) {
152+
out_vec.val[j] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
153+
static_cast<float>(in_vec.val[j]), scale);
154+
}
153155
vectorized_out[i] = out_vec;
154156
}
155157

156-
// Handle the remaining elements if num_elems is not divisible by 4
157-
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
158+
// Handle the remaining elements if num_elems is not divisible by VEC_SIZE
159+
for (int64_t i = num_vec_elems * VEC_SIZE + tid; i < num_elems; i += step) {
158160
out[i] = scaled_fp8_conversion<is_scale_inverted, fp8_type>(
159161
static_cast<float>(input[i]), scale);
160162
}

csrc/quantization/fused_kernels/layernorm_utils.cuh

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -140,29 +140,31 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
140140
// sum of squares
141141
float ss = 0.0f;
142142

143+
const int VEC_SIZE = 4;
143144
int32_t const num_vec_elems = hidden_size >> 2;
144145

145146
#pragma unroll 4
146147
for (auto i = threadIdx.x; i < num_vec_elems; i += blockDim.x) {
147148
vec4_t<scalar_t> in = vec_input[i];
148149

149150
vec4_t<float> x;
150-
x.x = static_cast<float>(in.x);
151-
x.y = static_cast<float>(in.y);
152-
x.z = static_cast<float>(in.z);
153-
x.w = static_cast<float>(in.w);
151+
#pragma unroll
152+
for (int j = 0; j < VEC_SIZE; ++j) {
153+
x.val[j] = static_cast<float>(in.val[j]);
154+
}
155+
154156
if constexpr (has_residual) {
155157
vec4_t<scalar_t> r = vec_residual[i];
156-
x.x += static_cast<float>(r.x);
157-
x.y += static_cast<float>(r.y);
158-
x.z += static_cast<float>(r.z);
159-
x.w += static_cast<float>(r.w);
158+
#pragma unroll
159+
for (int j = 0; j < VEC_SIZE; ++j) {
160+
x.val[j] += static_cast<float>(r.val[j]);
161+
}
160162
}
161163

162-
ss += x.x * x.x;
163-
ss += x.y * x.y;
164-
ss += x.z * x.z;
165-
ss += x.w * x.w;
164+
#pragma unroll
165+
for (int j = 0; j < VEC_SIZE; ++j) {
166+
ss += x.val[j] * x.val[j];
167+
}
166168
}
167169

168170
using BlockReduce = cub::BlockReduce<float, 1024>;
@@ -203,6 +205,7 @@ __device__ void compute_dynamic_per_token_scales(
203205

204206
constexpr scalar_out_t qmax{quant_type_max_v<scalar_out_t>};
205207

208+
const int VEC_SIZE = 4;
206209
int32_t const num_vec_elems = hidden_size >> 2;
207210
float block_absmax_val_maybe = 0.0f;
208211

@@ -212,26 +215,25 @@ __device__ void compute_dynamic_per_token_scales(
212215
vec4_t<scalar_t> const w = vec_weight[i];
213216

214217
vec4_t<float> x;
215-
x.x = static_cast<float>(in.x);
216-
x.y = static_cast<float>(in.y);
217-
x.z = static_cast<float>(in.z);
218-
x.w = static_cast<float>(in.w);
218+
#pragma unroll
219+
for (int j = 0; j < VEC_SIZE; ++j) {
220+
x.val[j] = static_cast<float>(in.val[j]);
221+
}
222+
219223
if constexpr (has_residual) {
220224
vec4_t<scalar_t> r = vec_residual[i];
221-
x.x += static_cast<float>(r.x);
222-
x.y += static_cast<float>(r.y);
223-
x.z += static_cast<float>(r.z);
224-
x.w += static_cast<float>(r.w);
225+
#pragma unroll
226+
for (int j = 0; j < VEC_SIZE; ++j) {
227+
x.val[j] += static_cast<float>(r.val[j]);
228+
}
225229
}
226230

227-
block_absmax_val_maybe = fmaxf(
228-
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.x * rms) * w.x));
229-
block_absmax_val_maybe = fmaxf(
230-
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.y * rms) * w.y));
231-
block_absmax_val_maybe = fmaxf(
232-
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.z * rms) * w.z));
233-
block_absmax_val_maybe = fmaxf(
234-
block_absmax_val_maybe, fabs(static_cast<scalar_t>(x.w * rms) * w.w));
231+
#pragma unroll
232+
for (int j = 0; j < VEC_SIZE; ++j) {
233+
block_absmax_val_maybe =
234+
fmaxf(block_absmax_val_maybe,
235+
fabs(static_cast<scalar_t>(x.val[j] * rms) * w.val[j]));
236+
}
235237
}
236238

237239
using BlockReduce = cub::BlockReduce<float, 1024>;
@@ -282,6 +284,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
282284
vec_residual = reinterpret_cast<vec4_t<scalar_t>*>(&residual[token_offset]);
283285
}
284286

287+
const int VEC_SIZE = 4;
285288
int32_t const num_vec_elems = hidden_size >> 2;
286289

287290
// TODO(luka/varun) extract into type-agnostic vectorized quant function to
@@ -292,33 +295,31 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
292295
vec4_t<scalar_t> const w = vec_weight[i];
293296

294297
vec4_t<float> x;
295-
x.x = static_cast<float>(in.x);
296-
x.y = static_cast<float>(in.y);
297-
x.z = static_cast<float>(in.z);
298-
x.w = static_cast<float>(in.w);
298+
#pragma unroll
299+
for (int j = 0; j < VEC_SIZE; ++j) {
300+
x.val[j] = static_cast<float>(in.val[j]);
301+
}
302+
299303
if constexpr (has_residual) {
300304
vec4_t<scalar_t> r = vec_residual[i];
301-
x.x += static_cast<float>(r.x);
302-
x.y += static_cast<float>(r.y);
303-
x.z += static_cast<float>(r.z);
304-
x.w += static_cast<float>(r.w);
305-
// Update residual
306-
r.x = static_cast<scalar_t>(x.x);
307-
r.y = static_cast<scalar_t>(x.y);
308-
r.z = static_cast<scalar_t>(x.z);
309-
r.w = static_cast<scalar_t>(x.w);
305+
#pragma unroll
306+
for (int j = 0; j < VEC_SIZE; ++j) {
307+
x.val[j] += static_cast<float>(r.val[j]);
308+
}
309+
// Update residual
310+
#pragma unroll
311+
for (int j = 0; j < VEC_SIZE; ++j) {
312+
r.val[j] = static_cast<scalar_t>(x.val[j]);
313+
}
310314
vec_residual[i] = r;
311315
}
312316

313317
q8x4_t<scalar_out_t> out;
314-
out.x = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
315-
static_cast<scalar_t>(x.x * rms) * w.x, scale);
316-
out.y = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
317-
static_cast<scalar_t>(x.y * rms) * w.y, scale);
318-
out.z = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
319-
static_cast<scalar_t>(x.z * rms) * w.z, scale);
320-
out.w = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
321-
static_cast<scalar_t>(x.w * rms) * w.w, scale);
318+
#pragma unroll
319+
for (int j = 0; j < VEC_SIZE; ++j) {
320+
out.val[j] = ScaledQuant<scalar_out_t, is_scale_inverted>::quant_fn(
321+
static_cast<scalar_t>(x.val[j] * rms) * w.val[j], scale);
322+
}
322323
vec_output[i] = out;
323324
}
324325
}

csrc/quantization/vectorization.cuh

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,22 @@
1010
namespace vllm {
1111

1212
// Vectorization containers
13-
template <typename scalar_t>
14-
struct __align__(8) vec4_t {
15-
scalar_t x;
16-
scalar_t y;
17-
scalar_t z;
18-
scalar_t w;
13+
template <typename scalar_t, size_t vec_size>
14+
struct __align__(vec_size * sizeof(scalar_t)) vec_n_t {
15+
scalar_t val[vec_size];
1916
};
2017

21-
template <typename quant_type_t>
22-
struct __align__(4) q8x4_t {
18+
template <typename quant_type_t, size_t vec_size>
19+
struct __align__(vec_size * sizeof(quant_type_t)) q8_n_t {
2320
static_assert(std::is_same_v<quant_type_t, int8_t> ||
2421
std::is_same_v<quant_type_t, c10::Float8_e4m3fn> ||
2522
std::is_same_v<quant_type_t, c10::Float8_e4m3fnuz>);
26-
quant_type_t x;
27-
quant_type_t y;
28-
quant_type_t z;
29-
quant_type_t w;
23+
quant_type_t val[vec_size];
3024
};
3125

26+
template <typename scalar_t>
27+
using vec4_t = vec_n_t<scalar_t, 4>;
28+
template <typename quant_type_t>
29+
using q8x4_t = q8_n_t<quant_type_t, 4>;
30+
3231
} // namespace vllm

0 commit comments

Comments
 (0)