Skip to content

Commit 6265f43

Browse files
[feat] support prefix cache clearing when /clear_load_weight is called (#4008)
* [feat] support clearing prefix cache (cherry-picked from release/2.1) * [fix] fix ipc suffix, use port instead * [fix] fix prefix caching not enabled * [fix] fix key/value_cache_scales indent * [fix] fix ep group all-reduce * [fix] fix clear/update lock not working when workers > 1 * [chore] add preemption triggered info log * [fix] fix code style * [fix] fix max_num_seqs config * [fix] do not force enable_prefix_caching=False in dynamic loading * [fix] fix ci * Revert "[fix] fix ci" This reverts commit 0bc6d55. * [fix] initialize available_gpu_block_num with max_gpu_block_num * [fix] fix config splitwise_role * [fix] fix clearing caches synchronization and add more logs * [chore] print cache_ready_signal in log * [fix] fix scheduler_config.splitwise_role * [fix] fix cache_messager cache_ready_signal create=True * [fix] stop cache messager from launching in mixed deployment
1 parent 59313ed commit 6265f43

20 files changed

+699
-215
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "helper.h"
16+
#include "cuda_multiprocess.h"
17+
18+
#if !defined(_WIN32)
19+
#include <errno.h>
20+
#include <string.h>
21+
#include <fcntl.h>
22+
#include <sys/mman.h>
23+
#include <sys/stat.h>
24+
#endif
25+
26+
// 可选:仅删除/解除共享内存命名对象(不依赖之前保存的 addr/fd)
27+
static inline int sharedMemoryUnlinkByName(const char* name) {
28+
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
29+
// Windows 上没有 shm_unlink 语义。命名对象在最后一个句柄关闭后消失。
30+
// 这里做“尽力而为”:尝试打开后立即关闭,减少一次引用。
31+
HANDLE hMap = OpenFileMappingA(FILE_MAP_ALL_ACCESS, FALSE, name);
32+
if (hMap) {
33+
CloseHandle(hMap);
34+
return 0;
35+
}
36+
// 已经不存在也算成功
37+
return 0;
38+
#else
39+
// POSIX: 移除名字,未来不可再 open;已映射区仍存活直至 munmap
40+
if (shm_unlink(name) != 0) {
41+
if (errno == ENOENT) return 0; // 不存在视作成功
42+
return errno;
43+
}
44+
return 0;
45+
#endif
46+
}
47+
48+
void UnsetDataIpc(const paddle::Tensor& tmp_input,
49+
const std::string& shm_name,
50+
bool close_ipc,
51+
bool unlink_shm) {
52+
// 1) 关闭消费者导入的 IPC 映射(仅当 close_ipc=true 且该指针确为 OpenMemHandle 得来)
53+
if (close_ipc) {
54+
void* ptr = const_cast<void*>(tmp_input.data());
55+
checkCudaErrors(cudaIpcCloseMemHandle(ptr));
56+
}
57+
58+
// 2) 解除共享内存命名对象(仅处理“名字”,不保证解除旧映射)
59+
if (unlink_shm) {
60+
int rc = sharedMemoryUnlinkByName(shm_name.c_str());
61+
if (rc != 0) {
62+
PD_THROW("Unlink shared memory failed: name=%s, err=%d",
63+
shm_name.c_str(), rc);
64+
}
65+
}
66+
}
67+
68+
PD_BUILD_STATIC_OP(unset_data_ipc)
69+
.Inputs({"tmp_input"})
70+
.Attrs({"shm_name: std::string", "close_ipc: bool", "unlink_shm: bool"})
71+
.SetKernelFn(PD_KERNEL(UnsetDataIpc));

custom_ops/setup_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def find_end_files(directory, end_str):
208208
"gpu_ops/rebuild_padding.cu",
209209
"gpu_ops/step.cu",
210210
"gpu_ops/set_data_ipc.cu",
211+
"gpu_ops/unset_data_ipc.cu",
211212
"gpu_ops/moe/tritonmoe_preprocess.cu",
212213
"gpu_ops/step_system_cache.cu",
213214
"gpu_ops/get_output_ep.cc",
@@ -278,6 +279,7 @@ def find_end_files(directory, end_str):
278279
"gpu_ops/beam_search_softmax.cu",
279280
"gpu_ops/rebuild_padding.cu",
280281
"gpu_ops/set_data_ipc.cu",
282+
"gpu_ops/unset_data_ipc.cu",
281283
"gpu_ops/read_data_ipc.cu",
282284
"gpu_ops/enforce_generation.cu",
283285
"gpu_ops/dequant_int8.cu",

fastdeploy/cache_manager/cache_messager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def __init__(
152152
cache_v = []
153153
self.messager = {}
154154
for layer_idx in range(self.num_layers):
155-
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
156-
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
155+
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}.device{gpu_id}"]
156+
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}.device{gpu_id}"]
157157
cache_k.append(key_cache)
158158
cache_v.append(val_cache)
159159
cache_k_ptr_list.append(key_cache.data_ptr())

