Skip to content

Commit 7cf5fdb

Browse files
committed
fix
1 parent c69e522 commit 7cf5fdb

File tree

2 files changed

+69
-44
lines changed

2 files changed

+69
-44
lines changed

lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode
77
from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager
8+
from lightllm.utils.log_utils import init_logger
9+
10+
logger = init_logger(__name__)
811

912

1013
class HybridMemManager(MemoryManager):
@@ -30,6 +33,10 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None)
3033
super().__init__(unique_name, total_token_num, rank_in_node, mem_manager)
3134
# 用于缓存需要被驱逐的buffer节点, 应该包含所有有buffer的节点
3235
self.evict_buffer_set: Set[TreeNode] = SortedSet(key=lambda x: (x.buffer_time,))
36+
self.match_count = 0
37+
self.log_interval = 1000
38+
self.match_len = 0
39+
self.hit_len = 0
3340

3441
def free_radix_cache_to_get_enough_buffer(self, need_buffer_num):
3542
if need_buffer_num > self.mem_manager.get_buffer_can_use_size():
@@ -47,14 +54,14 @@ def release_buffer(buffer_idx):
4754
release_buffers.append(buffer_idx)
4855
return
4956

50-
self.evict_buffer(need_evict_buffer_num, release_buffer, release_mem)
57+
self._evict_buffer(need_evict_buffer_num, release_buffer, release_mem)
5158
self.mem_manager.free_buffer(release_buffers)
5259
if len(release_mems) > 0:
5360
mem_index = torch.concat(release_mems)
5461
self.mem_manager.free(mem_index)
5562
return
5663

57-
def evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_token_callback):
64+
def _evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_token_callback):
5865
while need_evict_buffer_num > 0:
5966
node = self.evict_buffer_set.pop(0)
6067
assert node.buffer_idx is not None
@@ -78,6 +85,7 @@ def insert_for_hybrid_radix_cache(self, reqs):
7885

7986
self.free_radix_cache_to_get_enough_buffer(len(reqs))
8087
new_buffer_indexes = self.mem_manager.alloc_buffer(len(reqs))
88+
# req_ids_gpu = req_ids.cuda()
8189

8290
for i, req in enumerate(reqs):
8391
input_token_ids = req.get_input_token_ids()
@@ -88,16 +96,22 @@ def insert_for_hybrid_radix_cache(self, reqs):
8896
# 分配新的 buffer 并复制当前 buffer 的内容
8997
self.mem_manager.copy_buffer(cur_buffer_idx, new_buffer_indexes[i])
9098

91-
_, new_shared_kv_node = super().insert(key, value)
99+
prefix_len, new_shared_kv_node = super().insert(key, value)
92100
self.dec_node_ref_counter(req.shared_kv_node)
93101
self.add_node_ref_counter(new_shared_kv_node)
94-
self.set_node_buffer_idx(new_shared_kv_node, new_buffer_indexes[i].item())
102+
self.add_buffer_idx_to_node(new_shared_kv_node, new_buffer_indexes[i].item())
95103
req.shared_kv_node = new_shared_kv_node
104+
# 更新 prompt_cache_len,这样 free_a_req_mem 不会释放已属于树的 token
105+
# free_a_req_mem 中会释放 [prompt_cache_len:prefix_len],更新后这个范围为空
106+
req.shm_req.prompt_cache_len = req.cur_kv_len
96107

97108
def match_prefix(self, key, update_refs=False):
98109
assert len(key) != 0
110+
self.match_count = (self.match_count + 1) % self.log_interval
111+
self.match_len += len(key)
99112
ans_value_list = []
100113
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
114+
origin_ans_len = sum(len(v) for v in ans_value_list)
101115
evict_token_list = []
102116
while tree_node != self.root_node and tree_node.buffer_idx is None:
103117
if tree_node.is_leaf():
@@ -126,7 +140,7 @@ def match_prefix(self, key, update_refs=False):
126140
self.mem_manager.free(evict_token_value)
127141

128142
if tree_node == self.root_node:
129-
return None, 0, None
143+
return None, origin_ans_len, None
130144

131145
update_node = tree_node
132146
while update_node != self.root_node:
@@ -137,14 +151,31 @@ def match_prefix(self, key, update_refs=False):
137151
update_node = update_node.parent
138152

