@@ -17,6 +17,9 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_
1717 logger .info ("Initializing HiRadixCache" )
1818 self .rank_in_node = rank_in_node
1919 try :
20+ # TODO: determine by model type && dp, tp
21+ store_once = True # Deepseek -> True, Llama -> False
22+ self .do_store = store_once and self .rank_in_node == 0
2023 self .is_hi_radix_cache = True
2124 all_buffers = self .mem_manager .kv_buffer
2225 all_buffers = all_buffers .view (all_buffers .shape [0 ], all_buffers .shape [1 ], - 1 )
@@ -37,83 +40,111 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager, max_
3740 # then when the decode finishes, do syncronize to see whether this can be free
3841 # no buffer, parallel insert inputs
3942 def insert_disk (self , req_id , key , value ):
43+ if not self .do_store :
44+ return
4045 if req_id in self .working_tasks :
41- self .wait_till_finish (req_id )
46+ self .abort_req_store_task (req_id )
4247 self .working_tasks [req_id ] = self .py_cache_service .create (tokens = key , kv_page_indexer = value , mode = "w" )
4348 logger .info (f"Created store task for req { req_id } ." )
4449
45- def wait_till_finish (self , req_id ):
46- if req_id not in self .working_tasks :
50+ def abort_req_store_task (self , req_id ):
51+ if not self .do_store :
52+ return
53+ if self .working_tasks [req_id ].ready ():
54+ logger .info (f"Calling abort for req { req_id } , but is finished." )
4755 return
48- starting_time = time .time ()
49- while not self .working_tasks [req_id ].ready ():
50- time .sleep (0.01 )
51- logger .info (f"Waited { time .time () - starting_time } s for req { req_id } ." )
52-
53- # def insert(self, key, value=None):
54- # if value is None:
55- # value = key
56-
57- # assert len(key) == len(value) # and len(key) >= 1
58- # if len(key) == 0:
59- # return 0
60-
61- # # current implement is serial, TODO: make it parallel
62- # # if no hi_cache_buffer, work with normal radix cache
63- # if self.hi_cache_kv_buffer is not None:
64- # do_copy = False
65- # # and if is moving, ignore this insert request
66- # with self.moving_lock:
67- # if (not self.start_store_task) and self.write_task is not None:
68- # if self.write_task.ready():
69- # logger.info(f"HiCache of [{self.rank_in_node}]: stored len = {self.hi_cache_buffer_len}")
70- # self.start_store_task = True # ensure ready => start new only one kvcache stores
71- # do_copy = True
72- # elif self.write_task is None and self.starting:
73- # self.starting = False
74- # self.start_store_task = True
75- # do_copy = True
76-
77- # if do_copy:
78- # # copy the key and value to the hi_cache_buffer
79- # self.hi_cache_key_buffer[:len(key)].copy_(key)
80- # self.hi_cache_buffer_len = len(key)
81- # for buffer_index, index in enumerate(value):
82- # kv_data = self.mem_manager.get_index_kv_buffer(index)
83- # self.mem_manager.load_index_kv_buffer(self.hi_cache_kv_buffer[buffer_index], kv_data)
84- # # create a new thread to store the buffer
85- # self._store_buffer()
86-
87- # return self._insert_helper(self.root_node, key, value)
88-
89- # def _store_buffer(self):
90- # logger.info(f"Storing buffer size = {self.hi_cache_buffer_len}")
91- # assert self.hi_cache_buffer_len > 0
92- # assert self.hi_cache_kv_buffer is not None
93- # key = self.hi_cache_key_buffer[:self.hi_cache_buffer_len].tolist()
94- # self.write_task = self.py_cache_service.create(
95- # tokens=key, kv_page_indexer=self.hi_cache_kv_buffer[:self.hi_cache_buffer_len], mode="w")
96- # with self.moving_lock:
97- # self.start_store_task = False
56+ logger .info (f"Aborting req { req_id } unfinished." )
57+ self .py_cache_service .az5 (self .working_tasks [req_id ])
58+
59+ # TODO: finish this function to only update new ones
60+ def _reinsert_helper (self , node : TreeNode , key , value , ans_value_list : list , update_refs = False ):
61+ if node .is_leaf ():
62+ self .evict_tree_set .discard (node )
63+
64+ if update_refs :
65+ node .ref_counter += 1
66+ # from 0 to 1 need update refs token num
67+ if node .ref_counter == 1 :
68+ self .refed_tokens_num .arr [0 ] += len (node .token_mem_index_value )
69+
70+ try :
71+ if len (key ) == 0 :
72+ return node
73+
74+ first_key_id = key [0 ].item ()
75+ if first_key_id in node .children .keys ():
76+ child : TreeNode = node .children [first_key_id ]
77+ prefix_len = match (key , child .token_id_key )
78+ if prefix_len == len (key ):
79+ if child .is_leaf ():
80+ self .evict_tree_set .discard (child )
81+ child .update_time ()
82+ ans_value_list .append (child .token_mem_index_value )
83+ if child .is_leaf ():
84+ self .evict_tree_set .add (child )
85+ return prefix_len
86+
87+ elif prefix_len < len (key ) and prefix_len < len (child .token_id_key ):
88+ if child .is_leaf ():
89+ self .evict_tree_set .discard (child )
90+
91+ key = key [prefix_len :]
92+ value = value [prefix_len :]
93+ split_parent_node = child .split_node (prefix_len )
94+ new_node = split_parent_node .add_and_return_new_child (key , value )
95+ # update total token num
96+ self .tree_total_tokens_num .arr [0 ] += len (new_node .token_mem_index_value )
97+
98+ if split_parent_node .is_leaf ():
99+ self .evict_tree_set .add (split_parent_node )
100+ if new_node .is_leaf ():
101+ self .evict_tree_set .add (new_node )
102+
103+ if child .is_leaf ():
104+ self .evict_tree_set .add (child )
105+ return prefix_len
106+ elif prefix_len < len (key ) and prefix_len == len (child .token_id_key ):
107+ return prefix_len + self ._insert_helper (child , key [prefix_len :], value [prefix_len :])
108+ else :
109+ assert False , "can not run to here"
110+
111+ else :
112+ new_node = node .add_and_return_new_child (key , value )
113+ # update total token num
114+ self .tree_total_tokens_num .arr [0 ] += len (new_node .token_mem_index_value )
115+ ans_value_list .append (new_node .token_mem_index_value )
116+ if update_refs :
117+ new_node .ref_counter += 1
118+ if new_node .ref_counter == 1 :
119+ self .refed_tokens_num .arr [0 ] += len (new_node .token_mem_index_value )
120+ if new_node .is_leaf ():
121+ self .evict_tree_set .add (new_node )
122+ return new_node
123+ finally :
124+ node .update_time ()
125+ if node .is_leaf ():
126+ self .evict_tree_set .add (node )
98127
99128 def match_prefix (self , key , update_refs = False ):
100129 st_time = time .time ()
101130 assert len (key ) != 0
102131 ans_value_list = []
103- tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = update_refs )
132+ tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = False )
104133 # add a parameter if get long enough (>50%)
105134 first_query_time = time .time ()
106135 logger .info (f"HiCache of [{ self .rank_in_node } ]: No.1 First GPU query took { first_query_time - st_time } " )
107136 max_len = self ._query_hi_cache (key ) # x64
108137 hi_cache_query_time = time .time ()
109138 logger .info (f"HiCache of [{ self .rank_in_node } ]: No.2 Disk query took { hi_cache_query_time - first_query_time } " )
110- logger .info (f"Matched { len (ans_value_list )} from gpu and { max_len } from disk." )
139+ logger .info (f"Matched { sum ( len (s ) for s in ans_value_list )} from gpu and { max_len } from disk." )
111140 pull_hi_cache = False
112- if max_len > len (ans_value_list ):
141+ if max_len > sum ( len (s ) for s in ans_value_list ):
113142 pull_hi_cache = True
114143 try :
115144 self .free_radix_cache_to_get_enough_token (max_len )
116145 except :
146+ if update_refs :
147+ tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = update_refs )
117148 pull_hi_cache = False
118149 if pull_hi_cache :
119150 buffers = self .mem_manager .alloc (max_len )
@@ -133,7 +164,10 @@ def match_prefix(self, key, update_refs=False):
133164 logger .info (f"HiCache of [{ self .rank_in_node } ]: No.4 Reinsert took { insert_time - hicache_pull_time } " )
134165 ans_value_list = []
135166 tree_node = self ._match_prefix_helper (self .root_node , key , ans_value_list , update_refs = update_refs )
136- logger .info (f"HiCache of [{ self .rank_in_node } ]: No.5 Re match prefix took { time .time () - insert_time } " )
167+ logger .info (
168+ f"HiCache of [{ self .rank_in_node } ]: No.5 Re match prefix took { time .time () - insert_time } "
169+ + f" matched { sum (len (s ) for s in ans_value_list )} tokens"
170+ )
137171 if tree_node != self .root_node :
138172 if len (ans_value_list ) != 0 :
139173 value = torch .concat (ans_value_list )
0 commit comments