Skip to content

Commit 3b51af8

Browse files
author
wangzaijun
committed
fix radix cache
1 parent be62703 commit 3b51af8

File tree

4 files changed

+153
-71
lines changed

4 files changed

+153
-71
lines changed

lightllm/distributed/custom_all_gather.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from lightllm.utils.log_utils import init_logger
2929
from lightllm.utils.device_utils import has_nvlink
3030
from lightllm.utils.light_utils import light_ops
31-
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
3231

3332

3433
try:

lightllm/distributed/custom_all_reduce.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from lightllm.utils.device_utils import has_nvlink
3030
from lightllm.utils.sgl_utils import sgl_allreduce_ops
3131
from lightllm.utils.vllm_utils import vllm_ops
32-
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
3332

3433
logger = init_logger(__name__)
3534

@@ -225,6 +224,9 @@ def all_reduce(self, inp: torch.Tensor, *, out: torch.Tensor = None, registered:
225224
buffer.
226225
"""
227226
if out is None:
227+
# fix circle import
228+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
229+
228230
out = g_cache_manager.alloc_tensor(inp.shape, inp.dtype, device=inp.device, is_graph_out=False)
229231
if registered:
230232
ops.all_reduce(self._ptr, inp, out, 0, 0)
@@ -243,6 +245,9 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
243245
else:
244246
# If warm up, mimic the allocation pattern since custom
245247
# allreduce is out-of-place.
248+
# fix circle import
249+
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
250+
246251
out = g_cache_manager.alloc_tensor(input.shape, input.dtype, device=input.device, is_graph_out=False)
247252
return out
248253
else:

lightllm/server/router/dynamic_prompt/radix_cache.py

Lines changed: 143 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/router/radix_cache.py
22
import torch
33
import numpy as np
4-
from typing import Tuple, Dict, Set, List, Optional
4+
import collections
5+
from typing import Tuple, Dict, Set, List, Optional, Union
56
from sortedcontainers import SortedSet
67
from .shared_arr import SharedArray
7-
from lightllm.common.mem_manager import MemoryManager
88

99

1010
class UniqueTimeIdGenerator:
@@ -103,8 +103,10 @@ class RadixCache:
103103
unique_name 主要用于解决单机,多实列部署时的shm冲突
104104
"""
105105

106-
def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager: MemoryManager = None):
107-
self.mem_manager = mem_manager
106+
def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None):
107+
from lightllm.common.mem_manager import MemoryManager
108+
109+
self.mem_manager: MemoryManager = mem_manager
108110
self._key_dtype = torch.int64
109111
self._value_dtype = torch.int64
110112

@@ -133,58 +135,100 @@ def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]:
133135
return self._insert_helper(self.root_node, key, value)
134136

135137
def _insert_helper(self, node: TreeNode, key, value) -> Tuple[int, Optional[TreeNode]]:
138+
handle_stack = collections.deque()
139+
update_list = collections.deque()
140+
handle_stack.append((node, key, value))
141+
142+
ans_prefix_len = 0
143+
ans_node = None
144+
145+
while len(handle_stack) != 0:
146+
node, key, value = handle_stack.popleft()
147+
ans_tuple = self._insert_helper_no_recursion(node=node, key=key, value=value)
148+
if len(ans_tuple) == 4:
149+
(_prefix_len, new_node, new_key, new_value) = ans_tuple
150+
ans_prefix_len += _prefix_len
151+
handle_stack.append((new_node, new_key, new_value))
152+
else:
153+
_prefix_len, ans_node = ans_tuple
154+
ans_prefix_len += _prefix_len
155+
156+
update_list.append(node)
157+
158+
while len(update_list) != 0:
159+
cur_node: TreeNode = update_list.pop()
160+
cur_node.update_time()
161+
if cur_node.is_leaf():
162+
self.evict_tree_set.add(cur_node)
163+
164+
assert ans_node is not None
165+
166+
return ans_prefix_len, ans_node
167+
168+
def _insert_helper_no_recursion(
169+
self, node: TreeNode, key: torch.Tensor, value: torch.Tensor
170+
) -> Union[Tuple[int, Optional[TreeNode]], Tuple[int, TreeNode, torch.Tensor, torch.Tensor]]:
136171
if node.is_leaf():
137172
self.evict_tree_set.discard(node)
138173