fastdeploy/cache_manager/cache_transfer_manager.py

Lines changed: 179 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,27 @@
1616

1717
import argparse
1818
import concurrent.futures
19+
import gc
1920
import json
2021
import queue
22+
import threading
2123
import time
2224
import traceback
2325

2426
import numpy as np
2527
import paddle
2628

29+
from fastdeploy import envs
2730
from fastdeploy.cache_manager.cache_data import CacheStatus
2831
from fastdeploy.config import SpeculativeConfig
29-
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
32+
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
3033
from fastdeploy.model_executor.ops.gpu import (
3134
cuda_host_alloc,
35+
cuda_host_free,
36+
set_data_ipc,
3237
share_external_data,
3338
swap_cache_all_layers,
39+
unset_data_ipc,
3440
)
3541
from fastdeploy.utils import get_logger
3642

@@ -93,6 +99,7 @@ def parse_args():
9399
help="speculative config",
94100
)
95101
parser.add_argument("--local_data_parallel_id", type=int, default=0)
102+
parser.add_argument("--create_cache_tensor", action="store_true")
96103

97104
args = parser.parse_args()
98105
return args
@@ -110,7 +117,6 @@ def __init__(self, args):
110117

111118
device = args.device_id
112119
rank = args.rank
113-
paddle.set_device(f"gpu:{device}")
114120
self.gpu_cache_kvs = {}
115121
self.cpu_cache_kvs = {}
116122
self.gpu_cache_k_tensors = []
@@ -126,6 +132,7 @@ def __init__(self, args):
126132
self.n_ranks = args.mp_num
127133
self.rank = rank
128134
self.device = device
135+
self.engine_pid = args.engine_pid
129136

130137
address = (args.pod_ip, args.cache_queue_port)
131138
self.cache_task_queue = EngineCacheQueue(
@@ -136,57 +143,27 @@ def __init__(self, args):
136143
local_data_parallel_id=args.local_data_parallel_id,
137144
)
138145

139-
self.num_cpu_blocks = args.num_cpu_blocks
140-
141-
cache_type = args.cache_dtype
142-
cache_shape = [
143-
args.num_gpu_blocks,
144-
args.kv_num_head,
145-
args.block_size,
146-
args.head_dim,
147-
]
148-
149-
for i in range(args.num_layers + self.num_extra_layers):
150-
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
151-
cache_shape[0] = num_gpu_blocks
152-
key_name = f"key_caches_{i}_rank{rank}.device{device}"
153-
value_name = f"value_caches_{i}_rank{rank}.device{device}"
154-
key_cache = paddle.empty(shape=[], dtype=cache_type)
155-
value_cache = paddle.empty(shape=[], dtype=cache_type)
156-
key_cache = share_external_data(key_cache, key_name, cache_shape)
157-
value_cache = share_external_data(value_cache, value_name, cache_shape)
158-
self.gpu_cache_kvs[key_name] = key_cache
159-
self.gpu_cache_kvs[value_name] = value_cache
160-
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
161-
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[value_name])
162-
163-
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
164-
logger.info(f"device :{self.device}")
165-
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
166-
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
167-
168-
paddle.set_device("cpu")
169-
self.k_dst_ptrs = []
170-
self.v_dst_ptrs = []
171-
for i in range(args.num_layers + self.num_extra_layers):
172-
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
173-
args.num_cpu_blocks * args.bytes_per_layer_per_block
174-
)
175-
self.k_dst_ptrs.append(self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"])
176-
self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"] = cuda_host_alloc(
177-
args.num_cpu_blocks * args.bytes_per_layer_per_block
178-
)
179-
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
180-
181146
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
182147
self.cache_ready_signal = IPCSignal(
183148
name="cache_ready_signal",
184149
array=cache_ready_signal_data,
185150
dtype=np.int32,
186-
suffix=args.engine_pid,
151+
suffix=self.engine_pid,
152+
create=False,
153+
)
154+
swap_space_ready_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
155+
self.swap_space_ready_signal = IPCSignal(
156+
name="swap_space_ready_signal",
157+
array=swap_space_ready_data,
158+
dtype=np.int32,
159+
suffix=self.engine_pid,
187160
create=False,
188161
)
189-
self.cache_ready_signal.value[self.rank] = 1
162+
163+
self.num_cpu_blocks = args.num_cpu_blocks
164+
165+
self._init_cpu_cache(args)
166+
self._init_gpu_cache(args)
190167

