Skip to content

Commit 5cb554c

Browse files
committed
format
1 parent 47f0248 commit 5cb554c

File tree

7 files changed

+101
-72
lines changed

7 files changed

+101
-72
lines changed

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_kv_move_manager.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
KV_MOVE_MAX_NUM = 16
3232

33+
3334
class DecodeKVMoveManager(rpyc.Service):
3435
def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
3536
super().__init__()
@@ -45,7 +46,7 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
4546
self.mem_queues = mem_queues
4647
self.infer_rpyc_lock = threading.Lock()
4748
self.infer_rpyc_objs: List[PDDecodeInferRpcServer] = []
48-
49+
4950
from .decode_trans_obj import KVTransConnectObj
5051

5152
self.connect_id_to_trans_obj: Dict[str, KVTransConnectObj] = {}
@@ -70,16 +71,16 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
7071

7172
# 在不使用p2p 复制kv 的方案时,需要全局的传输锁进行控制。这个时候kv传输的效率会下降。
7273
self.kv_trans_lock = threading.Lock()
73-
74+
7475
from .decode_trans_obj import KVTransProcess
75-
76+
7677
self.kv_trans_processes: List[KVTransProcess] = [None] * self.node_world_size
7778
for device_id in range(self.node_world_size):
7879
self.kv_trans_processes[device_id] = KVTransProcess()
7980
assert self.kv_trans_processes[device_id].init_all(device_id, self)
8081

