Skip to content

Commit a9995b3

Browse files
author
Weichao Luo
committed
fix lint.
1 parent 1cb82dc commit a9995b3

File tree

12 files changed

+238
-138
lines changed

12 files changed

+238
-138
lines changed

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def _post_handle(
261261
is_chuncked_mode: bool,
262262
do_filter_finished_reqs: bool,
263263
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
264-
call_post_handle_for_chunk: bool = False ,
264+
call_post_handle_for_chunk: bool = False,
265265
) -> List[int]:
266266
"""
267267
extra_post_req_handle_func 用于提供在一个请求确定输出的时候,给出额外的后处理操作,主要是用于

lightllm/server/router/model_infer/mode_backend/continues_batch/impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def normal_prefill_reqs(
4646
ok_finished_reqs: List[InferReq],
4747
mask_func: Optional[Callable[[List[InferReq], torch.Tensor], None]] = None,
4848
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
49-
call_post_handle_for_chunk: bool = False
49+
call_post_handle_for_chunk: bool = False,
5050
):
5151
model_input, run_reqs = prepare_prefill_inputs(
5252
prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal

lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,15 @@ def decode(self):
5252
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
5353
return
5454

55-
def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs,
56-
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
57-
call_post_handle_for_chunk: bool = False):
55+
def normal_prefill_reqs(
56+
self,
57+
prefill_reqs: List[InferReq],
58+
max_prefill_num: int,
59+
uninit_reqs,
60+
ok_finished_reqs,
61+
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
62+
call_post_handle_for_chunk: bool = False,
63+
):
5864
model_input, run_reqs, padded_req_num = padded_prepare_prefill_inputs(
5965
prefill_reqs, is_multimodal=self.is_multimodal
6066
)
@@ -67,9 +73,13 @@ def normal_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int
6773
next_token_ids = next_token_ids.detach().cpu().numpy()
6874
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
6975
self._post_handle(
70-
run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False,
76+
run_reqs,
77+
next_token_ids,
78+
next_token_logprobs,
79+
is_chuncked_mode=True,
80+
do_filter_finished_reqs=False,
7181
extra_post_req_handle_func=extra_post_req_handle_func,
72-
call_post_handle_for_chunk=call_post_handle_for_chunk
82+
call_post_handle_for_chunk=call_post_handle_for_chunk,
7383
)
7484
return
7585

@@ -121,9 +131,15 @@ def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, unini
121131
)
122132
return
123133

124-
def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: int, uninit_reqs, ok_finished_reqs,
125-
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
126-
call_post_handle_for_chunk: bool = False):
134+
def overlap_prefill_reqs(
135+
self,
136+
prefill_reqs: List[InferReq],
137+
max_prefill_num: int,
138+
uninit_reqs,
139+
ok_finished_reqs,
140+
extra_post_req_handle_func: Optional[Callable[[InferReq, int, float], None]] = None,
141+
call_post_handle_for_chunk: bool = False,
142+
):
127143
(
128144
micro_input,
129145
run_reqs,
@@ -148,8 +164,12 @@ def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: in
148164
next_token_ids = next_token_ids.detach().cpu().numpy()
149165
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()
150166
self._post_handle(
151-
all_run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=True, do_filter_finished_reqs=False,
167+
all_run_reqs,
168+
next_token_ids,
169+
next_token_logprobs,
170+
is_chuncked_mode=True,
171+
do_filter_finished_reqs=False,
152172
extra_post_req_handle_func=extra_post_req_handle_func,
153-
call_post_handle_for_chunk=call_post_handle_for_chunk
173+
call_post_handle_for_chunk=call_post_handle_for_chunk,
154174
)
155175
return

lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_base.py

Lines changed: 61 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,20 @@ def init_custom(self):
6262
self.page_scheduer = SafePageIndexScheduler(self.nixl_agent.num_pages)
6363

6464
self.nixl_meta_queue.put(
65-
(self.nixl_agent.agent_metadata, self.nixl_agent.num_tokens, self.nixl_agent.num_pages,
66-
self.nixl_agent.local_mem_desc, self.nixl_agent.local_page_mem_desc)
65+
(
66+
self.nixl_agent.agent_metadata,
67+
self.nixl_agent.num_tokens,
68+
self.nixl_agent.num_pages,
69+
self.nixl_agent.local_mem_desc,
70+
self.nixl_agent.local_page_mem_desc,
71+
)
6772
)
6873

6974
def _start_async_loop(self, async_loop_func):
7075
loop = asyncio.new_event_loop()
7176
asyncio.set_event_loop(loop)
7277
loop.run_until_complete(async_loop_func())
7378

74-
7579
async def _handle_remote_prefill(self, req_status: RemotePrefillStatus):
7680
group_req_id = req_status.group_req_id
7781
status = req_status.status
@@ -80,29 +84,36 @@ async def _handle_remote_prefill(self, req_status: RemotePrefillStatus):
8084

8185
ret = None
8286
if run_req := self.remote_prefilled_reqs.get(group_req_id, None):
83-
if req_status.transfer_type == RemoteTransferType.PAGE_TRANSFER and status == RemoteTransferStatusType.SUCCESS:
87+
if (
88+
req_status.transfer_type == RemoteTransferType.PAGE_TRANSFER
89+
and status == RemoteTransferStatusType.SUCCESS
90+
):
8491
kv_start, kv_len = req_status.kv_start, req_status.kv_len
85-
token_ids = g_infer_context.req_manager.req_to_token_indexs[run_req.req_idx][kv_start: kv_start + kv_len] # gpu tensor
86-
self.model.mem_manager.kv_buffer[:, token_ids, :, :] = self.model.mem_manager.kv_move_buffer[req_status.page_id][:kv_len].transpose(0, 1)
92+
token_ids = g_infer_context.req_manager.req_to_token_indexs[run_req.req_idx][
93+
kv_start : kv_start + kv_len
94+
] # gpu tensor
95+
self.model.mem_manager.kv_buffer[:, token_ids, :, :] = self.model.mem_manager.kv_move_buffer[
96+
req_status.page_id
97+
][:kv_len].transpose(0, 1)
8798
ret = PageTransferAck(group_req_id=group_req_id, page_id=req_status.page_id)
8899

89100
if req_status.is_last or status != RemoteTransferStatusType.SUCCESS:
90-
shm_req: PDNIXLChunkedPrefillReq = run_req.shm_req
91-
shm_req.set_pd_req_rank_state(self.rank_in_dp, status.value)
92-
self.remote_prefilled_reqs.pop(group_req_id)
93-
self.request_to_first_token[group_req_id] = (req_status.next_token_id, req_status.next_token_logprob)
101+
shm_req: PDNIXLChunkedPrefillReq = run_req.shm_req
102+
shm_req.set_pd_req_rank_state(self.rank_in_dp, status.value)
103+
self.remote_prefilled_reqs.pop(group_req_id)
104+
self.request_to_first_token[group_req_id] = (req_status.next_token_id, req_status.next_token_logprob)
94105

95-
if self.is_master_in_dp:
96-
# return page ids
97-
if group_req_id in self.request_to_page_ids:
98-
self.page_scheduer.return_(self.request_to_page_ids[group_req_id])
99-
del self.request_to_page_ids[group_req_id]
100-
101-
logger.info(
102-
f"remote prefill reqeust: {group_req_id} done with status: {status} "
103-
f"took: {time.time() - run_req.remote_prefill_start} seconds"
104-
)
105-
ret = None
106+
if self.is_master_in_dp:
107+
# return page ids
108+
if group_req_id in self.request_to_page_ids:
109+
self.page_scheduer.return_(self.request_to_page_ids[group_req_id])
110+
del self.request_to_page_ids[group_req_id]
111+
112+
logger.info(
113+
f"remote prefill reqeust: {group_req_id} done with status: {status} "
114+
f"took: {time.time() - run_req.remote_prefill_start} seconds"
115+
)
116+
ret = None
106117

107118
else:
108119
if self.is_master_in_dp:
@@ -112,7 +123,7 @@ async def _handle_remote_prefill(self, req_status: RemotePrefillStatus):
112123

113124
async def _prefill_wait_loop_async(self):
114125
while True:
115-
# from local
126+
# from local
116127
try:
117128
req_status = self.from_remote_queue.get_nowait()
118129
await self._handle_remote_prefill(req_status)
@@ -141,7 +152,7 @@ async def _prefill_wait_loop_async(self):
141152

142153
await asyncio.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL)
143154

144-
def _handle_chunked_transfer(self, req: InferReq, next_token_id: int=None, next_token_logprob: float=None):
155+
def _handle_chunked_transfer(self, req: InferReq, next_token_id: int = None, next_token_logprob: float = None):
145156
if next_token_id:
146157
next_token_id = int(next_token_id)
147158
next_token_logprob = float(next_token_logprob)
@@ -164,7 +175,7 @@ def _handle_chunked_transfer(self, req: InferReq, next_token_id: int=None, next_
164175
free_page_ids=remote_request.data.page_ids.copy(),
165176
next_token_id=next_token_id,
166177
next_token_logprob=next_token_logprob,
167-
lock=threading.Lock()
178+
lock=threading.Lock(),
168179
)
169180
shm_req.set_pd_req_rank_state(self.rank_in_dp, RemoteTransferStatusType.IN_PROGRESS.value)
170181
req.in_prefill_or_transfer = True
@@ -179,7 +190,6 @@ def _handle_chunked_transfer(self, req: InferReq, next_token_id: int=None, next_
179190
transfer_state.next_token_id = next_token_id
180191
transfer_state.next_token_logprob = next_token_logprob
181192

182-
183193
async def _transfer_kv_to_remote_paged_batch(self, transfer_reqs: List[KVMoveRequest]):
184194
start = time.time()
185195
requests_by_agents = dict()
@@ -198,26 +208,30 @@ async def _transfer_kv_to_remote_paged_batch(self, transfer_reqs: List[KVMoveReq
198208

199209
start_kv_len = transfer_state.transfered_kv_len
200210
trans_kv_len = min(trans_req.cur_kv_len - trans_req.prev_kv_len, self.nixl_agent.page_size)
201-
trans_kv_index = transfer_state.token_index[start_kv_len: start_kv_len + trans_kv_len]
202-
self.model.mem_manager.kv_move_buffer[page_index][:trans_kv_len] = self.model.mem_manager.kv_buffer[:,trans_kv_index, :, : ].transpose(0, 1)
211+
trans_kv_index = transfer_state.token_index[start_kv_len : start_kv_len + trans_kv_len]
212+
self.model.mem_manager.kv_move_buffer[page_index][:trans_kv_len] = self.model.mem_manager.kv_buffer[
213+
:, trans_kv_index, :, :
214+
].transpose(0, 1)
203215

204216
receive_page = transfer_state.free_page_ids.pop(0)
205217
requests_by_agents[decode_id][0].append(page_index)
206218
requests_by_agents[decode_id][1].append(receive_page)
207-
is_last = (transfer_state.is_finished and start_kv_len + trans_kv_len == transfer_state.current_kv_len)
208-
209-
requests_by_agents[decode_id][2].append(RemotePrefillStatus(
210-
transfer_type=RemoteTransferType.PAGE_TRANSFER,
211-
group_req_id=group_req_id,
212-
status=RemoteTransferStatusType.SUCCESS,
213-
chunk_id=transfer_state.current_chunk_id,
214-
is_last=is_last,
215-
page_id=receive_page,
216-
kv_start=start_kv_len,
217-
kv_len=trans_kv_len,
218-
next_token_id=transfer_state.next_token_id,
219-
next_token_logprob=transfer_state.next_token_logprob
220-
))
219+
is_last = transfer_state.is_finished and start_kv_len + trans_kv_len == transfer_state.current_kv_len
220+
221+
requests_by_agents[decode_id][2].append(
222+
RemotePrefillStatus(
223+
transfer_type=RemoteTransferType.PAGE_TRANSFER,
224+
group_req_id=group_req_id,
225+
status=RemoteTransferStatusType.SUCCESS,
226+
chunk_id=transfer_state.current_chunk_id,
227+
is_last=is_last,
228+
page_id=receive_page,
229+
kv_start=start_kv_len,
230+
kv_len=trans_kv_len,
231+
next_token_id=transfer_state.next_token_id,
232+
next_token_logprob=transfer_state.next_token_logprob,
233+
)
234+
)
221235
transfer_state.transfered_kv_len += trans_kv_len
222236

223237
# wait copy done
@@ -227,11 +241,7 @@ async def _transfer_kv_to_remote_paged_batch(self, transfer_reqs: List[KVMoveReq
227241
# transfer
228242
self.nixl_agent.write_blocks_paged(decode_id, transfer_pages, receive_pages, notifications)
229243

230-
231-
logger.info(
232-
f"transfer kv to remote paged batch: {len(transfer_reqs)} "
233-
f"took: {time.time() - start} seconds"
234-
)
244+
logger.info(f"transfer kv to remote paged batch: {len(transfer_reqs)} " f"took: {time.time() - start} seconds")
235245

236246
async def _handle_transfer_loop(self):
237247
while True:
@@ -312,7 +322,6 @@ async def _wait_page_transfer_loop(self):
312322

313323
await asyncio.sleep(PDNIXLBackendBase._THREAD_WAIT_INTERVAL)
314324

315-
316325
async def _wait_transfer_loop(self):
317326
while True:
318327
done_req_ids = self.nixl_agent.get_done_tranfers()
@@ -375,7 +384,7 @@ def _transfer_kv_to_remote(self, req: InferReq, group_req_id: int, cur_kv_len: i
375384

376385
kv_transfer_req = KVMoveRequest(
377386
group_req_id=group_req_id,
378-
token_ids=token_index[: cur_kv_len].tolist(),
387+
token_ids=token_index[:cur_kv_len].tolist(),
379388
prev_kv_len=transfer_state.current_kv_len,
380389
cur_kv_len=cur_kv_len,
381390
)
@@ -403,11 +412,11 @@ def _post_remote_prefill(self, req: InferReq, success: bool = True):
403412
if self.is_master_in_dp:
404413
req.shm_req.shm_cur_kv_len = req.cur_kv_len
405414

415+
group_req_id = req.shm_req.group_req_id
406416
if not success:
407417
self.request_to_first_token.pop(group_req_id, None)
408418
return
409419

410-
group_req_id = req.shm_req.group_req_id
411420
assert group_req_id in self.request_to_first_token
412421
token_id, token_logprob = self.request_to_first_token.pop(group_req_id)
413422

@@ -520,14 +529,13 @@ def _prepare_remote_prefill_inputs(self, req_objs: List[InferReq]):
520529
g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0])
521530
mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0])
522531

523-
524532
req_to_token_indexs = g_infer_context.req_manager.req_to_token_indexs
525533
for idx, req_idx in enumerate(nopad_b_req_idx):
526534
cur_kv_len = req_objs[idx].cur_kv_len
527535
seq_len = nopad_b_seq_len[idx]
528536
mem_start = nopad_b_start_loc[idx]
529-
mem_end = nopad_b_start_loc[idx+1]
530-
req_to_token_indexs[req_idx, cur_kv_len:nopad_b_seq_len[idx]] = mem_indexes[mem_start:mem_end]
537+
mem_end = nopad_b_start_loc[idx + 1]
538+
req_to_token_indexs[req_idx, cur_kv_len : nopad_b_seq_len[idx]] = mem_indexes[mem_start:mem_end]
531539

532540
kwargs = {
533541
"batch_size": len(run_reqs),
@@ -547,4 +555,4 @@ def _prefill_abort_remote(self, req_objs: List[InferReq]):
547555
self.nixl_agent.send_abort_notify(self.remote_prefill_requests[group_req_id].decode_id, group_req_id)
548556
del self.remote_prefill_requests[group_req_id]
549557
if group_req_id in self.inflght_transfer_requests:
550-
del self.inflght_transfer_requests[group_req_id]
558+
del self.inflght_transfer_requests[group_req_id]

lightllm/server/router/model_infer/mode_backend/pd_nixl/impl_for_pd_decode.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
from lightllm.utils.log_utils import init_logger
1010
from lightllm.server.multimodal_params import MultimodalParams
1111

12-
from .pd_remote_prefill_obj import RemotePrefillTask, RemotePrefillServerInfo, RemotePrefillRequest, RemoteTransferStatusType
12+
from .pd_remote_prefill_obj import (
13+
RemotePrefillTask,
14+
RemotePrefillServerInfo,
15+
RemotePrefillRequest,
16+
RemoteTransferStatusType,
17+
)
1318

1419
from .impl_for_pd_base import PDNIXLBackendBase
1520

@@ -22,9 +27,9 @@ def __init__(self, prefill_task_queue: mp.Queue, prefill_done_queue: mp.Queue, n
2227

2328
def init_custom(self):
2429
super().init_custom()
25-
self.wait_prefill_thread = threading.Thread(target=self._start_async_loop,
26-
args=(self._prefill_wait_loop_async,),
27-
daemon=True)
30+
self.wait_prefill_thread = threading.Thread(
31+
target=self._start_async_loop, args=(self._prefill_wait_loop_async,), daemon=True
32+
)
2833
self.wait_move_page_pool = ThreadPoolExecutor(max_workers=4)
2934
self.wait_prefill_thread.start()
3035
return
@@ -45,7 +50,7 @@ def _build_remote_prefill_task(self, index: int, kwargs: Dict, req: InferReq):
4550
multimodal_params=MultimodalParams.from_dict(req.multimodal_params),
4651
local_cached_len=req.cur_kv_len,
4752
token_ids=mem_indexes[b_start_loc[index] : b_start_loc[index + 1]],
48-
page_ids=self.page_scheduer.borrow() # get page ids for this request, blocking when not enough pages
53+
page_ids=self.page_scheduer.borrow(), # get page ids for this request, blocking when not enough pages
4954
)
5055
return RemotePrefillTask(server_info=prefill_node_info, prefill_request=prefill_request)
5156

@@ -82,16 +87,20 @@ def decode(self):
8287
if self.is_master_in_dp:
8388
run_req.remote_prefill_start = time.time()
8489
# since this function may blocking the calling thread, so we do it in a thread pool
85-
self.wait_move_page_pool.submit(self._trigger_remote_prefill,
86-
shm_req.group_req_id, idx, kwargs, run_req)
90+
self.wait_move_page_pool.submit(
91+
self._trigger_remote_prefill, shm_req.group_req_id, idx, kwargs, run_req
92+
)
8793

88-
shm_req.set_pd_req_rank_state(self.rank_in_dp, RemoteTransferStatusType.IN_PROGRESS.value) # set in progress state
94+
shm_req.set_pd_req_rank_state(
95+
self.rank_in_dp, RemoteTransferStatusType.IN_PROGRESS.value
96+
) # set in progress state
8997
run_req.in_prefill_or_transfer = True
9098
self.remote_prefilled_reqs[shm_req.group_req_id] = run_req
9199

92100
if decode_reqs:
93101
ContinuesBatchBackend.normal_decode(
94-
self, decode_reqs=decode_reqs, uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs)
102+
self, decode_reqs=decode_reqs, uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs
103+
)
95104

96105
self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True)
97106
return

0 commit comments

Comments
 (0)