66
77from lightllm .utils .log_utils import init_logger
88
9- from .pd_remote_prefill_obj import RemoteAgent , KVMoveRequest , PrefillRequest , RemotePrefillStatus , ThreadSafeDict
9+ from .pd_remote_prefill_obj import (
10+ RemoteAgent , KVMoveRequest , PrefillRequest ,
11+ RemotePrefillStatus , ThreadSafeDict , KVMoveRequestState
12+ )
1013
1114
1215logger = init_logger (__name__ )
@@ -108,11 +111,20 @@ def _get_token_desc_ids(self, token_ids: List[int]):
108111 descs_ids .append (layer_id * self .num_tokens + token_id )
109112 return descs_ids
110113
111- def write_blocks (self , request : KVMoveRequest , prefill_request : PrefillRequest ):
114+ def write_blocks (self , request : KVMoveRequest , prefill_request : PrefillRequest , is_finished : bool ):
112115 group_reqeust_id = request .group_req_id
113116 skip_kv_move_len = prefill_request .data .local_cached_len
114- src_token_ids = request .token_ids [skip_kv_move_len :]
115- dst_token_ids = prefill_request .data .token_ids
117+
118+ # current kv len is less than remote cached kv len, just skip
119+ if request .cur_kv_len <= skip_kv_move_len :
120+ return
121+
122+ kv_move_start = max (skip_kv_move_len , request .prev_kv_len )
123+ kv_move_end = request .cur_kv_len
124+
125+ src_token_ids = request .token_ids [kv_move_start :]
126+ dst_token_ids = prefill_request .data .token_ids [kv_move_start - skip_kv_move_len : kv_move_end ]
127+
116128 remote_agent : RemoteAgent = self .remote_agents [prefill_request .decode_id ][
117129 self .tp_idx
118130 ] # TODO one-one mapping now
@@ -124,52 +136,85 @@ def write_blocks(self, request: KVMoveRequest, prefill_request: PrefillRequest):
124136
125137 src_handle = self .local_xfer_handles
126138 dst_handle = remote_agent .kv_xfer_handles
127- notify_status = RemotePrefillStatus (group_req_id = group_reqeust_id , status = 1 )
139+ notify_status = RemotePrefillStatus (
140+ group_req_id = group_reqeust_id ,
141+ status = 1 ,
142+ chunk_id = prefill_request .transfer_state .current_chunk_id ,
143+ is_last = is_finished )
144+
128145 handle = self .nixl_agent .make_prepped_xfer (
129146 "WRITE" , src_handle , src_token_descs , dst_handle , dst_token_descs , notify_status .serialize ()
130147 )
131148
132149 status = self .nixl_agent .transfer (handle )
133150 assert status != "ERR"
134151
135- self .inflight_transfers [group_reqeust_id ] = (handle , remote_agent , False )
152+ if group_reqeust_id not in self .inflight_transfers :
153+ self .inflight_transfers [group_reqeust_id ] = KVMoveRequestState (
154+ handles = [],
155+ done_handles = [],
156+ remote_agent = remote_agent ,
157+ abort = False
158+ )
159+ self .inflight_transfers [group_reqeust_id ].handles .append (handle )
136160
137161 return handle
138162
139163 return None
140164
141165 def send_abort_notify (self , remote_id : int , group_reqeust_id ):
142166 remote_agent : RemoteAgent = self .remote_agents [remote_id ][self .tp_idx ]
143- notify_status = RemotePrefillStatus (group_req_id = group_reqeust_id , status = - 1 )
167+ notify_status = RemotePrefillStatus (group_req_id = group_reqeust_id , status = - 1 , chunk_id = - 1 , is_last = True )
144168 self .nixl_agent .send_notif (remote_agent .name , notify_status .serialize ())
145169
146170 if group_reqeust_id in self .inflight_transfers :
147- self .inflight_transfers [group_reqeust_id ][ 2 ] = True
171+ self .inflight_transfers [group_reqeust_id ]. abort = True
148172
149173 def get_done_tranfers (self ):
150174 done_req_ids = []
151175
152- for req_id , (handle , remote_agent , is_abort ) in self .inflight_transfers .items ():
153- if is_abort :
176+ for req_id , kv_move_state in self .inflight_transfers .items ():
177+ kv_move_state : KVMoveRequestState
178+ if kv_move_state .abort :
154179 logger .warning (f"{ req_id } Transfer aborted" )
155180 done_req_ids .append ((req_id , - 1 ))
156181 continue
157182
158- remote_agent : RemoteAgent
159- xfer_state = self .nixl_agent .check_xfer_state (handle )
160- if xfer_state == "DONE" :
161- done_req_ids .append ((req_id , 1 ))
162- elif xfer_state == "PROC" :
163- continue
164- else :
165- logger .warning (f"{ req_id } Transfer failed with state { xfer_state } " )
183+ remote_agent : RemoteAgent = kv_move_state .remote_agent
184+
185+ left_handles = []
186+ failed = False
187+ for handle in kv_move_state .handles :
188+ if failed :
189+ left_handles .append (handle )
190+ continue
191+
192+ xfer_state = self .nixl_agent .check_xfer_state (handle )
193+
194+ if xfer_state == "DONE" :
195+ kv_move_state .done_handles .append (handle )
196+ elif xfer_state == "PROC" :
197+ left_handles .append (handle )
198+ else :
199+ logger .warning (f"{ req_id } Transfer failed with state { xfer_state } " )
200+ failed = True
201+ kv_move_state .done_handles .append (handle )
202+ notify_failed_status = RemotePrefillStatus (group_req_id = req_id , status = - 1 , chunk_id = - 1 , is_last = True )
203+ self .nixl_agent .send_notif (remote_agent .name , notify_failed_status .serialize ())
204+
205+ kv_move_state .handles = left_handles
206+
207+ if failed :
166208 done_req_ids .append ((req_id , - 1 ))
167- notify_failed_status = RemotePrefillStatus ( group_req_id = req_id , status = - 1 )
168- self . nixl_agent . send_notif ( remote_agent . name , notify_failed_status . serialize ( ))
209+ elif len ( left_handles ) == 0 :
210+ done_req_ids . append (( req_id , 1 ))
169211
170212 for req_id , _ in done_req_ids :
171- # release will abort inflight transfer
172- self .nixl_agent .release_xfer_handle (self .inflight_transfers [req_id ][0 ])
213+ kv_move_state : KVMoveRequestState = self .inflight_transfers [req_id ]
214+ for handle in kv_move_state .handles + kv_move_state .done_handles :
215+ # release will abort inflight transfer
216+ self .nixl_agent .release_xfer_handle (handle )
217+
173218 del self .inflight_transfers [req_id ]
174219
175220 return done_req_ids
0 commit comments