@@ -62,16 +62,20 @@ def init_custom(self):
6262 self .page_scheduer = SafePageIndexScheduler (self .nixl_agent .num_pages )
6363
6464 self .nixl_meta_queue .put (
65- (self .nixl_agent .agent_metadata , self .nixl_agent .num_tokens , self .nixl_agent .num_pages ,
66- self .nixl_agent .local_mem_desc , self .nixl_agent .local_page_mem_desc )
65+ (
66+ self .nixl_agent .agent_metadata ,
67+ self .nixl_agent .num_tokens ,
68+ self .nixl_agent .num_pages ,
69+ self .nixl_agent .local_mem_desc ,
70+ self .nixl_agent .local_page_mem_desc ,
71+ )
6772 )
6873
6974 def _start_async_loop (self , async_loop_func ):
7075 loop = asyncio .new_event_loop ()
7176 asyncio .set_event_loop (loop )
7277 loop .run_until_complete (async_loop_func ())
7378
74-
7579 async def _handle_remote_prefill (self , req_status : RemotePrefillStatus ):
7680 group_req_id = req_status .group_req_id
7781 status = req_status .status
@@ -80,29 +84,36 @@ async def _handle_remote_prefill(self, req_status: RemotePrefillStatus):
8084
8185 ret = None
8286 if run_req := self .remote_prefilled_reqs .get (group_req_id , None ):
83- if req_status .transfer_type == RemoteTransferType .PAGE_TRANSFER and status == RemoteTransferStatusType .SUCCESS :
87+ if (
88+ req_status .transfer_type == RemoteTransferType .PAGE_TRANSFER
89+ and status == RemoteTransferStatusType .SUCCESS
90+ ):
8491 kv_start , kv_len = req_status .kv_start , req_status .kv_len
85- token_ids = g_infer_context .req_manager .req_to_token_indexs [run_req .req_idx ][kv_start : kv_start + kv_len ] # gpu tensor
86- self .model .mem_manager .kv_buffer [:, token_ids , :, :] = self .model .mem_manager .kv_move_buffer [req_status .page_id ][:kv_len ].transpose (0 , 1 )
92+ token_ids = g_infer_context .req_manager .req_to_token_indexs [run_req .req_idx ][
93+ kv_start : kv_start + kv_len
94+ ] # gpu tensor
95+ self .model .mem_manager .kv_buffer [:, token_ids , :, :] = self .model .mem_manager .kv_move_buffer [
96+ req_status .page_id
97+ ][:kv_len ].transpose (0 , 1 )
8798 ret = PageTransferAck (group_req_id = group_req_id , page_id = req_status .page_id )
8899
89100 if req_status .is_last or status != RemoteTransferStatusType .SUCCESS :
90- shm_req : PDNIXLChunkedPrefillReq = run_req .shm_req
91- shm_req .set_pd_req_rank_state (self .rank_in_dp , status .value )
92- self .remote_prefilled_reqs .pop (group_req_id )
93- self .request_to_first_token [group_req_id ] = (req_status .next_token_id , req_status .next_token_logprob )
101+ shm_req : PDNIXLChunkedPrefillReq = run_req .shm_req
102+ shm_req .set_pd_req_rank_state (self .rank_in_dp , status .value )
103+ self .remote_prefilled_reqs .pop (group_req_id )
104+ self .request_to_first_token [group_req_id ] = (req_status .next_token_id , req_status .next_token_logprob )
94105
95- if self .is_master_in_dp :
96- # return page ids
97- if group_req_id in self .request_to_page_ids :
98- self .page_scheduer .return_ (self .request_to_page_ids [group_req_id ])
99- del self .request_to_page_ids [group_req_id ]
100-
101- logger .info (
102- f"remote prefill reqeust: { group_req_id } done with status: { status } "
103- f"took: { time .time () - run_req .remote_prefill_start } seconds"
104- )
105- ret = None
106+ if self .is_master_in_dp :
107+ # return page ids
108+ if group_req_id in self .request_to_page_ids :
109+ self .page_scheduer .return_ (self .request_to_page_ids [group_req_id ])
110+ del self .request_to_page_ids [group_req_id ]
111+
112+ logger .info (
113+ f"remote prefill reqeust: { group_req_id } done with status: { status } "
114+ f"took: { time .time () - run_req .remote_prefill_start } seconds"
115+ )
116+ ret = None
106117
107118 else :
108119 if self .is_master_in_dp :
@@ -112,7 +123,7 @@ async def _handle_remote_prefill(self, req_status: RemotePrefillStatus):
112123
113124 async def _prefill_wait_loop_async (self ):
114125 while True :
115- # from local
126+ # from local
116127 try :
117128 req_status = self .from_remote_queue .get_nowait ()
118129 await self ._handle_remote_prefill (req_status )
@@ -141,7 +152,7 @@ async def _prefill_wait_loop_async(self):
141152
142153 await asyncio .sleep (PDNIXLBackendBase ._THREAD_WAIT_INTERVAL )
143154
144- def _handle_chunked_transfer (self , req : InferReq , next_token_id : int = None , next_token_logprob : float = None ):
155+ def _handle_chunked_transfer (self , req : InferReq , next_token_id : int = None , next_token_logprob : float = None ):
145156 if next_token_id :
146157 next_token_id = int (next_token_id )
147158 next_token_logprob = float (next_token_logprob )
@@ -164,7 +175,7 @@ def _handle_chunked_transfer(self, req: InferReq, next_token_id: int=None, next_
164175 free_page_ids = remote_request .data .page_ids .copy (),
165176 next_token_id = next_token_id ,
166177 next_token_logprob = next_token_logprob ,
167- lock = threading .Lock ()
178+ lock = threading .Lock (),
168179 )
169180 shm_req .set_pd_req_rank_state (self .rank_in_dp , RemoteTransferStatusType .IN_PROGRESS .value )
170181 req .in_prefill_or_transfer = True
@@ -179,7 +190,6 @@ def _handle_chunked_transfer(self, req: InferReq, next_token_id: int=None, next_
179190 transfer_state .next_token_id = next_token_id
180191 transfer_state .next_token_logprob = next_token_logprob
181192
182-
183193 async def _transfer_kv_to_remote_paged_batch (self , transfer_reqs : List [KVMoveRequest ]):
184194 start = time .time ()
185195 requests_by_agents = dict ()
@@ -198,26 +208,30 @@ async def _transfer_kv_to_remote_paged_batch(self, transfer_reqs: List[KVMoveReq
198208
199209 start_kv_len = transfer_state .transfered_kv_len
200210 trans_kv_len = min (trans_req .cur_kv_len - trans_req .prev_kv_len , self .nixl_agent .page_size )
201- trans_kv_index = transfer_state .token_index [start_kv_len : start_kv_len + trans_kv_len ]
202- self .model .mem_manager .kv_move_buffer [page_index ][:trans_kv_len ] = self .model .mem_manager .kv_buffer [:,trans_kv_index , :, : ].transpose (0 , 1 )
211+ trans_kv_index = transfer_state .token_index [start_kv_len : start_kv_len + trans_kv_len ]
212+ self .model .mem_manager .kv_move_buffer [page_index ][:trans_kv_len ] = self .model .mem_manager .kv_buffer [
213+ :, trans_kv_index , :, :
214+ ].transpose (0 , 1 )
203215
204216 receive_page = transfer_state .free_page_ids .pop (0 )
205217 requests_by_agents [decode_id ][0 ].append (page_index )
206218 requests_by_agents [decode_id ][1 ].append (receive_page )
207- is_last = (transfer_state .is_finished and start_kv_len + trans_kv_len == transfer_state .current_kv_len )
208-
209- requests_by_agents [decode_id ][2 ].append (RemotePrefillStatus (
210- transfer_type = RemoteTransferType .PAGE_TRANSFER ,
211- group_req_id = group_req_id ,
212- status = RemoteTransferStatusType .SUCCESS ,
213- chunk_id = transfer_state .current_chunk_id ,
214- is_last = is_last ,
215- page_id = receive_page ,
216- kv_start = start_kv_len ,
217- kv_len = trans_kv_len ,
218- next_token_id = transfer_state .next_token_id ,
219- next_token_logprob = transfer_state .next_token_logprob
220- ))
219+ is_last = transfer_state .is_finished and start_kv_len + trans_kv_len == transfer_state .current_kv_len
220+
221+ requests_by_agents [decode_id ][2 ].append (
222+ RemotePrefillStatus (
223+ transfer_type = RemoteTransferType .PAGE_TRANSFER ,
224+ group_req_id = group_req_id ,
225+ status = RemoteTransferStatusType .SUCCESS ,
226+ chunk_id = transfer_state .current_chunk_id ,
227+ is_last = is_last ,
228+ page_id = receive_page ,
229+ kv_start = start_kv_len ,
230+ kv_len = trans_kv_len ,
231+ next_token_id = transfer_state .next_token_id ,
232+ next_token_logprob = transfer_state .next_token_logprob ,
233+ )
234+ )
221235 transfer_state .transfered_kv_len += trans_kv_len
222236
223237 # wait copy done
@@ -227,11 +241,7 @@ async def _transfer_kv_to_remote_paged_batch(self, transfer_reqs: List[KVMoveReq
227241 # transfer
228242 self .nixl_agent .write_blocks_paged (decode_id , transfer_pages , receive_pages , notifications )
229243
230-
231- logger .info (
232- f"transfer kv to remote paged batch: { len (transfer_reqs )} "
233- f"took: { time .time () - start } seconds"
234- )
244+ logger .info (f"transfer kv to remote paged batch: { len (transfer_reqs )} " f"took: { time .time () - start } seconds" )
235245
236246 async def _handle_transfer_loop (self ):
237247 while True :
@@ -312,7 +322,6 @@ async def _wait_page_transfer_loop(self):
312322
313323 await asyncio .sleep (PDNIXLBackendBase ._THREAD_WAIT_INTERVAL )
314324
315-
316325 async def _wait_transfer_loop (self ):
317326 while True :
318327 done_req_ids = self .nixl_agent .get_done_tranfers ()
@@ -375,7 +384,7 @@ def _transfer_kv_to_remote(self, req: InferReq, group_req_id: int, cur_kv_len: i
375384
376385 kv_transfer_req = KVMoveRequest (
377386 group_req_id = group_req_id ,
378- token_ids = token_index [: cur_kv_len ].tolist (),
387+ token_ids = token_index [:cur_kv_len ].tolist (),
379388 prev_kv_len = transfer_state .current_kv_len ,
380389 cur_kv_len = cur_kv_len ,
381390 )
@@ -403,11 +412,11 @@ def _post_remote_prefill(self, req: InferReq, success: bool = True):
403412 if self .is_master_in_dp :
404413 req .shm_req .shm_cur_kv_len = req .cur_kv_len
405414
415+ group_req_id = req .shm_req .group_req_id
406416 if not success :
407417 self .request_to_first_token .pop (group_req_id , None )
408418 return
409419
410- group_req_id = req .shm_req .group_req_id
411420 assert group_req_id in self .request_to_first_token
412421 token_id , token_logprob = self .request_to_first_token .pop (group_req_id )
413422
@@ -520,14 +529,13 @@ def _prepare_remote_prefill_inputs(self, req_objs: List[InferReq]):
520529 g_infer_context .radix_cache .free_radix_cache_to_get_enough_token (input_ids .shape [0 ])
521530 mem_indexes = g_infer_context .req_manager .mem_manager .alloc (input_ids .shape [0 ])
522531
523-
524532 req_to_token_indexs = g_infer_context .req_manager .req_to_token_indexs
525533 for idx , req_idx in enumerate (nopad_b_req_idx ):
526534 cur_kv_len = req_objs [idx ].cur_kv_len
527535 seq_len = nopad_b_seq_len [idx ]
528536 mem_start = nopad_b_start_loc [idx ]
529- mem_end = nopad_b_start_loc [idx + 1 ]
530- req_to_token_indexs [req_idx , cur_kv_len : nopad_b_seq_len [idx ]] = mem_indexes [mem_start :mem_end ]
537+ mem_end = nopad_b_start_loc [idx + 1 ]
538+ req_to_token_indexs [req_idx , cur_kv_len : nopad_b_seq_len [idx ]] = mem_indexes [mem_start :mem_end ]
531539
532540 kwargs = {
533541 "batch_size" : len (run_reqs ),
@@ -547,4 +555,4 @@ def _prefill_abort_remote(self, req_objs: List[InferReq]):
547555 self .nixl_agent .send_abort_notify (self .remote_prefill_requests [group_req_id ].decode_id , group_req_id )
548556 del self .remote_prefill_requests [group_req_id ]
549557 if group_req_id in self .inflght_transfer_requests :
550- del self .inflght_transfer_requests [group_req_id ]
558+ del self .inflght_transfer_requests [group_req_id ]
0 commit comments