@@ -28,7 +28,7 @@ class HybridRadixCache(RadixCache):
2828 def __init__ (self , unique_name , total_token_num , rank_in_node , mem_manager = None ):
2929 self .mem_manager : HybridMemManager = mem_manager
3030 super ().__init__ (unique_name , total_token_num , rank_in_node , mem_manager )
31- self .evict_buffer_set : Set [TreeNode ] = SortedSet (key = lambda x : x .time_id )
31+ self .evict_buffer_set : Set [TreeNode ] = SortedSet (key = lambda x : ( x .time_id ,) )
3232
3333 def free_radix_cache_to_get_enough_buffer (self , need_buffer_num ):
3434 if need_buffer_num > self .mem_manager .get_buffer_can_use_size ():
@@ -85,7 +85,8 @@ def insert_for_hybrid_radix_cache(self, reqs):
8585 new_shared_kv_node .buffer_idx = new_buffer_indexes [i ]
8686 self .dec_node_ref_counter (req .shared_kv_node )
8787 self .add_node_ref_counter (new_shared_kv_node )
88- self .evict_buffer_set .add (req .shared_kv_node )
88+ if req .shared_kv_node is not None and req .shared_kv_node .buffer_idx is not None :
89+ self .update_buffer_evict_set (req .shared_kv_node )
8990 req .shared_kv_node = new_shared_kv_node
9091
9192 def match_prefix (self , key , update_refs = False ):
@@ -105,6 +106,7 @@ def match_prefix(self, key, update_refs=False):
105106 return None , 0 , None
106107
107108 value = torch .concat (ans_value_list )
109+ self .update_buffer_evict_set (tree_node )
108110 return tree_node , len (value ), value
109111
110112 def _remove_leaf_node (self , node : TreeNode ):
@@ -116,10 +118,26 @@ def _remove_leaf_node(self, node: TreeNode):
116118 if parent_node .is_leaf ():
117119 self .evict_tree_set .add (parent_node )
118120 if parent_node .buffer_idx is not None :
119- self .evict_buffer_set . add (parent_node )
121+ self .update_buffer_evict_set (parent_node )
120122 return parent_node
121123
122124 def insert (self , key , value = None ) -> Tuple [int , Optional [TreeNode ]]:
123125 prefix_len , node = super ().insert (key , value )
126+ if node is not None :
127+ node .update_buffer_time ()
124128 self .evict_buffer_set .add (node )
125129 return prefix_len , node
130+
131+ def update_buffer_evict_set (self , node : TreeNode ):
132+ if node is None or node .buffer_idx is None :
133+ return
134+
135+ if node not in self .evict_buffer_set :
136+ self .evict_buffer_set .add (node )
137+ return
138+
139+ self .evict_buffer_set .discard (node )
140+ node .update_buffer_time ()
141+ self .evict_buffer_set .add (node )
142+
143+ self .update_buffer_evict_set (node .parent )
0 commit comments