Skip to content

Commit 0776d55

Browse files
yewentao256npanpaliya
authored andcommitted
[Perf] Optimize reshape_and_cache_flash CUDA Kernel (vllm-project#22036)
Signed-off-by: yewentao256 <[email protected]>
1 parent 3e1a9b5 commit 0776d55

File tree

2 files changed

+225
-23
lines changed

2 files changed

+225
-23
lines changed
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from __future__ import annotations
4+
5+
import random
6+
import time
7+
8+
import torch
9+
from tabulate import tabulate
10+
11+
from vllm import _custom_ops as ops
12+
from vllm.logger import init_logger
13+
from vllm.platforms import current_platform
14+
from vllm.utils import (
15+
STR_DTYPE_TO_TORCH_DTYPE,
16+
FlexibleArgumentParser,
17+
create_kv_caches_with_random_flash,
18+
)
19+
20+
logger = init_logger(__name__)
21+
22+
23+
@torch.inference_mode()
24+
def run_benchmark(
25+
num_tokens: int,
26+
num_heads: int,
27+
head_size: int,
28+
block_size: int,
29+
num_blocks: int,
30+
dtype: torch.dtype,
31+
kv_cache_dtype: str,
32+
kv_cache_layout: str,
33+
num_iters: int,
34+
device: str = "cuda",
35+
) -> float:
36+
"""Return latency (seconds) for given num_tokens."""
37+
38+
if kv_cache_dtype == "fp8" and head_size % 16:
39+
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
40+
41+
current_platform.seed_everything(42)
42+
torch.set_default_device(device)
43+
44+
# create random key / value tensors [T, H, D].
45+
key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
46+
value = torch.randn_like(key)
47+
48+
# prepare the slot mapping.
49+
# each token is assigned a unique slot in the KV-cache.
50+
num_slots = block_size * num_blocks
51+
if num_tokens > num_slots:
52+
raise ValueError("num_tokens cannot exceed the total number of cache slots")
53+
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
54+
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
55+
56+
key_caches, value_caches = create_kv_caches_with_random_flash(
57+
num_blocks,
58+
block_size,
59+
1, # num_layers
60+
num_heads,
61+
head_size,
62+
kv_cache_dtype,
63+
dtype,
64+
device=device,
65+
cache_layout=kv_cache_layout,
66+
)
67+
key_cache, value_cache = key_caches[0], value_caches[0]
68+
69+
# compute per-kernel scaling factors for fp8 conversion (if used).
70+
k_scale = (key.amax() / 64.0).to(torch.float32)
71+
v_scale = (value.amax() / 64.0).to(torch.float32)
72+
73+
def run_cuda_benchmark(n_iters: int) -> float:
74+
nonlocal key, value, key_cache, value_cache, slot_mapping
75+
torch.cuda.synchronize()
76+
start = time.perf_counter()
77+
for _ in range(n_iters):
78+
ops.reshape_and_cache_flash(
79+
key,
80+
value,
81+
key_cache,
82+
value_cache,
83+
slot_mapping,
84+
kv_cache_dtype,
85+
k_scale,
86+
v_scale,
87+
)
88+
torch.cuda.synchronize()
89+
end = time.perf_counter()
90+
return (end - start) / n_iters
91+
92+
# warm-up
93+
run_cuda_benchmark(3)
94+
95+
lat = run_cuda_benchmark(num_iters)
96+
97+
# free tensors to mitigate OOM when sweeping
98+
del key, value, key_cache, value_cache, slot_mapping
99+
torch.cuda.empty_cache()
100+
101+
return lat
102+
103+
104+
def main(args):
105+
rows = []
106+
for layout in ["NHD", "HND"]:
107+
for exp in range(1, 17):
108+
n_tok = 2**exp
109+
lat = run_benchmark(
110+
num_tokens=n_tok,
111+
num_heads=args.num_heads,
112+
head_size=args.head_size,
113+
block_size=args.block_size,
114+
num_blocks=args.num_blocks,
115+
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
116+
kv_cache_dtype=args.kv_cache_dtype,
117+
kv_cache_layout=layout,
118+
num_iters=args.iters,
119+
device="cuda",
120+
)
121+
rows.append([n_tok, layout, f"{lat * 1e6:.3f}"])
122+
123+
print(tabulate(rows, headers=["num_tokens", "layout", "latency (µs)"]))
124+
125+
126+
if __name__ == "__main__":
127+
parser = FlexibleArgumentParser()
128+
129+
parser.add_argument("--num-heads", type=int, default=128)
130+
parser.add_argument(
131+
"--head-size",
132+
type=int,
133+
choices=[64, 80, 96, 112, 120, 128, 192, 256],
134+
default=128,
135+
)
136+
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
137+
parser.add_argument("--num-blocks", type=int, default=128 * 512)
138+
139+
parser.add_argument(
140+
"--dtype",
141+
type=str,
142+
choices=["half", "bfloat16", "float"],
143+
default="bfloat16",
144+
)
145+
146+
parser.add_argument(
147+
"--kv-cache-dtype",
148+
type=str,
149+
choices=["auto", "fp8"],
150+
default="auto",
151+
)
152+
153+
parser.add_argument("--iters", type=int, default=100)
154+
args = parser.parse_args()
155+
156+
main(args)

csrc/cache_kernels.cu

Lines changed: 69 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "cuda_utils.h"
66
#include "cuda_compat.h"
77
#include "dispatch_utils.h"
8+
#include "quantization/vectorization_utils.cuh"
89

910
#ifdef USE_ROCM
1011
#include "quantization/fp8/amd/quant_utils.cuh"
@@ -261,14 +262,26 @@ __global__ void reshape_and_cache_kernel(
261262
}
262263
}
263264

