Skip to content

Commit 17b117f

Browse files
committed
Finalized async kvcache manager implementation
1 parent 01d13e2 commit 17b117f

21 files changed

+2703
-693
lines changed

examples/commons/ops/cuda_ops/csrc/kvcache_manager_impl.cpp

Lines changed: 1216 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
/******************************************************************************
2+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
# Implementation based on FlashInfer library.
18+
#
19+
******************************************************************************/
20+
21+
#include <cuda_bf16.h>
22+
#include <cuda_fp16.h>
23+
#include <driver_types.h>
24+
#include <c10/cuda/CUDAGuard.h>
25+
#include <c10/cuda/CUDAStream.h>
26+
#include <ATen/ATen.h>
27+
#include <torch/extension.h>
28+
#include <torch/serialize/tensor.h>
29+
30+
#include <barrier>
31+
#include <iomanip>
32+
#include <iostream>
33+
#include <list>
34+
#include <memory>
35+
#include <queue>
36+
#include <thread>
37+
#include <unordered_set>
38+
#include <unordered_map>
39+
#include <vector>
40+
41+
#include "nvcomp/ans.h"
42+
43+
namespace kvcache {
44+
45+
class MultithreadMemcpyProcessor
46+
{
47+
public:
48+
MultithreadMemcpyProcessor(int num_workers);
49+
~MultithreadMemcpyProcessor();
50+
51+
inline const size_t num_workers() const;
52+
void memcpy(void* dst, void* src, size_t bytes, size_t bytes_part);
53+
54+
private:
55+
void memcpy_coworker_loop(const int idx);
56+
57+
private:
58+
int num_workers_;
59+
std::vector<std::thread> workers_;
60+
std::barrier<> start_barrier_;
61+
std::barrier<> end_barrier_;
62+
char *dst_;
63+
char *src_;
64+
size_t localbytes_;
65+
bool terminate_;
66+
};
67+
68+
class PinnedDoubleBuffer {
69+
public:
70+
PinnedDoubleBuffer(size_t buffer_bytes);
71+
~PinnedDoubleBuffer();
72+
73+
public:
74+
std::vector<char *> ptr_;
75+
std::vector<cudaEvent_t> cuda_event_;
76+
};
77+
78+
class KVCompressor {
79+
public:
80+
KVCompressor(int max_num_chunks, size_t chunk_numel, size_t chunk_bytes);
81+
~KVCompressor();
82+
83+
void set_compress_input_buffer_ptrs(char *base_ptr, size_t num_chunks); // call once
84+
void set_decompress_output_buffer_ptrs(char *base_ptr, size_t num_chunks, cudaStream_t stream); // call multiples
85+
86+
void compress(
87+
size_t *compressed_bytes_cpu,
88+
size_t num_chunks,
89+
cudaStream_t stream);
90+
void decompress(
91+
size_t *compressed_bytes_cpu,
92+
size_t num_chunks,
93+
cudaStream_t stream);
94+
95+
public:
96+
char *comp_out_buffer();
97+
char *comp_out_buffer(int index);
98+
99+
char *decomp_in_buffer();
100+
char *decomp_in_buffer(int index);
101+
102+
private:
103+
int max_num_chunks_;
104+
size_t chunk_numel_;
105+
size_t chunk_bytes_;
106+
107+
size_t max_comp_chunk_bytes_;
108+
109+
char *comp_out_buffer_;
110+
void **comp_in_ptrs_; // setup once
111+
size_t *comp_in_bytes_; // setup once
112+
void **comp_out_ptrs_; // to internal
113+
size_t *comp_out_bytes_; // output
114+
// size_t *comp_out_bytes_cpu_; // output
115+
116+
char *decomp_in_buffer_;
117+
void **decomp_in_ptrs_; // to internal
118+
size_t *decomp_in_bytes_; // setup multiple times
119+
// size_t *decomp_in_bytes_cpu_; // setup multiple times
120+
void **decomp_out_ptrs_; // setup multiple times
121+
size_t *decomp_out_bytes_; // may ignore
122+
size_t *decomp_buffer_bytes_; // setup internal
123+
124+
125+
void *comp_tmp_buffer_;
126+
size_t comp_tmp_bytes_;
127+
void *decomp_tmp_buffer_;
128+
size_t decomp_tmp_bytes_;
129+
130+
nvcompStatus_t *comp_status_;
131+
nvcompStatus_t *comp_status_cpu_;
132+
nvcompStatus_t *decomp_status_;
133+
nvcompStatus_t *decomp_status_cpu_;
134+
135+
const nvcompBatchedANSCompressOpts_t k_comp_opts_ = {nvcomp_rANS, NVCOMP_TYPE_FLOAT16, {0}};
136+
const nvcompBatchedANSDecompressOpts_t k_decomp_opts_ = nvcompBatchedANSDecompressDefaultOpts;
137+
};
138+
139+
class GPUKVCacheMangerImpl;
140+
class HostKVStorageImpl;
141+
142+
class KVOnloadHandle {
143+
public:
144+
KVOnloadHandle(int num_layers);
145+
~KVOnloadHandle() ;
146+
147+
void reset();
148+
void complete_host(int layer_idx);
149+
void complete_host(int layer_idx, cudaStream_t stream);
150+
void wait_host(int layer_idx);
151+
152+
public:
153+
int num_layers;
154+
std::vector<cudaEvent_t> event;
155+
std::mutex mtx_;
156+
std::condition_variable cv_;
157+
std::vector<int> host_complete;
158+
};
159+
160+
class KVOffloadHandle {
161+
public:
162+
KVOffloadHandle();
163+
KVOffloadHandle(
164+
int num_layers,
165+
GPUKVCacheMangerImpl& gpu_kv_mgr,
166+
bool has_offload
167+
);
168+
169+
void mark_ready(int layer_idx);
170+
void set_no_offload();
171+
172+
public:
173+
GPUKVCacheMangerImpl* gpu_kv_mgr;
174+
int num_layers;
175+
std::vector<cudaEvent_t> ready_event;
176+
int *host_ready;
177+
bool no_offload;
178+
};
179+
180+
class HostKVStorageImpl
181+
{
182+
public:
183+
HostKVStorageImpl(
184+
int num_layers,
185+
int num_kv_heads,
186+
int kv_headdim,
187+
int num_tokens_per_page,
188+
int64_t num_tokens_per_chunk
189+
);
190+
~HostKVStorageImpl();
191+
192+
int64_t get_kvdata_length(int64_t user_id);
193+
194+
void append_kvdata(
195+
int64_t user_id, int64_t start_position, int64_t length,
196+
uint16_t *kvdata_buffer, size_t buffer_layer_stride);
197+
void append_kvdata(
198+
int64_t user_id, int64_t start_position, int64_t length,
199+
uint16_t *kvdata_buffer, size_t buffer_layer_stride,
200+
size_t *kvdata_bytes, size_t bytes_layer_stride);
201+
202+
std::vector<uint16_t*> get_kvdata(int64_t user_id, int64_t length, int64_t layer_idx);
203+
std::vector<size_t> get_kvdata_bytes(int64_t user_id, int64_t length, int64_t layer_idx);
204+
205+
public:
206+
std::vector<at::Tensor> get_kvdata_tensor(std::vector<int64_t> user_ids, bool with_concat = true);
207+
void init_random_kvdata(int64_t user_id, size_t num_tokens);
208+
209+
public:
210+
const int num_layers;
211+
const int num_kv_heads;
212+
const int kv_headdim;
213+
const int page_size;
214+
215+
const int64_t chunk_size;
216+
size_t chunk_numel;
217+
size_t page_numel;
218+
size_t per_token_numel;
219+
size_t layer_numel;
220+
221+
std::vector<std::unordered_map<int64_t, std::vector<uintptr_t>>> _uid_to_chunk_id;
222+
std::vector<std::unordered_map<int64_t, std::vector<size_t>>> _uid_to_chunk_bytes;
223+
std::unordered_map<int64_t, int64_t> _uid_to_length;
224+
std::unordered_map<int64_t, std::vector<uintptr_t>> _uid_to_mempool;
225+
std::mutex host_kvcache_mutex_;
226+
};
227+
228+
class GPUKVCacheMangerImpl
229+
{
230+
public:
231+
GPUKVCacheMangerImpl(
232+
int num_layers,
233+
int num_kv_heads,
234+
int kv_headdim,
235+
int num_tokens_per_page,
236+
int num_primary_cache_pages,
237+
int num_onload_buffer_pages,
238+
int num_reserved_buffer_pages,
239+
int num_tokens_per_chunk,
240+
int max_num_sequences,
241+
int max_sequence_length,
242+
at::Tensor cache_table_tensor,
243+
HostKVStorageImpl& host_kv_mgr,
244+
size_t max_queued_offload_tokens,
245+
int onload_buffer_chunks = 1,
246+
int offload_buffer_chunks = 8,
247+
int num_memcpy_workers = 4,
248+
bool enable_nvcomp = false);
249+
~GPUKVCacheMangerImpl();
250+
251+
int64_t getUIdToEvict(std::unordered_set<int64_t> extra_freezed_uids);
252+
253+
std::vector<int32_t>& alloc(int64_t uid, int new_total_length, std::unordered_set<int64_t> freezed_uids);
254+
std::vector<int32_t> get_total_cache_length(std::vector<int64_t>& uids);
255+
256+
void evict(int64_t uid);
257+
void evict_all();
258+
void invalid(int64_t uid);
259+
bool retain(int64_t uid);
260+
261+
uint16_t *get_cache_table(void);
262+
uint16_t *get_cache_table_by_layer(int layer_idx);
263+
264+
public:
265+
void onload_kvcache(
266+
std::vector<int64_t>& user_ids,
267+
KVOnloadHandle& onloadhandle);
268+
269+
void offload_kvcache(
270+
KVOffloadHandle& offload_handle,
271+
at::Tensor offload_user_ids, // host -> make static
272+
at::Tensor offload_page_ids, // gpu -> make static
273+
at::Tensor new_offload_startpos, // host
274+
at::Tensor new_offload_lengths); // host
275+
276+
bool is_busy_offloading();
277+
278+
public:
279+
void init_random_offload_status(int64_t user_id, size_t length);
280+
281+
private:
282+
void offload_loop();
283+
284+
public:
285+
int num_layers;
286+
int num_kv_heads;
287+
int kv_headdim;
288+
int num_tokens_per_page;
289+
int num_primary_cache_pages;
290+
int num_onload_buffer_pages;
291+
int num_reserved_buffer_pages;
292+
int num_tokens_per_chunk;
293+
int max_num_sequences;
294+
int max_sequence_length;
295+
296+
size_t layer_stride;
297+
size_t k2v_stride;
298+
size_t page_stride;
299+
size_t per_token_kv_stride;
300+
301+
public:
302+
// kvcache bookkeeping
303+
std::list<int64_t> _lru_list;
304+
std::unordered_map<int64_t,
305+
typename std::list<int64_t>::iterator> _lru_lookup_table;
306+
std::queue<int64_t> _empty_pages;
307+
std::unordered_map<int64_t, std::vector<int32_t>> _uid_to_page_id;
308+
std::unordered_map<int64_t, int32_t> _uid_to_paged_cache_startpos;
309+
std::unordered_map<int64_t, int32_t> _uid_to_paged_cache_length;
310+
std::unordered_map<int64_t, int32_t> _uid_to_offloaded_length;
311+
312+
// threadpool
313+
std::thread offload_worker;
314+
bool terminate_;
315+
std::atomic<bool> offload_busy_;
316+
317+
// offloading shared objects
318+
std::queue<std::tuple<std::vector<int>, at::Tensor, std::vector<cudaEvent_t>, int*>> offload_task_queue;
319+
std::mutex offload_task_mutex_;
320+
std::condition_variable offload_task_cv_;
321+
322+
// offloading limiter
323+
std::unordered_map<int64_t, int> queued_offload_lastpos;
324+
size_t queued_offload_tokens;
325+
std::mutex queued_offload_lastpos_mutex_;
326+
size_t queued_offload_limits;
327+
328+
// internal device buffer
329+
uint16_t* onload_device_buffers;
330+
int num_onload_device_chunks;
331+
uint16_t* offload_device_buffers;
332+
int num_offload_device_chunks;
333+
334+
// external offloading synchronization
335+
std::mutex offload_ready_mtx_;
336+
std::condition_variable offload_ready_cv_;
337+
338+
// allocation-vs-offloading synchronization
339+
std::unordered_map<int64_t, int> offload_freezed_uids_;
340+
std::mutex offload_freezed_uids_mtx_;
341+
342+
bool enable_nvcomp;
343+
KVCompressor compressor;
344+
345+
cudaStream_t worker_stream;
346+
cudaStream_t onload_stream;
347+
cudaStream_t offload_stream;
348+
349+
PinnedDoubleBuffer onload_pin_buffer;
350+
PinnedDoubleBuffer offload_pin_buffer;
351+
352+
MultithreadMemcpyProcessor onload_memcpy_workers;
353+
MultithreadMemcpyProcessor offload_memcpy_workers;
354+
355+
HostKVStorageImpl *host_kv_mgr;
356+
357+
public:
358+
uint16_t *cache_table;
359+
c10::Device device;
360+
};
361+
362+
void prepare_kvcache(
363+
GPUKVCacheMangerImpl& gpu_mgr,
364+
HostKVStorageImpl& host_mgr,
365+
std::vector<int64_t>& user_ids,
366+
std::vector<int64_t>& total_hist_lens, // all histo w/o candi
367+
at::Tensor page_ids_gpu_buffer,
368+
at::Tensor offload_page_ids_gpu_buffer,
369+
at::Tensor offload_uids_buffer,
370+
at::Tensor metadata_host_buffer,
371+
at::Tensor metadata_gpu_buffer);
372+
373+
} // namespace kvcache

0 commit comments

Comments
 (0)