139153
value = torch.concat(ans_value_list)
140-
return tree_node, len(value), value
141-
142-
def set_node_buffer_idx(self, node: TreeNode, buffer_idx: int):
154+
# logger.info("HybridRadixCache match_prefix hit tokens: {}".format(len(value)))
155+
self.hit_len += len(value)
156+
if self.match_count == 0:
157+
logger.info(
158+
f"HybridRadixCache match_prefix avg hit rate: {self.hit_len / self.match_len:.4f} "
159+
f"({self.hit_len}/{self.match_len}) over last {self.log_interval} matches"
160+
)
161+
self.match_len = 0
162+
self.hit_len = 0
163+
164+
return tree_node, origin_ans_len, value
165+
166+
def add_buffer_idx_to_node(self, node: TreeNode, buffer_idx: int):
143167
"""Set buffer_idx for a node and add it to evict_buffer_set."""
144-
node.buffer_idx = buffer_idx
145168
self.evict_buffer_set.discard(node)
169+
if node.is_leaf():
170+
self.evict_tree_set.discard(node)
171+
if node.buffer_idx is not None:
172+
self.mem_manager.free_buffer([node.buffer_idx])
173+
node.buffer_idx = buffer_idx
146174
node.update_buffer_time()
147175
self.evict_buffer_set.add(node)
176+
if node.is_leaf():
177+
self.evict_tree_set.add(node)
178+
return
148179

149180
def free_radix_cache_to_get_enough_token(self, need_token_num):
150181
assert self.mem_manager is not None

lightllm/server/router/model_infer/infer_batch.py

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -71,39 +71,31 @@ def get_cpu_kv_cache_stream(self) -> torch.cuda.Stream:
7171
self.cpu_kv_cache_stream = torch.cuda.Stream()
7272
return self.cpu_kv_cache_stream
7373

74-
def _maybe_alloc_and_copy_req_buffers(self, req_objs: List["InferReq"]) -> None:
75-
"""
76-
For hybrid/linear-attention models (e.g. Qwen3-Next) we allocate a fixed-size buffer per request.
77-
If radix cache hits and the matched node has a buffer, copy that buffer content to the newly
78-
allocated buffer for this request.
79-
"""
80-
if not self.use_buffer_manager or not req_objs:
81-
return
82-
74+
def _alloc_and_copy_req_buffers(self, req_objs: List["InferReq"]) -> None:
75+
# 为请求分配 buffer, 如果 shared_kv_node 不为 None,则从 radix cache 复制 buffer。
8376
if self.radix_cache is not None:
84-
# Ensure enough buffer capacity by evicting radix cache buffers if needed.
8577
self.radix_cache.free_radix_cache_to_get_enough_buffer(len(req_objs))
8678

87-
req_idxs = np.array([r.req_idx for r in req_objs], dtype=np.int64)
88-
request_indices_gpu = torch.from_numpy(req_idxs).to(device="cuda", dtype=torch.int64)
79+
req_idxs = []
80+
copy_indices = []
81+
copy_buffers = []
82+
83+
for r in req_objs:
84+
req_idxs.append(r.req_idx)
85+
if r.shared_kv_node is not None:
86+
copy_indices.append(r.req_idx)
87+
copy_buffers.append(r.shared_kv_node.buffer_idx)
88+
89+
request_indices_gpu = torch.tensor(req_idxs, device="cuda", dtype=torch.int64)
8990
self.req_manager.alloc_buffer_for_req(request_indices_gpu)
9091

9192
if self.radix_cache is None:
9293
return
9394

94-
# `shared_kv_node` may be None on cache miss; treat it as "no buffer to copy".
95-
buffer_idxs = np.array(
96-
[None if r.shared_kv_node is None else r.shared_kv_node.buffer_idx for r in req_objs], dtype=object
97-
)
98-
mask = buffer_idxs == None # noqa: E711 (intentional elementwise comparison against None)
99-
copy_indices = req_idxs[~mask].tolist()
100-
if not copy_indices:
101-
return
102-
103-
copy_buffers = buffer_idxs[~mask].tolist()
104-
copy_indices_tensor = torch.tensor(copy_indices, device="cuda", dtype=torch.int64)
105-
copy_buffers_tensor = torch.tensor(copy_buffers, device="cuda", dtype=torch.int64)
106-
self.req_manager.copy_buffer_from_another_buffer(copy_buffers_tensor, copy_indices_tensor)
95+
if copy_indices:
96+
copy_indices_tensor = torch.tensor(copy_indices, device="cuda", dtype=torch.int64)
97+
copy_buffers_tensor = torch.tensor(copy_buffers, device="cuda", dtype=torch.int64)
98+
self.req_manager.copy_buffer_from_another_buffer(copy_buffers_tensor, copy_indices_tensor)
10799

