Skip to content

Commit 71f15e1

Browse files
committed
delete useless codes
1 parent aaa8946 commit 71f15e1

File tree

1 file changed

+0
-94
lines changed

1 file changed

+0
-94
lines changed

lightllm/server/router/dynamic_prompt/hiradix_cache.py

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,6 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_
4141
self.hi_cache_kv_buffer = None
4242
self.is_hi_radix_cache = False
4343

44-
# write a new function, only insert input(after prefill), call after prefill,
45-
# then when the decode finishes, do syncronize to see whether this can be free
46-
# no buffer, parallel insert inputs
4744
def insert_disk(self, req_id, key, value):
4845
if not self.do_store:
4946
return
@@ -61,95 +58,17 @@ def abort_req_store_task(self, req_id):
6158
logger.info(f"Aborting req {req_id} unfinished.")
6259
self.py_cache_service.az5(self.working_tasks[req_id])
6360

64-
# TODO: finish this function to only update new ones
65-
def _reinsert_helper(self, node: TreeNode, key, value, ans_value_list: list, update_refs=False):
66-
if node.is_leaf():
67-
self.evict_tree_set.discard(node)
68-
69-
if update_refs:
70-
node.ref_counter += 1
71-
# from 0 to 1 need update refs token num
72-
if node.ref_counter == 1:
73-
self.refed_tokens_num.arr[0] += len(node.token_mem_index_value)
74-
75-
try:
76-
if len(key) == 0:
77-
return node
78-
79-
first_key_id = key[0].item()
80-
if first_key_id in node.children.keys():
81-
child: TreeNode = node.children[first_key_id]
82-
prefix_len = match(key, child.token_id_key)
83-
if prefix_len == len(key):
84-
if child.is_leaf():
85-
self.evict_tree_set.discard(child)
86-
child.update_time()
87-
ans_value_list.append(child.token_mem_index_value)
88-
if child.is_leaf():
89-
self.evict_tree_set.add(child)
90-
return prefix_len
91-
92-
elif prefix_len < len(key) and prefix_len < len(child.token_id_key):
93-
if child.is_leaf():
94-
self.evict_tree_set.discard(child)
95-
96-
key = key[prefix_len:]
97-
value = value[prefix_len:]
98-
split_parent_node = child.split_node(prefix_len)
99-
new_node = split_parent_node.add_and_return_new_child(key, value)
100-
# update total token num
101-
self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value)
102-
103-
if split_parent_node.is_leaf():
104-
self.evict_tree_set.add(split_parent_node)
105-
if new_node.is_leaf():
106-
self.evict_tree_set.add(new_node)
107-
108-
if child.is_leaf():
109-
self.evict_tree_set.add(child)
110-
return prefix_len
111-
elif prefix_len < len(key) and prefix_len == len(child.token_id_key):
112-
return prefix_len + self._insert_helper(child, key[prefix_len:], value[prefix_len:])
113-
else:
114-
assert False, "can not run to here"
115-
116-
else:
117-
new_node = node.add_and_return_new_child(key, value)
118-
# update total token num
119-
self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value)
120-
ans_value_list.append(new_node.token_mem_index_value)
121-
if update_refs:
122-
new_node.ref_counter += 1
123-
if new_node.ref_counter == 1:
124-
self.refed_tokens_num.arr[0] += len(new_node.token_mem_index_value)
125-
if new_node.is_leaf():
126-
self.evict_tree_set.add(new_node)
127-
return new_node
128-
finally:
129-
node.update_time()
130-
if node.is_leaf():
131-
self.evict_tree_set.add(node)
132-
13361
def match_prefix(self, key, update_refs=False):
13462
assert len(key) != 0
13563
ans_value_list = []
13664
pull_hi_cache_tensor = torch.tensor([0], dtype=torch.int64).cuda(self.rank_in_node)
13765
if self.do_store:
138-
# st_time = time.time()
13966
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False)
140-
# add a parameter if get long enough (>50%)
141-
# first_query_time = time.time()
142-
# logger.info(f"HiCache of [{self.rank_in_node}]: No.1 First GPU query took {first_query_time - st_time}s")
14367
max_len = self._query_hi_cache(key) # x64
144-
# hi_cache_q_time = time.time()
145-
# logger.info(f"HiCache of [{self.rank_in_node}]: No.2 Disk query {hi_cache_q_time - first_query_time}s")
14668
logger.info(f"Matched {sum(len(s) for s in ans_value_list)} from gpu and {max_len} from disk.")
14769
pull_hi_cache_tensor[0] = max_len if (max_len > sum(len(s) for s in ans_value_list)) else 0
148-
# hi_cache_q_time = time.time()
14970
dist.broadcast(pull_hi_cache_tensor, src=0)
150-
# logger.info(f"After broadcast on rank {self.rank_in_node}, tensor={pull_hi_cache_tensor}")
15171
pull_hi_cache = False
152-
# logger.info(f"Rank {self.rank_in_node}, {pull_hi_cache=} {pull_hi_cache_tensor=}")
15372

15473
if pull_hi_cache_tensor[0] == 0 and not self.do_store:
15574
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=False)
@@ -166,28 +85,15 @@ def match_prefix(self, key, update_refs=False):
16685
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
16786
if pull_hi_cache:
16887
buffers = self.mem_manager.alloc(max_len)
169-
# before_pull_time = time.time()
170-
# logger.info(
171-
# f"HiCache of [{self.rank_in_node}]: No.2.5 Before pull took {before_pull_time - hi_cache_q_time}"
172-
# )
17388
if self.do_store:
17489
read_task = self.py_cache_service.create(tokens=key[:max_len], kv_page_indexer=buffers, mode="r")
17590
while not read_task.ready():
17691
time.sleep(0.05)
17792
dist.broadcast(self.mem_manager.get_index_kv_buffer(buffers)["kv_buffer"], src=0)
178-
# hicache_pull_time = time.time()
179-
# logger.info(f"HiCache of [{self.rank_in_node}]: No.3 Disk pull {hicache_pull_time - before_pull_time}s")
18093
logger.info(f"HiCache pulled one cache with len = {max_len}")
181-
# maybe try: add a function to only insert middle part of kv cache
18294
self._insert_helper(self.root_node, key, buffers)
183-
# insert_time = time.time()
184-
# logger.info(f"HiCache of [{self.rank_in_node}]: No.4 Reinsert took {insert_time - hicache_pull_time}")
18595
ans_value_list = []
18696
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
187-
# logger.info(
188-
# f"HiCache of [{self.rank_in_node}]: No.5 Re match prefix took {time.time() - insert_time}"
189-
# + f" matched {sum(len(s) for s in ans_value_list)} tokens"
190-
# )
19197
if tree_node != self.root_node:
19298
if len(ans_value_list) != 0:
19399
value = torch.concat(ans_value_list)

0 commit comments

Comments
 (0)