139-
try:
140-
first_key_id = key[0].item()
141-
if first_key_id in node.children.keys():
142-
child: TreeNode = node.children[first_key_id]
143-
prefix_len = match(key, child.token_id_key)
144-
if prefix_len == len(key):
174+
first_key_id = key[0].item()
175+
if first_key_id in node.children.keys():
176+
child: TreeNode = node.children[first_key_id]
177+
prefix_len = match(key, child.token_id_key)
178+
if prefix_len == len(key):
179+
if prefix_len == len(child.token_id_key):
145180
if child.is_leaf():
146181
self.evict_tree_set.discard(child)
147182
child.update_time()
148183
if child.is_leaf():
149184
self.evict_tree_set.add(child)
150185
return prefix_len, child
151-
152-
elif prefix_len < len(key) and prefix_len < len(child.token_id_key):
186+
elif prefix_len < len(child.token_id_key):
153187
if child.is_leaf():
154188
self.evict_tree_set.discard(child)
155189

156-
key = key[prefix_len:]
157-
value = value[prefix_len:]
158190
split_parent_node = child.split_node(prefix_len)
159-
new_node = split_parent_node.add_and_return_new_child(key, value)
160-
# update total token num
161-
self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value)
162191

163192
if split_parent_node.is_leaf():
164193
self.evict_tree_set.add(split_parent_node)
165-
if new_node.is_leaf():
166-
self.evict_tree_set.add(new_node)
167-
168194
if child.is_leaf():
169195
self.evict_tree_set.add(child)
170-
return prefix_len, new_node
171-
elif prefix_len < len(key) and prefix_len == len(child.token_id_key):
172-
_prefix_len, ans_node = self._insert_helper(child, key[prefix_len:], value[prefix_len:])
173-
return prefix_len + _prefix_len, ans_node
196+
197+
return prefix_len, split_parent_node
174198
else:
175199
assert False, "can not run to here"
176200

177-
else:
178-
new_node = node.add_and_return_new_child(key, value)
201+
elif prefix_len < len(key) and prefix_len < len(child.token_id_key):
202+
if child.is_leaf():
203+
self.evict_tree_set.discard(child)
204+
205+
key = key[prefix_len:]
206+
value = value[prefix_len:]
207+
split_parent_node = child.split_node(prefix_len)
208+
new_node = split_parent_node.add_and_return_new_child(key, value)
179209
# update total token num
180210
self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value)
211+
212+
if split_parent_node.is_leaf():
213+
self.evict_tree_set.add(split_parent_node)
181214
if new_node.is_leaf():
182215
self.evict_tree_set.add(new_node)
183-
return 0, new_node
184-
finally:
185-
node.update_time()
186-
if node.is_leaf():
187-
self.evict_tree_set.add(node)
216+
217+
if child.is_leaf():
218+
self.evict_tree_set.add(child)
219+
return prefix_len, new_node
220+
elif prefix_len < len(key) and prefix_len == len(child.token_id_key):
221+
return (prefix_len, child, key[prefix_len:], value[prefix_len:])
222+
else:
223+
assert False, "can not run to here"
224+
225+
else:
226+
new_node = node.add_and_return_new_child(key, value)
227+
# update total token num
228+
self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value)
229+
if new_node.is_leaf():
230+
self.evict_tree_set.add(new_node)
231+
return 0, new_node
188232

189233
def match_prefix(self, key, update_refs=False):
190234
assert len(key) != 0
@@ -200,7 +244,39 @@ def match_prefix(self, key, update_refs=False):
200244
self.dec_node_ref_counter(self.root_node)
201245
return None, 0, None
202246

203-
def _match_prefix_helper(self, node: TreeNode, key, ans_value_list: list, update_refs=False) -> TreeNode:
247+
def _match_prefix_helper(
248+
self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False
249+
) -> TreeNode:
250+
handle_stack = collections.deque()
251+
update_list = collections.deque()
252+
handle_stack.append((node, key))
253+
254+
ans_node = None
255+
256+
while len(handle_stack) != 0:
257+
node, key = handle_stack.popleft()
258+
ans_tuple = self._match_prefix_helper_no_recursion(
259+
node=node, key=key, ans_value_list=ans_value_list, update_refs=update_refs
260+
)
261+
if isinstance(ans_tuple, tuple):
262+
new_node, new_key = ans_tuple
263+
handle_stack.append((new_node, new_key))
264+
else:
265+
ans_node = ans_tuple
266+
267+
update_list.append(node)
268+
269+
while len(update_list) != 0:
270+
cur_node: TreeNode = update_list.pop()
271+
cur_node.update_time()
272+
if cur_node.is_leaf():
273+
self.evict_tree_set.add(cur_node)
274+
275+
return ans_node
276+
277+
def _match_prefix_helper_no_recursion(
278+
self, node: TreeNode, key: torch.Tensor, ans_value_list: list, update_refs=False
279+
) -> TreeNode:
204280
if node.is_leaf():
205281
self.evict_tree_set.discard(node)
206282

