Skip to content

Commit eb0fa43

Browse files
ZJY0516Liu-congo
andauthored
[Perf] Optimize reshape_and_cache CUDA Kernel (vllm-project#25955)
Signed-off-by: zjy0516 <[email protected]> Co-authored-by: Liu-congo <[email protected]>
1 parent 0ad9951 commit eb0fa43

File tree

2 files changed

+225
-45
lines changed

2 files changed

+225
-45
lines changed
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from __future__ import annotations
4+
5+
import random
6+
import time
7+
8+
import torch
9+
from tabulate import tabulate
10+
11+
from vllm import _custom_ops as ops
12+
from vllm.logger import init_logger
13+
from vllm.platforms import current_platform
14+
from vllm.utils import (
15+
STR_DTYPE_TO_TORCH_DTYPE,
16+
FlexibleArgumentParser,
17+
create_kv_caches_with_random,
18+
)
19+
20+
logger = init_logger(__name__)
21+
22+
23+
@torch.inference_mode()
24+
def run_benchmark(
25+
num_tokens: int,
26+
num_heads: int,
27+
head_size: int,
28+
block_size: int,
29+
num_blocks: int,
30+
dtype: torch.dtype,
31+
kv_cache_dtype: str,
32+
num_iters: int,
33+
benchmark_mode: str,
34+
device: str = "cuda",
35+
) -> float:
36+
"""Return latency (seconds) for given num_tokens."""
37+
38+
if kv_cache_dtype == "fp8" and head_size % 16:
39+
raise ValueError("fp8 kv-cache requires head_size to be a multiple of 16.")
40+
41+
current_platform.seed_everything(42)
42+
torch.set_default_device(device)
43+
44+
# create random key / value tensors [T, H, D].
45+
key = torch.randn(num_tokens, num_heads, head_size, dtype=dtype, device=device)
46+
value = torch.randn_like(key)
47+
48+
# prepare the slot mapping.
49+
# each token is assigned a unique slot in the KV-cache.
50+
num_slots = block_size * num_blocks
51+
if num_tokens > num_slots:
52+
raise ValueError("num_tokens cannot exceed the total number of cache slots")
53+
slot_mapping_lst = random.sample(range(num_slots), num_tokens)
54+
slot_mapping = torch.tensor(slot_mapping_lst, dtype=torch.long, device=device)
55+
56+
key_caches, value_caches = create_kv_caches_with_random(
57+
num_blocks,
58+
block_size,
59+
1, # num_layers
60+
num_heads,
61+
head_size,
62+
kv_cache_dtype,
63+
dtype,
64+
device=device,
65+
)
66+
key_cache, value_cache = key_caches[0], value_caches[0]
67+
# to free unused memory
68+
del key_caches, value_caches
69+
70+
# compute per-kernel scaling factors for fp8 conversion (if used).
71+
k_scale = (key.amax() / 64.0).to(torch.float32)
72+
v_scale = (value.amax() / 64.0).to(torch.float32)
73+
74+
function_under_test = lambda: ops.reshape_and_cache(
75+
key, # noqa: F821
76+
value, # noqa: F821
77+
key_cache, # noqa: F821
78+
value_cache, # noqa: F821
79+
slot_mapping, # noqa: F821
80+
kv_cache_dtype,
81+
k_scale,
82+
v_scale,
83+
)
84+
85+
if benchmark_mode == "cudagraph":
86+
g = torch.cuda.CUDAGraph()
87+
with torch.cuda.graph(g):
88+
function_under_test()
89+
torch.cuda.synchronize()
90+
function_under_test = lambda: g.replay()
91+
92+
def run_cuda_benchmark(n_iters: int) -> float:
93+
nonlocal key, value, key_cache, value_cache, slot_mapping
94+
torch.cuda.synchronize()
95+
start = time.perf_counter()
96+
for _ in range(n_iters):
97+
function_under_test()
98+
torch.cuda.synchronize()
99+
end = time.perf_counter()
100+
return (end - start) / n_iters
101+
102+
# warm-up
103+
run_cuda_benchmark(3)
104+
105+
lat = run_cuda_benchmark(num_iters)
106+
107+
# free tensors to mitigate OOM when sweeping
108+
del key, value, key_cache, value_cache, slot_mapping
109+
torch.cuda.empty_cache()
110+
111+
return lat
112+
113+
114+
def main(args):
115+
rows = []
116+
for exp in range(1, 17):
117+
n_tok = 2**exp
118+
lat = run_benchmark(
119+
num_tokens=n_tok,
120+
num_heads=args.num_heads,
121+
head_size=args.head_size,
122+
block_size=args.block_size,
123+
num_blocks=args.num_blocks,
124+
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
125+
kv_cache_dtype=args.kv_cache_dtype,
126+
num_iters=args.iters,
127+
benchmark_mode=args.mode,
128+
device="cuda",
129+
)
130+
rows.append([n_tok, lat * 1e6]) # convert to microseconds
131+
132+
print(f"Benchmark results for implementation cuda (measuring with {args.mode}):")
133+
print(tabulate(rows, headers=["num_tokens", "latency (µs)"], floatfmt=".3f"))
134+
135+
136+
if __name__ == "__main__":
137+
parser = FlexibleArgumentParser()
138+
139+
parser.add_argument("--num-heads", type=int, default=128)
140+
parser.add_argument(
141+
"--head-size",
142+
type=int,
143+
choices=[64, 80, 96, 112, 120, 128, 192, 256],
144+
default=128,
145+
)
146+
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
147+
parser.add_argument("--num-blocks", type=int, default=128 * 128)
148+
149+
parser.add_argument(
150+
"--dtype",
151+
type=str,
152+
choices=["half", "bfloat16", "float"],
153+
default="bfloat16",
154+
)
155+
156+
parser.add_argument(
157+
"--kv-cache-dtype",
158+
type=str,
159+
choices=["auto", "fp8"],
160+
default="auto",
161+
)
162+
163+
parser.add_argument("--iters", type=int, default=200)
164+
165+
parser.add_argument(
166+
"--mode",
167+
type=str,
168+
choices=["cudagraph", "no_graph"],
169+
default="cudagraph",
170+
)
171+
172+
args = parser.parse_args()
173+
174+
main(args)

csrc/cache_kernels.cu

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616

1717
#include <algorithm>
1818
#include <cassert>
19-
#include <map>
20-
#include <vector>
19+
#include <cfloat> // FLT_MIN
2120

2221
#ifdef USE_ROCM
2322
#include <hip/hip_bf16.h>
@@ -209,6 +208,20 @@ void copy_blocks_mla(std::vector<torch::Tensor> const& kv_caches,
209208

210209
namespace vllm {
211210

211+
// Used to copy/convert one element
212+
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
213+
struct CopyWithScaleOp {
214+
float scale;
215+
216+
__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
217+
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
218+
dst = static_cast<OutT>(src);
219+
} else {
220+
dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
221+
}
222+
}
223+
};
224+
212225
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
213226
__global__ void reshape_and_cache_kernel(
214227
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
@@ -224,58 +237,50 @@ __global__ void reshape_and_cache_kernel(
224237
const int64_t token_idx = blockIdx.x;
225238
const int64_t slot_idx = slot_mapping[token_idx];
226239
if (slot_idx < 0) {
227-
// Padding token that should be ignored.
228240
return;
229241
}
230242

231243
const int64_t block_idx = slot_idx / block_size;
232244
const int64_t block_offset = slot_idx % block_size;
245+
const int h_block_count = head_size / x; // head_size//x
233246

234-
const int n = num_heads * head_size;
235-
for (int i = threadIdx.x; i < n; i += blockDim.x) {
236-
const int64_t src_key_idx = token_idx * key_stride + i;
237-
const int64_t src_value_idx = token_idx * value_stride + i;
238-
239-
const int head_idx = i / head_size;
240-
const int head_offset = i % head_size;
241-
const int x_idx = head_offset / x;
242-
const int x_offset = head_offset % x;
243-
244-
const int64_t tgt_key_idx =
245-
block_idx * num_heads * (head_size / x) * block_size * x +
246-
head_idx * (head_size / x) * block_size * x + x_idx * block_size * x +
247-
block_offset * x + x_offset;
248-
const int64_t tgt_value_idx =
249-
block_idx * num_heads * head_size * block_size +
250-
head_idx * head_size * block_size + head_offset * block_size +
251-
block_offset;
252-
scalar_t tgt_key = key[src_key_idx];
253-
scalar_t tgt_value = value[src_value_idx];
254-
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
255-
key_cache[tgt_key_idx] = tgt_key;
256-
value_cache[tgt_value_idx] = tgt_value;
257-
} else {
258-
key_cache[tgt_key_idx] =
259-
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
260-
value_cache[tgt_value_idx] =
261-
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
262-
}
247+
const int h_block_idx = threadIdx.x;
248+
if (h_block_idx >= num_heads * h_block_count) {
249+
return;
263250
}
264-
}
265251

266-
// Used by vectorization_utils to copy/convert one element
267-
template <typename OutT, typename InT, Fp8KVCacheDataType kv_dt>
268-
struct CopyWithScaleOp {
269-
float scale;
252+
const int head_idx = h_block_idx / h_block_count;
253+
const int h_block = h_block_idx % h_block_count;
270254

271-
__device__ __forceinline__ void operator()(OutT& dst, const InT src) const {
272-
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
273-
dst = static_cast<OutT>(src);
274-
} else {
275-
dst = fp8::scaled_convert<OutT, InT, kv_dt>(src, scale);
276-
}
255+
const scalar_t* __restrict__ key_src =
256+
key + token_idx * key_stride + head_idx * head_size + h_block * x;
257+
const int64_t src_value_start =
258+
token_idx * value_stride + head_idx * head_size + h_block * x;
259+
260+
cache_t* __restrict__ key_dst =
261+
key_cache + block_idx * num_heads * h_block_count * block_size * x +
262+
head_idx * h_block_count * block_size * x + h_block * block_size * x +
263+
block_offset * x;
264+
const int64_t tgt_value_start =
265+
block_idx * num_heads * h_block_count * x * block_size +
266+
head_idx * h_block_count * x * block_size + h_block * x * block_size +
267+
block_offset;
268+
269+
constexpr int VEC_SIZE = (sizeof(scalar_t) == 2) ? 8 : 4;
270+
float k_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *k_scale;
271+
CopyWithScaleOp<cache_t, scalar_t, kv_dt> k_op{k_scale_val};
272+
float v_scale_val = (kv_dt == Fp8KVCacheDataType::kAuto) ? 0.f : *v_scale;
273+
CopyWithScaleOp<cache_t, scalar_t, kv_dt> v_op{v_scale_val};
274+
275+
vectorize_with_alignment<VEC_SIZE>(key_src, key_dst, x, 0, 1, k_op);
276+
277+
const scalar_t* __restrict__ value_src = value + src_value_start;
278+
cache_t* __restrict__ value_dst = value_cache + tgt_value_start;
279+
#pragma unroll
280+
for (int i = 0; i < x; i++) {
281+
v_op(value_dst[i * block_size], value_src[i]);
277282
}
278-
};
283+
}
279284

280285
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
281286
__global__ void reshape_and_cache_flash_kernel(
@@ -601,9 +606,10 @@ void reshape_and_cache(
601606

602607
int key_stride = key.stride(0);
603608
int value_stride = value.stride(0);
609+
int head_div_x = head_size / x;
604610

605611
dim3 grid(num_tokens);
606-
dim3 block(std::min(num_heads * head_size, 512));
612+
dim3 block(std::min(num_heads * head_div_x, 512));
607613
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
608614
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
609615

0 commit comments

Comments
 (0)