2626from lightllm .utils .dist_utils import get_dp_world_size , get_global_dp_rank , get_current_rank_in_dp
2727from lightllm .utils .dist_utils import get_current_device_id , get_current_rank_in_node , get_node_world_size
2828from lightllm .utils .dist_utils import get_dp_rank_in_node , create_new_group_for_current_node
29- from lightllm .utils .envs_utils import get_env_start_args
29+ from lightllm .utils .envs_utils import (
30+ get_env_start_args ,
31+ enable_radix_tree_timer_merge ,
32+ get_radix_tree_merge_update_delta ,
33+ )
3034from lightllm .distributed import dist_group_manager
3135from lightllm .server .core .objs .shm_objs_io_buffer import ShmObjsIOBuffer
3236from lightllm .server .router .model_infer .mode_backend .overlap_events import OverlapEventManager , OverlapEventPack
@@ -61,6 +65,11 @@ def __init__(self) -> None:
6165
6266 # nixl pd mode callback func
6367 self .nixl_prefill_chuncked_handle_func : Optional [Callable [[InferReq , int , float , int ], None ]] = None
68+
69+ # counter
70+ self ._radix_tree_merge_counter : int = 0
71+ self ._enable_radix_tree_timer_merge : bool = enable_radix_tree_timer_merge ()
72+ self ._radix_tree_merge_update_delta : int = get_radix_tree_merge_update_delta ()
6473 pass
6574
6675 def init_model (self , kvargs ):
@@ -439,6 +448,20 @@ def _filter_not_ready_reqs(self, req_ids: List[int]) -> List[InferReq]:
439448 """
440449 return [g_infer_context .requests_mapping [request_id ] for request_id in req_ids ]
441450
451+ def _timer_merge_radix_tree (self ):
452+ self ._radix_tree_merge_counter += 1
453+ if (
454+ self ._enable_radix_tree_timer_merge
455+ and (self ._radix_tree_merge_counter % self ._radix_tree_merge_update_delta == 0 )
456+ and self .radix_cache is not None
457+ ):
458+ g_infer_state_lock .acquire ()
459+ start = time .time ()
460+ self .radix_cache .merge_unreferenced_nodes ()
461+ self .logger .info (f"radix tree merge_unreferenced_nodes cost time { time .time () - start } s" )
462+ g_infer_state_lock .release ()
463+ return
464+
442465 # 一些可以复用的通用功能函数
443466 def _get_classed_reqs (
444467 self ,
@@ -465,6 +488,9 @@ def _get_classed_reqs(
465488 4. prefill_reqs 需要进行prefill操作的请求
466489 5. decode_reqs 需要进行decode操作的请求
467490 """
491+ # 定期对 radix cache 进行 merge,防止查询插入的操作效率下降
492+ self ._timer_merge_radix_tree ()
493+
468494 if self .args .enable_cpu_cache and len (g_infer_context .infer_req_ids ) > 0 :
469495 self .multi_level_cache_module .update_cpu_cache_task_states ()
470496
0 commit comments