Skip to content

Commit d42f5ee

Browse files
committed
feat: add radix prefix hit rate log
1 parent 256b1fe commit d42f5ee

File tree

2 files changed

+38
-17
lines changed

2 files changed

+38
-17
lines changed

lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,6 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None)
3333
super().__init__(unique_name, total_token_num, rank_in_node, mem_manager)
3434
# 用于缓存需要被驱逐的buffer节点, 应该包含所有有buffer的节点
3535
self.evict_buffer_set: Set[TreeNode] = SortedSet(key=lambda x: (x.buffer_time,))
36-
self.match_count = 0
37-
self.log_interval = 1000
38-
self.match_len = 0
39-
self.hit_len = 0
4036

4137
def free_radix_cache_to_get_enough_buffer(self, need_buffer_num):
4238
if need_buffer_num > self.mem_manager.get_buffer_can_use_size():
@@ -112,8 +108,6 @@ def insert_for_hybrid_radix_cache(self, reqs):
112108

113109
def match_prefix(self, key, update_refs=False):
114110
assert len(key) != 0
115-
self.match_count = (self.match_count + 1) % self.log_interval
116-
self.match_len += len(key)
117111
ans_value_list = []
118112
tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
119113
origin_ans_len = sum(len(v) for v in ans_value_list)
@@ -145,6 +139,7 @@ def match_prefix(self, key, update_refs=False):
145139
self.mem_manager.free(evict_token_value)
146140

147141
if tree_node == self.root_node:
142+
self._inc_hit_rate(len(key), 0)
148143
return None, origin_ans_len, None
149144

150145
update_node = tree_node
@@ -156,16 +151,7 @@ def match_prefix(self, key, update_refs=False):
156151
update_node = update_node.parent
157152

158153
value = torch.concat(ans_value_list)
159-
# logger.info("HybridRadixCache match_prefix hit tokens: {}".format(len(value)))
160-
self.hit_len += len(value)
161-
if self.match_count == 0:
162-
logger.info(
163-
f"HybridRadixCache match_prefix avg hit rate: {self.hit_len / self.match_len:.4f} "
164-
f"({self.hit_len}/{self.match_len}) over last {self.log_interval} matches"
165-
)
166-
self.match_len = 0
167-
self.hit_len = 0
168-
154+
self._inc_hit_rate(len(key), len(value))
169155
return tree_node, origin_ans_len, value
170156

171157
def add_buffer_idx_to_node(self, node: TreeNode, buffer_idx: int):

lightllm/server/router/dynamic_prompt/radix_cache.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from typing import Tuple, Dict, Set, List, Optional, Union
66
from sortedcontainers import SortedSet
77
from .shared_arr import SharedArray
8+
from lightllm.utils.log_utils import init_logger, log_time_ready
9+
10+
logger = init_logger(__name__)
811

912

1013
class UniqueTimeIdGenerator:
@@ -135,6 +138,34 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None)
135138
)
136139
self.tree_total_tokens_num.arr[0] = 0
137140

141+
self.total_query_tokens = SharedArray(f"{unique_name}_total_query_tokens_{rank_in_node}", (1,), dtype=np.int64)
142+
self.total_query_tokens.arr[0] = 0
143+
self.total_hit_tokens = SharedArray(f"{unique_name}_total_hit_tokens_{rank_in_node}", (1,), dtype=np.int64)
144+
self.total_hit_tokens.arr[0] = 0
145+
self.last_log_query_tokens = 0
146+
self.last_log_hit_tokens = 0
147+
148+
def _inc_hit_rate(self, query_len, hit_len):
149+
self.total_query_tokens.arr[0] += query_len
150+
self.total_hit_tokens.arr[0] += hit_len
151+
if log_time_ready("radix_cache_hit_rate", time_count=30):
152+
current_total_query = self.total_query_tokens.arr[0]
153+
current_total_hit = self.total_hit_tokens.arr[0]
154+
window_query = current_total_query - self.last_log_query_tokens
155+
window_hit = current_total_hit - self.last_log_hit_tokens
156+
window_hit_rate = window_hit / window_query if window_query > 0 else 0.0
157+
cumulative_hit_rate = current_total_hit / current_total_query if current_total_query > 0 else 0.0
158+
159+
label = self.__class__.__name__
160+
logger.info(
161+
f"{label} Hit Rate: "
162+
f"Window {window_hit_rate:.2%} ({window_hit}/{window_query}), "
163+
f"Cumulative {cumulative_hit_rate:.2%} ({current_total_hit}/{current_total_query})"
164+
)
165+
166+
self.last_log_query_tokens = current_total_query
167+
self.last_log_hit_tokens = current_total_hit
168+
138169
def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]:
139170
if value is None:
140171
value = key
@@ -248,9 +279,13 @@ def match_prefix(self, key, update_refs=False):
248279
value = torch.concat(ans_value_list)
249280
else:
250281
value = torch.zeros((0,), device="cpu", dtype=self._value_dtype)
251-
return tree_node, len(value), value
282+
283+
matched_len = len(value)
284+
self._inc_hit_rate(len(key), matched_len)
285+
return tree_node, matched_len, value
252286
else:
253287
self.dec_node_ref_counter(self.root_node)
288+
self._inc_hit_rate(len(key), 0)
254289
return None, 0, None
255290

256291
def _match_prefix_helper(

0 commit comments

Comments
 (0)