@@ -210,44 +286,39 @@ def _match_prefix_helper(self, node: TreeNode, key, ans_value_list: list, update
210286
if node.ref_counter == 1:
211287
self.refed_tokens_num.arr[0] += len(node.token_mem_index_value)
212288

213-
try:
214-
if len(key) == 0:
215-
return node
289+
if len(key) == 0:
290+
return node
216291

217-
first_key_id = key[0].item()
218-
if first_key_id not in node.children.keys():
219-
return node
292+
first_key_id = key[0].item()
293+
if first_key_id not in node.children.keys():
294+
return node
295+
else:
296+
child = node.children[first_key_id]
297+
prefix_len = match(key, child.token_id_key)
298+
if prefix_len == len(child.token_id_key):
299+
ans_value_list.append(child.token_mem_index_value)
300+
return (child, key[prefix_len:])
301+
elif prefix_len < len(child.token_id_key):
302+
if child.is_leaf():
303+
self.evict_tree_set.discard(child)
304+
305+
split_parent_node = child.split_node(prefix_len)
306+
ans_value_list.append(split_parent_node.token_mem_index_value)
307+
308+
if update_refs:
309+
split_parent_node.ref_counter += 1
310+
# from 0 to 1 need update refs token num
311+
if split_parent_node.ref_counter == 1:
312+
self.refed_tokens_num.arr[0] += len(split_parent_node.token_mem_index_value)
313+
314+
if child.is_leaf():
315+
self.evict_tree_set.add(child)
316+
if split_parent_node.is_leaf():
317+
self.evict_tree_set.add(split_parent_node)
318+
319+
return split_parent_node
220320
else:
221-
child = node.children[first_key_id]
222-
prefix_len = match(key, child.token_id_key)
223-
if prefix_len == len(child.token_id_key):
224-
ans_value_list.append(child.token_mem_index_value)
225-
return self._match_prefix_helper(child, key[prefix_len:], ans_value_list, update_refs=update_refs)
226-
elif prefix_len < len(child.token_id_key):
227-
if child.is_leaf():
228-
self.evict_tree_set.discard(child)
229-
230-
split_parent_node = child.split_node(prefix_len)
231-
ans_value_list.append(split_parent_node.token_mem_index_value)
232-
233-
if update_refs:
234-
split_parent_node.ref_counter += 1
235-
# from 0 to 1 need update refs token num
236-
if split_parent_node.ref_counter == 1:
237-
self.refed_tokens_num.arr[0] += len(split_parent_node.token_mem_index_value)
238-
239-
if child.is_leaf():
240-
self.evict_tree_set.add(child)
241-
if split_parent_node.is_leaf():
242-
self.evict_tree_set.add(split_parent_node)
243-
244-
return split_parent_node
245-
else:
246-
assert False, "error state"
247-
finally:
248-
node.update_time()
249-
if node.is_leaf():
250-
self.evict_tree_set.add(node)
321+
assert False, "error state"
251322

252323
def evict(self, need_remove_tokens, evict_callback):
253324
if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens:
@@ -417,3 +488,7 @@ def get_tree_total_tokens_num(self, dp_rank_in_node):
417488

418489
def get_unrefed_tokens_num(self, dp_rank_in_node):
419490
return self.dp_rank_clients[dp_rank_in_node].get_unrefed_tokens_num()
491+
492+
493+
class _RecursionParams:
494+
pass

lightllm/server/router/model_infer/mode_backend/diverse_backend/impl.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def _master_req_to_radix_cache(self, master_req: InferReq):
165165
value = self.model.req_manager.req_to_token_indexs[master_req.req_idx][: master_req.cur_kv_len].detach().cpu()
166166
prefix_len, new_shared_kv_node = self.radix_cache.insert(key, value)
167167
old_prefix_len = 0 if master_req.shared_kv_node is None else master_req.shared_kv_node.node_prefix_total_len
168+
assert old_prefix_len <= master_req.cur_kv_len
168169
self.model.mem_manager.free(
169170
self.model.req_manager.req_to_token_indexs[master_req.req_idx][old_prefix_len:prefix_len]
170171
)
@@ -173,7 +174,9 @@ def _master_req_to_radix_cache(self, master_req: InferReq):
173174
self.radix_cache.dec_node_ref_counter(master_req.shared_kv_node)
174175
self.radix_cache.add_node_ref_counter(new_shared_kv_node)
175176
master_req.shared_kv_node = new_shared_kv_node
176-
assert new_shared_kv_node.node_prefix_total_len == master_req.cur_kv_len
177+
assert (
178+
new_shared_kv_node.node_prefix_total_len == master_req.cur_kv_len
179+
), f"shared len: {new_shared_kv_node.node_prefix_total_len} cur_kv_len {master_req.cur_kv_len}"
177180

178181
share_node, kv_len, value = self.radix_cache.match_prefix(key, update_refs=False)
179182
assert share_node == new_shared_kv_node and kv_len == master_req.cur_kv_len

0 commit comments

Comments
 (0)