8182
return
82-
83+
8384
# ==================================================================================
8485
# _dp_alloc_to_frozen_some_tokens
8586
# _put_kv_received_to_radix_cache
@@ -158,13 +159,13 @@ def _put_mem_manager_to_mem_queue(self) -> None:
158159
for obj in self.infer_rpyc_objs:
159160
obj.put_mem_manager_to_mem_queue()
160161
return
161-
162+
162163
# ==================================================================================
163164
# put_to_fail_release_task_queue 将因为一些原因失败,需要释放锁定的kv资源的请求放入到
164165
# 对应的处理队列中,handle_fail_release_task_loop 是一个循环的线程,专门处理这些失败的请求
165166
# 通过调用与推理进程交互的接口,释放掉申请锁定的 kv 资源。
166167
# ==================================================================================
167-
168+
168169
def put_to_fail_release_task_queue(self, task: Union[KVMoveTask, List[KVMoveTask]]):
169170
if isinstance(task, KVMoveTask):
170171
self.fail_to_release_queue.put(task)
@@ -182,9 +183,9 @@ def handle_fail_release_task_loop(self):
182183
else:
183184
self._fail_to_realese_forzen_tokens(handle_list)
184185
return
185-
186+
186187
# ==================================================================================
187-
# on_connect
188+
# on_connect
188189
# on_disconnect
189190
# exposed_check_alive
190191
# exposed_build_trans_process
@@ -278,12 +279,14 @@ def exposed_request_data_transfer(self, tasks: List[KVMoveTask]) -> List[Optiona
278279
self.remove_trans_obj(tasks[0].connect_id)
279280
logger.exception(str(e))
280281
raise e
281-
282+
282283
if alloc_tokened_tasks:
283-
trans_obj.ready_to_move_queue.put(alloc_tokened_tasks, error_handle_func=self.put_to_fail_release_task_queue)
284+
trans_obj.ready_to_move_queue.put(
285+
alloc_tokened_tasks, error_handle_func=self.put_to_fail_release_task_queue
286+
)
284287

285288
return ans_list
286-
289+
287290
# ==================================================================================
288291
# 定时检测kv 传输成功,但是长时间没有pd master来触发推理的请求,
289292
# 释放这些超时请求占用的kv资源
@@ -308,24 +311,24 @@ def check_trans_process_loop(self):
308311
for device_id in range(self.node_world_size):
309312
if not self.kv_trans_processes[device_id].is_trans_process_health():
310313
raise Exception(f"device_id {device_id} kv process is unhealth")
311-
314+
312315
time.sleep(10.0)
313316
except (BaseException, RuntimeError) as e:
314317
logger.exception(str(e))
315-
318+
316319
for device_id in range(self.node_world_size):
317320
self.kv_trans_processes[device_id].killself()
318321

319322
# 杀掉当前进程的父进程(router), 触发全局崩溃
320323
os.kill(os.getppid(), signal.SIGKILL)
321324
os.kill(os.getpid(), signal.SIGKILL)
322325
raise e
323-
326+
324327
# ==================================================================================
325328
# 常用辅助功能函数
326329
# ==================================================================================
327330
def get_next_device_index(self):
328-
counts = [0 for _ in range(self.node_world_size)]
331+
counts = [0 for _ in range(self.node_world_size)]
329332
for obj in self.connect_id_to_trans_obj.values():
330333
counts[obj.device_index] += 1
331334
device_index = int(np.argmin(counts))

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_obj.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616

1717
KV_MOVE_MAX_NUM = 16
1818

19+
1920
@dataclass
2021
class KVTransConnectObj:
2122
connect_id: str = None
2223
prefill_node_id: int = None
23-
kv_trans_process: 'KVTransProcess' = None
24+
kv_trans_process: "KVTransProcess" = None
2425
pd_prefill_nccl_ip: str = None
2526
pd_prefill_nccl_port: int = None
2627
device_index: int = None
@@ -33,8 +34,13 @@ class KVTransConnectObj:
3334
timer_checker: TimeChecker = None
3435

3536
def create(
36-
self, connect_id: str, prefill_node_id: str, pd_prefill_nccl_ip: str, pd_prefill_nccl_port: int, manager: "DecodeKVMoveManager"
37-
):
37+
self,
38+
connect_id: str,
39+
prefill_node_id: str,
40+
pd_prefill_nccl_ip: str,
41+
pd_prefill_nccl_port: int,
42+
manager: "DecodeKVMoveManager",
43+
):
3844
self.connect_id = connect_id
3945
self.device_index = manager.get_next_device_index()
4046
self.kv_trans_process = manager.kv_trans_processes[self.device_index]
@@ -49,7 +55,7 @@ def create(
4955
pd_prefill_nccl_port=pd_prefill_nccl_port,
5056
decode_id=decode_node_id,
5157
decode_device_id=self.device_index,
52-
connect_id=self.connect_id
58+
connect_id=self.connect_id,
5359
)
5460
)
5561
assert self.kv_trans_process.task_out_queue.get(timeout=60) == "nccl_ok"
@@ -74,7 +80,7 @@ def create(
7480
self.put_to_radix_thread = threading.Thread(target=self.put_to_radix_loop, daemon=True)
7581
self.put_to_radix_thread.start()
7682
return
77-
83+
7884
# ==================================================================================
7985
# 处理接受所有进行 kv 传输的请求,完成后,将请求放入到 move_finished_queue 中
8086
# ==================================================================================
@@ -106,7 +112,7 @@ def kv_move_loop(self):
106112
logger.error(f"error get need 1, but get {len(move_tasks)}")
107113
assert False
108114

109-
move_tasks:List[KVMoveTask] = move_tasks[0]
115+
move_tasks: List[KVMoveTask] = move_tasks[0]
110116
for task in move_tasks:
111117
logger.info(f"{func_name} get task {task.to_decode_log_info()}")
112118

@@ -128,7 +134,7 @@ def kv_move_loop(self):
128134

129135
logger.error(f"{func_name} thread quit")
130136
return
131-
137+
132138
# ==================================================================================
133139
# 将传输完成的请求,放入到 radix cache 中进行管理。
134140
# ==================================================================================
@@ -168,11 +174,11 @@ def put_to_radix_loop(self):
168174

169175
logger.error(f"{func_name} thread quit, info: {self.to_log_info()}")
170176
return
171-
177+
172178
# ==================================================================================
173179
# 错误处理检测操作的一些通用函数
174180
# ==================================================================================
175-
181+
176182
def timer_to_check_status(self, raise_exception=True):
177183
if self.timer_checker.has_exceeded():
178184
try:
@@ -203,10 +209,10 @@ def set_has_error(self):
203209

204210
if self.ready_to_move_queue is not None:
205211
self.ready_to_move_queue.has_error = True
206-
212+
207213
if self.move_finished_queue is not None:
208214
self.move_finished_queue.has_error = True
209-
215+
210216
if self.manager is not None:
211217
self.manager.remove_trans_obj(self.connect_id)
212218
return
@@ -219,15 +225,19 @@ def __del__(self):
219225

220226
join_if_alive(self.kv_move_thread)
221227
join_if_alive(self.put_to_radix_thread)
222-
228+
223229
if self.connect_id is not None and self.kv_trans_process is not None:
224-
self.kv_trans_process.task_in_queue.put(PDTransLeaveInfo(decode_id=self.decode_node_id, prefill_id=self.prefill_node_id, connect_id=self.connect_id))
230+
self.kv_trans_process.task_in_queue.put(
231+
PDTransLeaveInfo(
232+
decode_id=self.decode_node_id, prefill_id=self.prefill_node_id, connect_id=self.connect_id
233+
)
234+
)
225235

226236
if self.ready_to_move_queue is not None:
227237
self.ready_to_move_queue.clear_tasks()
228238
if self.move_finished_queue is not None:
229239
self.move_finished_queue.clear_tasks()
230-
240+
231241
except BaseException as e:
232242
logger.exception(str(e))
233243

@@ -240,6 +250,7 @@ def to_log_info(self):
240250
log += f"device_index: {self.device_index} "
241251
return log
242252

253+
243254
@dataclass
244255
class KVTransProcess:
245256
process: mp.Process = None
@@ -249,7 +260,6 @@ class KVTransProcess:
249260
task_out_queue: mp.Queue = None
250261
device_id: int = None
251262

252-
253263
def init_all(self, device_id: int, manager: "DecodeKVMoveManager"):
254264
self.device_lock = threading.Lock()
255265
self.device_id = device_id
@@ -271,12 +281,12 @@ def init_all(self, device_id: int, manager: "DecodeKVMoveManager"):
271281
assert self.task_out_queue.get(timeout=60) == "get_mem_managers_ok"
272282

273283
return True
274-
284+
275285
except Exception as e:
276286
logger.warning(f"Failed start kv trans process for device {device_id}: {e}")
277287
logger.exception(str(e))
278288
return False
279-
289+
280290
def is_trans_process_health(self):
281291
try:
282292
process = psutil.Process(self.process.pid)
@@ -287,6 +297,6 @@ def is_trans_process_health(self):
287297
return True
288298
except:
289299
return False
290-
300+
291301
def killself(self):
292302
self.process.kill()

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_trans_process.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ def _init_env(args, device_id: int, task_in_queue: mp.Queue, task_out_queue: mp.
8383
while True:
8484
task: Union[KVMoveTaskGroup, PDTransJoinInfo, PDTransLeaveInfo] = task_in_queue.get()
8585
if isinstance(task, KVMoveTaskGroup):
86-
_handle_kvmove_task(task, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node)
86+
_handle_kvmove_task(
87+
task, task_out_queue, mem_managers, connect_id_to_comm, task.connect_id, dp_size_in_node
88+
)
8789
elif isinstance(task, PDTransJoinInfo):
8890
_handle_prefill_join(task, task_out_queue, connect_id_to_comm)
8991
elif isinstance(task, PDTransLeaveInfo):

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_kv_move_manager.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,14 @@ def __init__(self, args, info_queue: mp.Queue, mem_queues: List[mp.Queue]):
6666
self.release_tasks_thread.start()
6767

6868
from .prefill_trans_obj import KVTransProcess
69-
69+
7070
self.kv_trans_processes: List[KVTransProcess] = [None] * self.node_world_size
7171
for device_id in range(self.node_world_size):
7272
self.kv_trans_processes[device_id] = KVTransProcess()
7373
assert self.kv_trans_processes[device_id].init_all(device_id, self)
7474

7575
return
76-
76+
7777
# ==================================================================================
7878
# 主任务循环,接收需要进行kv传输的请求进行处理
7979
# ==================================================================================
@@ -95,7 +95,7 @@ def task_dispatcher_loop(self):
9595
except (BaseException, RuntimeError) as e:
9696
logger.exception(str(e))
9797
raise e
98-
98+
9999
# ==================================================================================
100100
# 请求出错或者完成kv传输后的处理队列和线程loop
101101
# ==================================================================================
@@ -117,7 +117,7 @@ def handle_release_task_loop(self):
117117
else:
118118
self._remove_req_refs_from_prompt_cache(handle_list)
119119
return
120-
120+
121121
# ==================================================================================
122122
# 定时检测传输进程的健康状态,出现问题拉崩整个系统触发重启
123123
# ==================================================================================
@@ -128,21 +128,21 @@ def check_trans_process_loop(self):
128128
for device_id in range(self.node_world_size):
129129
if not self.kv_trans_processes[device_id].is_trans_process_health():
130130
raise Exception(f"device_id {device_id} kv process is unhealth")
131-
131+
132132
time.sleep(10.0)
133133
except (BaseException, RuntimeError) as e:
134134
logger.exception(str(e))
135-
135+
136136
for device_id in range(self.node_world_size):
137137
self.kv_trans_processes[device_id].killself()
138138

139139
# 杀掉当前进程的父进程(router), 触发全局崩溃
140140
os.kill(os.getppid(), signal.SIGKILL)
141141
os.kill(os.getpid(), signal.SIGKILL)
142142
raise e
143-
143+
144144
# ==================================================================================
145-
# 与推理进程交互接口, _remove_req_refs_from_prompt_cache 和
145+
# 与推理进程交互接口, _remove_req_refs_from_prompt_cache 和
146146
# _put_mem_manager_to_mem_queue 都是通过 rpyc 与推理进程进行交互的接口
147147
# ==================================================================================
148148

@@ -172,7 +172,7 @@ def _put_mem_manager_to_mem_queue(self):
172172
async def wait_all_future_finish(self, futures: List[AsyncResult]):
173173
await asyncio.gather(*[asyncio.to_thread(future.wait) for future in futures])
174174
return
175-
175+
176176
# ==================================================================================
177177
# 辅助功能接口
178178
# ==================================================================================
@@ -191,18 +191,18 @@ def remove_trans_obj(self, connect_id):
191191
trans_obj.set_has_error()
192192
logger.error(f"remove tran obj id {trans_obj.decode_node_id}")
193193
return
194-
194+
195195
def __get_trans_obj(self, task: KVMoveTask):
196196
self.__remove_dead_trans_obj()
197197
# 如果已经存在连接对象,直接返回
198198
for obj in self.connect_id_to_trans_obj.values():
199199
if obj.decode_node_id == task.decode_node.node_id:
200200
return obj
201-
201+
202202
# 如果不存在连接对象,创建新的连接对象
203203
gc.collect()
204204
from .prefill_trans_obj import KVTransConnectObj
205-
205+
206206
trans_obj = KVTransConnectObj()
207207
trans_obj.create(task.decode_node.node_id, task.decode_node.ip, task.decode_node.rpyc_port, self)
208208
self.connect_id_to_trans_obj[trans_obj.connect_id] = trans_obj
@@ -221,6 +221,7 @@ def __remove_dead_trans_obj(self):
221221
gc.collect()
222222
return
223223

224+
224225
def _init_env(args, info_queue: mp.Queue, mem_queues: List[mp.Queue], event: mp.Event):
225226
import lightllm.utils.rpyc_fix_utils as _
226227

0 commit comments

Comments
 (0)