|
5 | 5 | from typing import Tuple, Dict, Set, List, Optional, Union |
6 | 6 | from sortedcontainers import SortedSet |
7 | 7 | from .shared_arr import SharedArray |
| 8 | +from lightllm.utils.log_utils import init_logger, log_time_ready |
| 9 | + |
| 10 | +logger = init_logger(__name__) |
8 | 11 |
|
9 | 12 |
|
10 | 13 | class UniqueTimeIdGenerator: |
@@ -135,6 +138,34 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None) |
135 | 138 | ) |
136 | 139 | self.tree_total_tokens_num.arr[0] = 0 |
137 | 140 |
|
| 141 | + self.total_query_tokens = SharedArray(f"{unique_name}_total_query_tokens_{rank_in_node}", (1,), dtype=np.int64) |
| 142 | + self.total_query_tokens.arr[0] = 0 |
| 143 | + self.total_hit_tokens = SharedArray(f"{unique_name}_total_hit_tokens_{rank_in_node}", (1,), dtype=np.int64) |
| 144 | + self.total_hit_tokens.arr[0] = 0 |
| 145 | + self.last_log_query_tokens = 0 |
| 146 | + self.last_log_hit_tokens = 0 |
| 147 | + |
| 148 | + def _inc_hit_rate(self, query_len, hit_len): |
| 149 | + self.total_query_tokens.arr[0] += query_len |
| 150 | + self.total_hit_tokens.arr[0] += hit_len |
| 151 | + if log_time_ready("radix_cache_hit_rate", time_count=30): |
| 152 | + current_total_query = self.total_query_tokens.arr[0] |
| 153 | + current_total_hit = self.total_hit_tokens.arr[0] |
| 154 | + window_query = current_total_query - self.last_log_query_tokens |
| 155 | + window_hit = current_total_hit - self.last_log_hit_tokens |
| 156 | + window_hit_rate = window_hit / window_query if window_query > 0 else 0.0 |
| 157 | + cumulative_hit_rate = current_total_hit / current_total_query if current_total_query > 0 else 0.0 |
| 158 | + |
| 159 | + label = self.__class__.__name__ |
| 160 | + logger.info( |
| 161 | + f"{label} Hit Rate: " |
| 162 | + f"Window {window_hit_rate:.2%} ({window_hit}/{window_query}), " |
| 163 | + f"Cumulative {cumulative_hit_rate:.2%} ({current_total_hit}/{current_total_query})" |
| 164 | + ) |
| 165 | + |
| 166 | + self.last_log_query_tokens = current_total_query |
| 167 | + self.last_log_hit_tokens = current_total_hit |
| 168 | + |
138 | 169 | def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]: |
139 | 170 | if value is None: |
140 | 171 | value = key |
@@ -248,9 +279,13 @@ def match_prefix(self, key, update_refs=False): |
248 | 279 | value = torch.concat(ans_value_list) |
249 | 280 | else: |
250 | 281 | value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) |
251 | | - return tree_node, len(value), value |
| 282 | + |
| 283 | + matched_len = len(value) |
| 284 | + self._inc_hit_rate(len(key), matched_len) |
| 285 | + return tree_node, matched_len, value |
252 | 286 | else: |
253 | 287 | self.dec_node_ref_counter(self.root_node) |
| 288 | + self._inc_hit_rate(len(key), 0) |
254 | 289 | return None, 0, None |
255 | 290 |
|
256 | 291 | def _match_prefix_helper( |
|
0 commit comments