11# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/router/radix_cache.py
22import torch
33import numpy as np
4- from typing import Tuple , Dict , Set , List , Optional
4+ import collections
5+ from typing import Tuple , Dict , Set , List , Optional , Union
56from sortedcontainers import SortedSet
67from .shared_arr import SharedArray
7- from lightllm .common .mem_manager import MemoryManager
88
99
1010class UniqueTimeIdGenerator :
@@ -103,8 +103,10 @@ class RadixCache:
103103 unique_name 主要用于解决单机,多实列部署时的shm冲突
104104 """
105105
106- def __init__ (self , unique_name , total_token_num , rank_in_node , mem_manager : MemoryManager = None ):
107- self .mem_manager = mem_manager
106+ def __init__ (self , unique_name , total_token_num , rank_in_node , mem_manager = None ):
107+ from lightllm .common .mem_manager import MemoryManager
108+
109+ self .mem_manager : MemoryManager = mem_manager
108110 self ._key_dtype = torch .int64
109111 self ._value_dtype = torch .int64
110112
@@ -133,58 +135,100 @@ def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]:
133135 return self ._insert_helper (self .root_node , key , value )
134136
135137 def _insert_helper (self , node : TreeNode , key , value ) -> Tuple [int , Optional [TreeNode ]]:
138+ handle_stack = collections .deque ()
139+ update_list = collections .deque ()
140+ handle_stack .append ((node , key , value ))
141+
142+ ans_prefix_len = 0
143+ ans_node = None
144+
145+ while len (handle_stack ) != 0 :
146+ node , key , value = handle_stack .popleft ()
147+ ans_tuple = self ._insert_helper_no_recursion (node = node , key = key , value = value )
148+ if len (ans_tuple ) == 4 :
149+ (_prefix_len , new_node , new_key , new_value ) = ans_tuple
150+ ans_prefix_len += _prefix_len
151+ handle_stack .append ((new_node , new_key , new_value ))
152+ else :
153+ _prefix_len , ans_node = ans_tuple
154+ ans_prefix_len += _prefix_len
155+
156+ update_list .append (node )
157+
158+ while len (update_list ) != 0 :
159+ cur_node : TreeNode = update_list .pop ()
160+ cur_node .update_time ()
161+ if cur_node .is_leaf ():
162+ self .evict_tree_set .add (cur_node )
163+
164+ assert ans_node is not None
165+
166+ return ans_prefix_len , ans_node
167+
168+ def _insert_helper_no_recursion (
169+ self , node : TreeNode , key : torch .Tensor , value : torch .Tensor
170+ ) -> Union [Tuple [int , Optional [TreeNode ]], Tuple [int , TreeNode , torch .Tensor , torch .Tensor ]]:
136171 if node .is_leaf ():
137172 self .evict_tree_set .discard (node )
138173
139- try :
140- first_key_id = key [ 0 ]. item ()
141- if first_key_id in node .children . keys ():
142- child : TreeNode = node . children [ first_key_id ]
143- prefix_len = match (key , child . token_id_key )
144- if prefix_len == len (key ):
174+ first_key_id = key [ 0 ]. item ()
175+ if first_key_id in node . children . keys ():
176+ child : TreeNode = node .children [ first_key_id ]
177+ prefix_len = match ( key , child . token_id_key )
178+ if prefix_len == len (key ):
179+ if prefix_len == len (child . token_id_key ):
145180 if child .is_leaf ():
146181 self .evict_tree_set .discard (child )
147182 child .update_time ()
148183 if child .is_leaf ():
149184 self .evict_tree_set .add (child )
150185 return prefix_len , child
151-
152- elif prefix_len < len (key ) and prefix_len < len (child .token_id_key ):
186+ elif prefix_len < len (child .token_id_key ):
153187 if child .is_leaf ():
154188 self .evict_tree_set .discard (child )
155189
156- key = key [prefix_len :]
157- value = value [prefix_len :]
158190 split_parent_node = child .split_node (prefix_len )
159- new_node = split_parent_node .add_and_return_new_child (key , value )
160- # update total token num
161- self .tree_total_tokens_num .arr [0 ] += len (new_node .token_mem_index_value )
162191
163192 if split_parent_node .is_leaf ():
164193 self .evict_tree_set .add (split_parent_node )
165- if new_node .is_leaf ():
166- self .evict_tree_set .add (new_node )
167-
168194 if child .is_leaf ():
169195 self .evict_tree_set .add (child )
170- return prefix_len , new_node
171- elif prefix_len < len (key ) and prefix_len == len (child .token_id_key ):
172- _prefix_len , ans_node = self ._insert_helper (child , key [prefix_len :], value [prefix_len :])
173- return prefix_len + _prefix_len , ans_node
196+
197+ return prefix_len , split_parent_node
174198 else :
175199 assert False , "can not run to here"
176200
177- else :
178- new_node = node .add_and_return_new_child (key , value )
201+ elif prefix_len < len (key ) and prefix_len < len (child .token_id_key ):
202+ if child .is_leaf ():
203+ self .evict_tree_set .discard (child )
204+
205+ key = key [prefix_len :]
206+ value = value [prefix_len :]
207+ split_parent_node = child .split_node (prefix_len )
208+ new_node = split_parent_node .add_and_return_new_child (key , value )
179209 # update total token num
180210 self .tree_total_tokens_num .arr [0 ] += len (new_node .token_mem_index_value )
211+
212+ if split_parent_node .is_leaf ():
213+ self .evict_tree_set .add (split_parent_node )
181214 if new_node .is_leaf ():
182215 self .evict_tree_set .add (new_node )
183- return 0 , new_node
184- finally :
185- node .update_time ()
186- if node .is_leaf ():
187- self .evict_tree_set .add (node )
216+
217+ if child .is_leaf ():
218+ self .evict_tree_set .add (child )
219+ return prefix_len , new_node
220+ elif prefix_len < len (key ) and prefix_len == len (child .token_id_key ):
221+ return (prefix_len , child , key [prefix_len :], value [prefix_len :])
222+ else :
223+ assert False , "can not run to here"
224+
225+ else :
226+ new_node = node .add_and_return_new_child (key , value )
227+ # update total token num
228+ self .tree_total_tokens_num .arr [0 ] += len (new_node .token_mem_index_value )
229+ if new_node .is_leaf ():
230+ self .evict_tree_set .add (new_node )
231+ return 0 , new_node
188232
189233 def match_prefix (self , key , update_refs = False ):
190234 assert len (key ) != 0
@@ -200,7 +244,39 @@ def match_prefix(self, key, update_refs=False):
200244 self .dec_node_ref_counter (self .root_node )
201245 return None , 0 , None
202246
203- def _match_prefix_helper (self , node : TreeNode , key , ans_value_list : list , update_refs = False ) -> TreeNode :
247+ def _match_prefix_helper (
248+ self , node : TreeNode , key : torch .Tensor , ans_value_list : list , update_refs = False
249+ ) -> TreeNode :
250+ handle_stack = collections .deque ()
251+ update_list = collections .deque ()
252+ handle_stack .append ((node , key ))
253+
254+ ans_node = None
255+
256+ while len (handle_stack ) != 0 :
257+ node , key = handle_stack .popleft ()
258+ ans_tuple = self ._match_prefix_helper_no_recursion (
259+ node = node , key = key , ans_value_list = ans_value_list , update_refs = update_refs
260+ )
261+ if isinstance (ans_tuple , tuple ):
262+ new_node , new_key = ans_tuple
263+ handle_stack .append ((new_node , new_key ))
264+ else :
265+ ans_node = ans_tuple
266+
267+ update_list .append (node )
268+
269+ while len (update_list ) != 0 :
270+ cur_node : TreeNode = update_list .pop ()
271+ cur_node .update_time ()
272+ if cur_node .is_leaf ():
273+ self .evict_tree_set .add (cur_node )
274+
275+ return ans_node
276+
277+ def _match_prefix_helper_no_recursion (
278+ self , node : TreeNode , key : torch .Tensor , ans_value_list : list , update_refs = False
279+ ) -> TreeNode :
204280 if node .is_leaf ():
205281 self .evict_tree_set .discard (node )
206282
@@ -210,44 +286,39 @@ def _match_prefix_helper(self, node: TreeNode, key, ans_value_list: list, update
210286 if node .ref_counter == 1 :
211287 self .refed_tokens_num .arr [0 ] += len (node .token_mem_index_value )
212288
213- try :
214- if len (key ) == 0 :
215- return node
289+ if len (key ) == 0 :
290+ return node
216291
217- first_key_id = key [0 ].item ()
218- if first_key_id not in node .children .keys ():
219- return node
292+ first_key_id = key [0 ].item ()
293+ if first_key_id not in node .children .keys ():
294+ return node
295+ else :
296+ child = node .children [first_key_id ]
297+ prefix_len = match (key , child .token_id_key )
298+ if prefix_len == len (child .token_id_key ):
299+ ans_value_list .append (child .token_mem_index_value )
300+ return (child , key [prefix_len :])
301+ elif prefix_len < len (child .token_id_key ):
302+ if child .is_leaf ():
303+ self .evict_tree_set .discard (child )
304+
305+ split_parent_node = child .split_node (prefix_len )
306+ ans_value_list .append (split_parent_node .token_mem_index_value )
307+
308+ if update_refs :
309+ split_parent_node .ref_counter += 1
310+ # from 0 to 1 need update refs token num
311+ if split_parent_node .ref_counter == 1 :
312+ self .refed_tokens_num .arr [0 ] += len (split_parent_node .token_mem_index_value )
313+
314+ if child .is_leaf ():
315+ self .evict_tree_set .add (child )
316+ if split_parent_node .is_leaf ():
317+ self .evict_tree_set .add (split_parent_node )
318+
319+ return split_parent_node
220320 else :
221- child = node .children [first_key_id ]
222- prefix_len = match (key , child .token_id_key )
223- if prefix_len == len (child .token_id_key ):
224- ans_value_list .append (child .token_mem_index_value )
225- return self ._match_prefix_helper (child , key [prefix_len :], ans_value_list , update_refs = update_refs )
226- elif prefix_len < len (child .token_id_key ):
227- if child .is_leaf ():
228- self .evict_tree_set .discard (child )
229-
230- split_parent_node = child .split_node (prefix_len )
231- ans_value_list .append (split_parent_node .token_mem_index_value )
232-
233- if update_refs :
234- split_parent_node .ref_counter += 1
235- # from 0 to 1 need update refs token num
236- if split_parent_node .ref_counter == 1 :
237- self .refed_tokens_num .arr [0 ] += len (split_parent_node .token_mem_index_value )
238-
239- if child .is_leaf ():
240- self .evict_tree_set .add (child )
241- if split_parent_node .is_leaf ():
242- self .evict_tree_set .add (split_parent_node )
243-
244- return split_parent_node
245- else :
246- assert False , "error state"
247- finally :
248- node .update_time ()
249- if node .is_leaf ():
250- self .evict_tree_set .add (node )
321+ assert False , "error state"
251322
252323 def evict (self , need_remove_tokens , evict_callback ):
253324 if self .tree_total_tokens_num .arr [0 ] - self .refed_tokens_num .arr [0 ] < need_remove_tokens :
@@ -417,3 +488,7 @@ def get_tree_total_tokens_num(self, dp_rank_in_node):
417488
418489 def get_unrefed_tokens_num (self , dp_rank_in_node ):
419490 return self .dp_rank_clients [dp_rank_in_node ].get_unrefed_tokens_num ()
491+
492+
493+ class _RecursionParams :
494+ pass
0 commit comments