@@ -41,9 +41,6 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_
4141 self .hi_cache_kv_buffer = None
4242 self .is_hi_radix_cache = False
4343
44- # write a new function, only insert input(after prefill), call after prefill,
45- # then when the decode finishes, do syncronize to see whether this can be free
46- # no buffer, parallel insert inputs
4744 def insert_disk (self , req_id , key , value ):
4845 if not self .do_store :
4946 return
@@ -61,95 +58,17 @@ def abort_req_store_task(self, req_id):
6158 logger .info (f"Aborting req { req_id } unfinished." )
6259 self .py_cache_service .az5 (self .working_tasks [req_id ])
6360
64- # TODO: finish this function to only update new ones
65- def _reinsert_helper (self , node : TreeNode , key , value , ans_value_list : list , update_refs = False ):
66- if node .is_leaf ():
67- self .evict_tree_set .discard (node )
68-
69- if update_refs :
70- node .ref_counter += 1
71- # from 0 to 1 need update refs token num
72- if node .ref_counter == 1 :
73- self .refed_tokens_num .arr [0 ] += len (node .token_mem_index_value )
74-
75- try :
76- if len (key ) == 0 :
77- return node
78-
79- first_key_id = key [0 ].item ()
80- if first_key_id in node .children .keys ():
81- child : TreeNode = node .children [first_key_id ]
82- prefix_len = match (key , child .token_id_key )
83- if prefix_len == len (key ):
84- if child .is_leaf ():
85- self .evict_tree_set .discard (child )
86- child .update_time ()
87- ans_value_list .append (child .token_mem_index_value )
88- if child .is_leaf ():
89- self .evict_tree_set .add (child )
90- return prefix_len
91-
92- elif prefix_len < len (key ) and prefix_len < len (child .token_id_key ):
93- if child .is_leaf ():
94- self .evict_tree_set .discard (child )
95-
96- key = key [prefix_len :]
97- value = value [prefix_len :]
98- split_parent_node = child .split_node (prefix_len )
99- new_node = split_parent_node .add_and_return_new_child (key , value )
100- # update total token num
101- self .tree_total_tokens_num .arr [0 ] += len (new_node .token_mem_index_value )
102-
103- if split_parent_node .is_leaf ():
104- self .evict_tree_set .add (split_parent_node )
105- if new_node .is_leaf ():
106- self .evict_tree_set .add (new_node )
107-
108- if child .is_leaf ():
109- self .evict_tree_set .add (child )
110- return prefix_len
111- elif prefix_len < len (key ) and prefix_len == len (child .token_id_key ):
112- return prefix_len + self ._insert_helper (child , key [prefix_len :], value [prefix_len :])
113- else :
114- assert False , "can not run to here"
115-
116- else :
117- new_node = node .add_and_return_new_child (key , value )
118- # update total token num
119- self .tree_total_tokens_num .arr [0 ] += len (new_node .token_mem_index_value )
120- ans_value_list .append (new_node .token_mem_index_value )
121- if update_refs :
122- new_node .ref_counter += 1
123- if new_node .ref_counter == 1 :
124- self .refed_tokens_num .arr [0 ] += len (new_node .token_mem_index_value )
125- if new_node .is_leaf ():
126- self .evict_tree_set .add (new_node )
127- return new_node
128- finally :
129- node .update_time ()
130- if node .is_leaf ():
131- self .evict_tree_set .add (node )
132-
13361 def match_prefix (self , key , update_refs = False ):
13462 assert len (key ) != 0
13563 ans_value_list = []
13664 pull_hi_cache_tensor = torch .tensor ([0 ], dtype = torch .int64 ).cuda (self .rank_in_node )
13765 if self .do_store :
138- # st_time = time.time()
13966 tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = False )
140- # add a parameter if get long enough (>50%)
141- # first_query_time = time.time()
142- # logger.info(f"HiCache of [{self.rank_in_node}]: No.1 First GPU query took {first_query_time - st_time}s")
14367 max_len = self ._query_hi_cache (key ) # x64
144- # hi_cache_q_time = time.time()
145- # logger.info(f"HiCache of [{self.rank_in_node}]: No.2 Disk query {hi_cache_q_time - first_query_time}s")
14668 logger .info (f"Matched { sum (len (s ) for s in ans_value_list )} from gpu and { max_len } from disk." )
14769 pull_hi_cache_tensor [0 ] = max_len if (max_len > sum (len (s ) for s in ans_value_list )) else 0
148- # hi_cache_q_time = time.time()
14970 dist .broadcast (pull_hi_cache_tensor , src = 0 )
150- # logger.info(f"After broadcast on rank {self.rank_in_node}, tensor={pull_hi_cache_tensor}")
15171 pull_hi_cache = False
152- # logger.info(f"Rank {self.rank_in_node}, {pull_hi_cache=} {pull_hi_cache_tensor=}")
15372
15473 if pull_hi_cache_tensor [0 ] == 0 and not self .do_store :
15574 tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = False )
@@ -166,28 +85,15 @@ def match_prefix(self, key, update_refs=False):
16685 tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = update_refs )
16786 if pull_hi_cache :
16887 buffers = self .mem_manager .alloc (max_len )
169- # before_pull_time = time.time()
170- # logger.info(
171- # f"HiCache of [{self.rank_in_node}]: No.2.5 Before pull took {before_pull_time - hi_cache_q_time}"
172- # )
17388 if self .do_store :
17489 read_task = self .py_cache_service .create (tokens = key [:max_len ], kv_page_indexer = buffers , mode = "r" )
17590 while not read_task .ready ():
17691 time .sleep (0.05 )
17792 dist .broadcast (self .mem_manager .get_index_kv_buffer (buffers )["kv_buffer" ], src = 0 )
178- # hicache_pull_time = time.time()
179- # logger.info(f"HiCache of [{self.rank_in_node}]: No.3 Disk pull {hicache_pull_time - before_pull_time}s")
18093 logger .info (f"HiCache pulled one cache with len = { max_len } " )
181- # maybe try: add a function to only insert middle part of kv cache
18294 self ._insert_helper (self .root_node , key , buffers )
183- # insert_time = time.time()
184- # logger.info(f"HiCache of [{self.rank_in_node}]: No.4 Reinsert took {insert_time - hicache_pull_time}")
18595 ans_value_list = []
18696 tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = update_refs )
187- # logger.info(
188- # f"HiCache of [{self.rank_in_node}]: No.5 Re match prefix took {time.time() - insert_time}"
189- # + f" matched {sum(len(s) for s in ans_value_list)} tokens"
190- # )
19197 if tree_node != self .root_node :
19298 if len (ans_value_list ) != 0 :
19399 value = torch .concat (ans_value_list )
0 commit comments