108100
def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: bool = True) -> List["InferReq"]:
109101
req_objs = []
@@ -143,8 +135,8 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache:
143135
slave_req: InferReq = slave_req
144136
slave_req.related_master_req = master_req
145137

146-
# Hybrid/linear-attention models
147-
self._maybe_alloc_and_copy_req_buffers(req_objs)
138+
if self.use_buffer_manager and len(req_objs) > 0:
139+
self._alloc_and_copy_req_buffers(req_objs)
148140

149141
return req_objs
150142

@@ -169,11 +161,11 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq", free_buffer_in
169161
if self.use_buffer_manager:
170162
buffer_idx = self.req_manager.req_to_buffer_index[req.req_idx].item()
171163
if node.buffer_idx is None:
172-
self.radix_cache.set_node_buffer_idx(node, buffer_idx)
164+
self.radix_cache.add_buffer_idx_to_node(node, buffer_idx)
173165
else:
174166
free_buffer_index.append(buffer_idx)
175167

176-
old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len
168+
old_prefix_len = req.shm_req.prompt_cache_len
177169
free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len])
178170
if req.shared_kv_node is not None:
179171
assert req.shared_kv_node.node_prefix_total_len <= prefix_len
@@ -218,7 +210,6 @@ def _filter(self, finished_request_ids: List[int]):
218210
self.req_manager.free(free_req_index, free_token_index)
219211

220212
if self.use_buffer_manager and len(free_buffer_index) != 0:
221-
free_buffer_index = torch.tensor(free_buffer_index, dtype=torch.int64, device="cpu")
222213
self.req_manager.free_buffer(free_buffer_index)
223214

224215
finished_req_ids_set = set(finished_request_ids)
@@ -278,6 +269,7 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool):
278269
def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bool, can_alloc_token_num: int):
279270
if paused_reqs:
280271
g_infer_state_lock.acquire()
272+
revovered_reqs = []
281273
for req in paused_reqs:
282274
prefill_need_token_num = req.get_cur_total_len()
283275
if prefill_need_token_num > can_alloc_token_num:
@@ -288,8 +280,10 @@ def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bo
288280
if is_master_in_dp:
289281
req.shm_req.is_paused = False
290282
can_alloc_token_num -= prefill_need_token_num
283+
revovered_reqs.append(req)
291284

292-
self._maybe_alloc_and_copy_req_buffers(paused_reqs)
285+
self._alloc_and_copy_req_buffers(revovered_reqs)
286+
g_infer_state_lock.release()
293287
return
294288

295289
def get_can_alloc_token_num(self):
@@ -413,14 +407,13 @@ def __init__(
413407
self.nixl_pd_task_failed_num: int = 0
414408
self.nixl_trans_device_id: int = -1
415409

410+
# 在开启radix cache的情况下,用于标记命中情况,用于插入算法
411+
self.mamba_model_match_len = 0
412+
416413
# 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache
417414
# 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态
418415
self.cpu_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED
419416

420-
# 用于管理该请求整个生命周期固定大小的 buffer 索引,None 表示未分配
421-
# 用于线性注意力模型,比如 Qwen3-Next
422-
self.buffer_idx: int = None
423-
424417
# mtp_step 用来记录一个请求 draft模型每步需要生成的token数量
425418
# 正常模式下,这个值为0,在 mtp 模式下,这个值为 draft 模型每步需要生成的token数量
426419
self.mtp_step: int = get_env_start_args().mtp_step
@@ -469,6 +462,7 @@ def _match_radix_cache(self):
469462
key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu")
470463
key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值
471464
share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True)
465+
self.mamba_model_match_len = kv_len
472466
if share_node is not None:
473467
self.shared_kv_node = share_node
474468
ready_cache_len = share_node.node_prefix_total_len

0 commit comments

Comments
 (0)