191168
cache_task_broadcast_data = np.zeros(shape=[1], dtype=np.int32)
192169
self.cache_task_broadcast_signal = IPCSignal(
@@ -197,6 +174,76 @@ def __init__(self, args):
197174
create=False,
198175
)
199176

177+
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
178+
179+
def _init_gpu_cache(self, args):
180+
181+
if not args.create_cache_tensor:
182+
logger.info(f"[rank {self.rank}/{self.n_ranks}] Waiting for runners to create kv cache.")
183+
while self.cache_ready_signal.value[self.rank] != 1:
184+
time.sleep(0.1)
185+
logger.info(f"[rank {self.rank}/{self.n_ranks}] OK! Stop waiting.")
186+
187+
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing kv cache for all layers.")
188+
paddle.set_device(f"gpu:{self.device}")
189+
for i in range(args.num_layers + self.num_extra_layers):
190+
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
191+
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
192+
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
193+
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
194+
195+
if args.create_cache_tensor:
196+
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..creating kv cache for layer {i}: {cache_shape}")
197+
key_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
198+
val_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
199+
set_data_ipc(key_cache, key_name)
200+
set_data_ipc(val_cache, val_name)
201+
else:
202+
logger.info(f"[rank {self.rank}/{self.n_ranks}] ..attaching kv cache for layer {i}: {cache_shape}")
203+
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
204+
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
205+
key_cache = share_external_data(key_cache, key_name, cache_shape)
206+
val_cache = share_external_data(val_cache, val_name, cache_shape)
207+
208+
self.gpu_cache_kvs[key_name] = key_cache
209+
self.gpu_cache_kvs[val_name] = val_cache
210+
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
211+
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
212+
213+
if args.create_cache_tensor:
214+
logger.info("[rank {self.rank}/{self.n_ranks}] ✅ kv cache is ready!")
215+
self.cache_ready_signal.value[self.rank] = 1
216+
217+
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
218+
logger.info(f"[rank {self.rank}/{self.n_ranks}] device :{self.device}")
219+
logger.info(f"[rank {self.rank}/{self.n_ranks}] cache_kv_size_byte : {cache_kv_size_byte}")
220+
logger.info(
221+
f"[rank {self.rank}/{self.n_ranks}] done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}"
222+
)
223+
224+
def _init_cpu_cache(self, args):
225+
if args.num_cpu_blocks == 0:
226+
logger.info(f"[rank {self.rank}/{self.n_ranks}] 💡 no swap space (cpu cache) is specified.")
227+
self.swap_space_ready_signal.value[self.rank] = 1
228+
return
229+
logger.info(f"[rank {self.rank}/{self.n_ranks}] Initializing swap space (cpu cache) for all layers.")
230+
paddle.set_device("cpu")
231+
self.k_dst_ptrs = []
232+
self.v_dst_ptrs = []
233+
for i in range(args.num_layers + self.num_extra_layers):
234+
key_name = f"key_caches_{i}_rank{self.rank}"
235+
val_name = f"value_caches_{i}_rank{self.rank}"
236+
need_to_allocate_bytes = args.num_cpu_blocks * args.bytes_per_layer_per_block
237+
logger.info(
238+
f"[rank {self.rank}/{self.n_ranks}] ..creating cpu cache for layer {i}: {2 * need_to_allocate_bytes / 1024 ** 3:.2f}GB"
239+
)
240+
self.cpu_cache_kvs[key_name] = cuda_host_alloc(need_to_allocate_bytes)
241+
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
242+
self.cpu_cache_kvs[val_name] = cuda_host_alloc(need_to_allocate_bytes)
243+
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
244+
logger.info(f"[rank {self.rank}/{self.n_ranks}] ✅ swap space (cpu cache) is ready!")
245+
self.swap_space_ready_signal.value[self.rank] = 1
246+
200247
def _do_swap_to_cpu_task(
201248
self,
202249
swap_node_ids,
@@ -394,6 +441,92 @@ def _transfer_data(
394441
transfer_task_id,
395442
)
396443

444+
def clear_or_update_caches(self, args):
445+
logger.info("Start a thread to clear/restore kv cache when model weights are cleared/updated.")
446+
logger.info(f"FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}")
447+
kv_cache_status = np.zeros([1], dtype=np.int32)
448+
kv_cache_status_signal = IPCSignal(
449+
name="kv_cache_status",
450+
array=kv_cache_status,
451+
dtype=np.int32,
452+
suffix=self.engine_pid,
453+
create=False,
454+
)
455+
while True:
456+
if kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
457+
try:
458+
logger.info(
459+
f"[rank {self.rank}/{self.n_ranks}] Start clearing caches {self.cache_ready_signal.value}"
460+
)
461+
# clear cpu caches
462+
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
463+
paddle.set_device("cpu")
464+
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
465+
cuda_host_free(ptrs)
466+
self.cpu_cache_kvs.clear()
467+
self.k_dst_ptrs.clear()
468+
self.v_dst_ptrs.clear()
469+
gc.collect()
470+
# reset swap_space_ready_signal
471+
self.swap_space_ready_signal.value[self.rank] = 0
472+
while np.sum(self.swap_space_ready_signal.value) != 0:
473+
time.sleep(0.1)
474+
475+
# clear gpu caches
476+
paddle.set_device(f"gpu:{self.device}")
477+
for name, tensor in self.gpu_cache_kvs.items():
478+
unset_data_ipc(tensor, name, True, False)
479+
self.gpu_cache_kvs.clear()
480+
self.gpu_cache_k_tensors.clear()
481+
self.gpu_cache_v_tensors.clear()
482+
483+
# reset cache_ready_signal
484+
self.cache_ready_signal.value[self.rank] = 0
485+
logger.info(
486+
f"[rank {self.rank}/{self.n_ranks}] Finish clearing caches {self.cache_ready_signal.value}"
487+
)
488+
489+
# wait for all ranks caches to be cleared
490+
if np.sum(self.cache_ready_signal.value) != 0:
491+
time.sleep(0.1)
492+
493+
# reset kv_cache_status_signal
494+
kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
495+
logger.info("All ranks finish clearing caches")
496+
497+
except Exception as e:
498+
logger.error(f"[rank {self.rank}/{self.n_ranks}] Failed to clear caches: {e}")
499+
500+
elif kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
501+
try:
502+
logger.info(
503+
f"[rank {self.rank}/{self.n_ranks}] Start restoring caches {self.cache_ready_signal.value}"
504+
)
505+
# restore cpu cache
506+
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
507+
self._init_cpu_cache(args)
508+
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
509+
time.sleep(0.1)
510+
511+
# restore gpu cache and set cache_ready_signal
512+
self._init_gpu_cache(args)
513+
logger.info(
514+
f"[rank {self.rank}/{self.n_ranks}] Finish restoring caches {self.cache_ready_signal.value}"
515+
)
516+
517+
# wait for all ranks caches to be ready
518+
while np.sum(self.cache_ready_signal.value) != args.mp_num:
519+
time.sleep(0.1)
520+
521+
# set kv_cache_status_signal
522+
logger.info("All ranks finish restoring caches")
523+
kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
524+
525+
except Exception as e:
526+
logger.error(f"[rank {self.rank}/{self.n_ranks}] Failed to restore caches: {e}")
527+
528+
time.sleep(0.1)
529+
397530

398531
def main():
399532
"""

0 commit comments

Comments
 (0)