265+
// Used by vectorization_utils to copy/convert one element
266+
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
267+
struct CopyWithScaleOp {
268+
float scale;
269+
270+
__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
271+
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
272+
dst = static_cast<OutT>(src);
273+
} else {
274+
dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
275+
}
276+
}
277+
};
278+
264279
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
265280
__global__ void reshape_and_cache_flash_kernel(
266281
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
267282
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
268-
cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
269-
// head_size]
270-
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
271-
// head_size]
283+
cache_t* __restrict__ key_cache, // NHD or HND, shape see comments below
284+
cache_t* __restrict__ value_cache, // same above
272285
const int64_t* __restrict__ slot_mapping, // [num_tokens]
273286
const int64_t block_stride, const int64_t page_stride,
274287
const int64_t head_stride, const int64_t key_stride,
@@ -282,25 +295,58 @@ __global__ void reshape_and_cache_flash_kernel(
282295
}
283296
const int64_t block_idx = slot_idx / block_size;
284297
const int64_t block_offset = slot_idx % block_size;
285-
const int n = num_heads * head_size;
286-
for (int i = threadIdx.x; i < n; i += blockDim.x) {
287-
const int64_t src_key_idx = token_idx * key_stride + i;
288-
const int64_t src_value_idx = token_idx * value_stride + i;
289-
const int head_idx = i / head_size;
290-
const int head_offset = i % head_size;
291-
const int64_t tgt_key_value_idx = block_idx * block_stride +
292-
block_offset * page_stride +
293-
head_idx * head_stride + head_offset;
294-
scalar_t tgt_key = key[src_key_idx];
295-
scalar_t tgt_value = value[src_value_idx];
296-
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
297-
key_cache[tgt_key_value_idx] = tgt_key;
298-
value_cache[tgt_key_value_idx] = tgt_value;
299-
} else {
300-
key_cache[tgt_key_value_idx] =
301-
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
302-
value_cache[tgt_key_value_idx] =
303-
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
298+
const int n_elems = num_heads * head_size;
299+
300+
// pointers to the beginning of the source row for this token.
301+
const scalar_t* __restrict__ key_src = key + token_idx * key_stride;
302+
const scalar_t* __restrict__ value_src = value + token_idx * value_stride;
303+
304+
// find the start position inside the kv-cache for this token.
305+
cache_t* __restrict__ key_dst =
306+
key_cache + block_idx * block_stride + block_offset * page_stride;
307+
cache_t* __restrict__ value_dst =
308+
value_cache + block_idx * block_stride + block_offset * page_stride;
309+
310+
// this is true for the NHD layout where `head_stride == head_size`
311+
const bool is_contiguous_heads = (head_stride == head_size);
312+
313+
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
314+
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
315+
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
316+
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
317+
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
318+
if (is_contiguous_heads) {
319+
// NHD layout
320+
// kv cache: [num_blocks, block_size, num_heads, head_size]
321+
vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, n_elems, threadIdx.x,
322+
blockDim.x, k_op);
323+
324+
vectorize_with_alignment<VEC_SIZE>(value_src, value_dst, n_elems,
325+
threadIdx.x, blockDim.x, v_op);
326+
327+
} else {
328+
// HND layout: heads are strided, but each head_size segment is contiguous
329+
// kv cache: [num_blocks, num_heads, block_size, head_size]
330+
const int lane = threadIdx.x & 31; // 0..31 within warp
331+
const int warp_id = threadIdx.x >> 5; // warp index within block
332+
const int warps_per_block = blockDim.x >> 5;
333+
334+
for (int head = warp_id; head < num_heads; head += warps_per_block) {
335+
const scalar_t* __restrict__ k_src_h = key_src + head * head_size;
336+
const scalar_t* __restrict__ v_src_h = value_src + head * head_size;
337+
338+
cache_t* __restrict__ k_dst_h =
339+
key_dst + static_cast<int64_t>(head) * head_stride;
340+
cache_t* __restrict__ v_dst_h =
341+
value_dst + static_cast<int64_t>(head) * head_stride;
342+
343+
// within each head, let the 32 threads of the warp perform the vector
344+
// copy
345+
vectorize_with_alignment<VEC_SIZE>(k_src_h, k_dst_h, head_size, lane, 32,
346+
k_op);
347+
348+
vectorize_with_alignment<VEC_SIZE>(v_src_h, v_dst_h, head_size, lane, 32,
349+
v_op);
304350
}
305351
}
306352
}

0 commit comments

Comments
 (0)