|
| 1 | +import torch |
| 2 | +import torch.distributed as dist |
| 3 | +import rpyc |
| 4 | +import time |
| 5 | +from typing import Dict, List, Tuple, Optional, Union |
| 6 | +from rpyc.utils.classic import obtain |
| 7 | +from .decode_impl import DecodeNode |
| 8 | +from lightllm.common.basemodel.infer_lock import acquire_lock_until_ready, release_acquired_lock, g_router_lock |
| 9 | +from .decode_task_cache import g_kv_move_task_cache, g_success_kv_move_task_cache |
| 10 | +from lightllm.server.pd_io_struct import KVMoveTask |
| 11 | +from lightllm.utils.log_utils import init_logger |
| 12 | + |
| 13 | +logger = init_logger(__name__) |
| 14 | + |
| 15 | + |
| 16 | +class PDDecodeInferRpcServer(rpyc.Service): |
| 17 | + def __init__(self, backend: DecodeNode) -> None: |
| 18 | + super().__init__() |
| 19 | + self.backend = backend |
| 20 | + self.device_id = self.backend.current_device_id |
| 21 | + self.dp_rank_in_node = self.backend.dp_rank_in_node |
| 22 | + self.is_master_in_dp = self.backend.is_master_in_dp |
| 23 | + return |
| 24 | + |
| 25 | + def on_connect(self, conn): |
| 26 | + torch.cuda.set_device(f"cuda:{self.device_id}") |
| 27 | + return |
| 28 | + |
| 29 | + def judge_token_is_ok(self, key_len, max_new_token): |
| 30 | + # 多 dp 单卡模式下, 每个 dp 各自处理自己的, 不需要同步 |
| 31 | + if self.backend.dp_world_size == 1: |
| 32 | + with g_router_lock.obj: |
| 33 | + shared_token_load = self.backend.shared_token_load |
| 34 | + peak_num = shared_token_load.get_estimated_peak_token_count(self.dp_rank_in_node) |
| 35 | + peak_num += shared_token_load.get_frozened_token_count(self.dp_rank_in_node) |
| 36 | + peak_num += key_len + max_new_token |
| 37 | + |
| 38 | + if peak_num < self.backend.get_max_total_token_num(): |
| 39 | + object_list = [True] |
| 40 | + shared_token_load.add_frozened_token_count(key_len + max_new_token, self.dp_rank_in_node) |
| 41 | + else: |
| 42 | + object_list = [False] |
| 43 | + return object_list[0] |
| 44 | + |
| 45 | + # 普通单dp模式下, 只有主 rank 处理信息,并将数据同步到其他rank上 |
| 46 | + if self.is_master_in_dp: |
| 47 | + with g_router_lock.obj: |
| 48 | + shared_token_load = self.backend.shared_token_load |
| 49 | + peak_num = shared_token_load.get_estimated_peak_token_count(self.dp_rank_in_node) |
| 50 | + peak_num += shared_token_load.get_frozened_token_count(self.dp_rank_in_node) |
| 51 | + peak_num += key_len + max_new_token |
| 52 | + |
| 53 | + if peak_num < self.backend.get_max_total_token_num(): |
| 54 | + object_list = [True] |
| 55 | + shared_token_load.add_frozened_token_count(key_len + max_new_token, self.dp_rank_in_node) |
| 56 | + else: |
| 57 | + object_list = [False] |
| 58 | + dist.broadcast_object_list(object_list, src=0, group=self.backend.lock_nccl_group) |
| 59 | + else: |
| 60 | + object_list = [None] |
| 61 | + dist.broadcast_object_list(object_list, src=0, group=self.backend.lock_nccl_group) |
| 62 | + return object_list[0] |
| 63 | + |
| 64 | + def recover_frozen_token(self, key_len, max_new_token): |
| 65 | + if self.is_master_in_dp: |
| 66 | + with g_router_lock.obj: |
| 67 | + shared_token_load = self.backend.shared_token_load |
| 68 | + shared_token_load.add_frozened_token_count(-(key_len + max_new_token), self.dp_rank_in_node) |
| 69 | + return |
| 70 | + |
| 71 | + def _alloc_to_frozen_some_tokens(self, move_task: KVMoveTask): |
| 72 | + is_ok = self.judge_token_is_ok(len(move_task.input_tokens), move_task.decode_node.max_new_tokens) |
| 73 | + if not is_ok: |
| 74 | + if self.is_master_in_dp: |
| 75 | + logger.info(f"req_id: {move_task.to_decode_log_info()} alloc token failed") |
| 76 | + shared_token_load = self.backend.shared_token_load |
| 77 | + dp_rank = self.dp_rank_in_node |
| 78 | + frozen_token_num = shared_token_load.get_frozened_token_count(dp_rank) |
| 79 | + estimated_peak_token_num = shared_token_load.get_estimated_peak_token_count(dp_rank) |
| 80 | + logger.debug( |
| 81 | + f"radix refed token num {self.backend.radix_cache.get_refed_tokens_num()}\n" |
| 82 | + f"radix hold token num {self.backend.radix_cache.get_tree_total_tokens_num()}\n" |
| 83 | + f"mem manager can alloc token num {self.backend.model.mem_manager.can_use_mem_size}\n" |
| 84 | + f"mem manager total size {self.backend.model.mem_manager.size}" |
| 85 | + f"frozened token num {frozen_token_num}\n" |
| 86 | + f"estimated peak token num {estimated_peak_token_num}\n" |
| 87 | + ) |
| 88 | + return None |
| 89 | + |
| 90 | + key = torch.tensor(move_task.input_tokens, dtype=torch.int64, device="cpu") |
| 91 | + tree_node, kv_len, fused_token_indexes = self.backend.radix_cache.match_prefix(key, update_refs=True) |
| 92 | + # 如果没匹配到,说明长度是0, 将fused_token_indexes做一下转换 |
| 93 | + fused_token_indexes = [] if fused_token_indexes is None else fused_token_indexes.tolist() |
| 94 | + need_len = len(move_task.input_tokens) - kv_len |
| 95 | + if need_len == 0: |
| 96 | + alloc_token_indexes = [] |
| 97 | + else: |
| 98 | + self.backend.radix_cache.free_radix_cache_to_get_enough_token(need_len) |
| 99 | + alloc_token_indexes = self.backend.model.mem_manager.alloc(need_len) |
| 100 | + if alloc_token_indexes is not None: |
| 101 | + alloc_token_indexes = alloc_token_indexes.tolist() |
| 102 | + |
| 103 | + if alloc_token_indexes is None: |
| 104 | + self.backend.radix_cache.dec_node_ref_counter(tree_node) |
| 105 | + self.recover_frozen_token(len(move_task.input_tokens), move_task.decode_node.max_new_tokens) |
| 106 | + return None |
| 107 | + |
| 108 | + move_task.decode_token_indexes = alloc_token_indexes |
| 109 | + move_task.move_kv_len = need_len |
| 110 | + |
| 111 | + g_kv_move_task_cache[move_task.group_request_id] = (move_task, tree_node, fused_token_indexes) |
| 112 | + return move_task.decode_token_indexes |
| 113 | + |
| 114 | + # 返回 None 代表服务繁忙已经无法调度新的请求进入了 |
| 115 | + def exposed_alloc_to_frozen_some_tokens(self, move_tasks: List[KVMoveTask]) -> List[Optional[List[int]]]: |
| 116 | + move_tasks = obtain(move_tasks) |
| 117 | + acquire_lock_until_ready(self.backend.lock_nccl_group) |
| 118 | + try: |
| 119 | + ans_list = [] |
| 120 | + for move_task in move_tasks: |
| 121 | + ans_list.append(self._alloc_to_frozen_some_tokens(move_task)) |
| 122 | + return ans_list |
| 123 | + except BaseException as e: |
| 124 | + logger.exception(str(e)) |
| 125 | + return None |
| 126 | + finally: |
| 127 | + release_acquired_lock() |
| 128 | + |
| 129 | + def _put_kv_received_to_radix_cache(self, group_req_id: int): |
| 130 | + move_task, tree_node, fused_token_indexes = g_kv_move_task_cache.pop(group_req_id) |
| 131 | + radix_cache = self.backend.radix_cache |
| 132 | + key = torch.tensor(move_task.input_tokens, dtype=torch.int64, device="cpu") |
| 133 | + value = torch.tensor(fused_token_indexes + move_task.decode_token_indexes, dtype=torch.int64, device="cpu") |
| 134 | + prefix_len = radix_cache.insert(key, value) |
| 135 | + assert len(fused_token_indexes) <= prefix_len |
| 136 | + self.backend.model.mem_manager.free(value[len(fused_token_indexes) : prefix_len]) |
| 137 | + self.backend.radix_cache.dec_node_ref_counter(tree_node) |
| 138 | + |
| 139 | + # 申请一段key,把 radix cache 锁住,防止极端情况下被刷掉, decode 端通过减两次引用计数来修正。 |
| 140 | + tree_node, kv_len, _ = self.backend.radix_cache.match_prefix(key, update_refs=True) |
| 141 | + assert len(key) == kv_len |
| 142 | + g_success_kv_move_task_cache[group_req_id] = (move_task, tree_node, time.time()) |
| 143 | + return |
| 144 | + |
| 145 | + def exposed_put_kv_received_to_radix_cache(self, group_req_ids: List[int]): |
| 146 | + group_req_ids = obtain(group_req_ids) |
| 147 | + acquire_lock_until_ready(self.backend.lock_nccl_group) |
| 148 | + for group_req_id in group_req_ids: |
| 149 | + self._put_kv_received_to_radix_cache(group_req_id) |
| 150 | + release_acquired_lock() |
| 151 | + return |
| 152 | + |
| 153 | + def _fail_to_realese_forzen_tokens(self, group_req_id: int): |
| 154 | + move_task, tree_node, fused_token_indexes = g_kv_move_task_cache.pop(group_req_id) |
| 155 | + value = torch.tensor(move_task.decode_token_indexes, dtype=torch.int64, device="cpu") |
| 156 | + self.backend.model.mem_manager.free(value) |
| 157 | + self.backend.radix_cache.dec_node_ref_counter(tree_node) |
| 158 | + self.recover_frozen_token(len(move_task.input_tokens), move_task.decode_node.max_new_tokens) |
| 159 | + return |
| 160 | + |
| 161 | + def exposed_fail_to_realese_forzen_tokens(self, group_req_ids: List[int]): |
| 162 | + group_req_ids = obtain(group_req_ids) |
| 163 | + acquire_lock_until_ready(self.backend.lock_nccl_group) |
| 164 | + for group_req_id in group_req_ids: |
| 165 | + self._fail_to_realese_forzen_tokens(group_req_id) |
| 166 | + release_acquired_lock() |
| 167 | + return |
| 168 | + |
| 169 | + def exposed_put_mem_manager_to_mem_queue(self): |
| 170 | + self.backend.mem_queue.put(self.backend.model.mem_manager) |
| 171 | + logger.info("put mem manager to info_queues ok") |
| 172 | + return |
| 173 | + |
| 174 | + def exposed_unfrozen_time_out_reqs_tokens(self): |
| 175 | + acquire_lock_until_ready(self.backend.lock_nccl_group) |
| 176 | + if self.backend.dp_world_size == 1: |
| 177 | + need_release_reqs = self._get_time_out_reqs() |
| 178 | + logger.info(f"kv time out reqs: {need_release_reqs}") |
| 179 | + remove_tokens = self._remove_time_out_reqs(need_release_reqs) |
| 180 | + if remove_tokens != 0: |
| 181 | + with g_router_lock.obj: |
| 182 | + self.backend.shared_token_load.add_frozened_token_count(-remove_tokens, self.dp_rank_in_node) |
| 183 | + else: |
| 184 | + if self.is_master_in_dp: |
| 185 | + need_release_reqs = self._get_time_out_reqs() |
| 186 | + logger.info(f"kv time out reqs: {need_release_reqs}") |
| 187 | + dist.broadcast_object_list([need_release_reqs], src=0, group=self.backend.lock_nccl_group) |
| 188 | + else: |
| 189 | + receive_objs = [None] |
| 190 | + dist.broadcast_object_list(receive_objs, src=0, group=self.backend.lock_nccl_group) |
| 191 | + need_release_reqs = receive_objs[0] |
| 192 | + remove_tokens = self._remove_time_out_reqs(need_release_reqs) |
| 193 | + if self.is_master_in_dp and remove_tokens != 0: |
| 194 | + with g_router_lock.obj: |
| 195 | + self.backend.shared_token_load.add_frozened_token_count(-remove_tokens, self.dp_rank_in_node) |
| 196 | + |
| 197 | + release_acquired_lock() |
| 198 | + return |
| 199 | + |
| 200 | + def _get_time_out_reqs(self): |
| 201 | + need_release_reqs = [] |
| 202 | + for req_id, (_, _, time_mark) in g_success_kv_move_task_cache.items(): |
| 203 | + # 6s 这个请求都没有被调度使用,就会主动被删除掉锁定,释放其锁定的token |
| 204 | + if time.time() - time_mark > 6: |
| 205 | + need_release_reqs.append(req_id) |
| 206 | + return need_release_reqs |
| 207 | + |
| 208 | + def _remove_time_out_reqs(self, need_release_reqs: List[int]) -> int: |
| 209 | + remove_tokens = 0 |
| 210 | + for req_id in need_release_reqs: |
| 211 | + task, tree_node, _ = g_success_kv_move_task_cache.pop(req_id) |
| 212 | + self.backend.radix_cache.dec_node_ref_counter(tree_node) |
| 213 | + remove_tokens += len(task.input_tokens) + task.decode_node.max_new_tokens |
| 214 | + return remove_tokens |
0 commit comments