@@ -184,13 +184,20 @@ def end_transfer(self, request: LlmRequest):
184184 """
185185 request , block_id , counter = self .requests .pop (request .py_request_id )
186186
187+ should_terminate = False
188+
187189 if counter == 1 :
188190 if self .request_should_store_blocks .pop (request .py_request_id ):
189191 self .kv_cache_manager .unpin_blocks_by_id (block_id )
192+ else :
193+ should_terminate = True
194+
190195 request .state = LlmRequestState .DISAGG_CONTEXT_COMPLETE
191196 else :
192197 self .requests [request .py_request_id ] = (request , block_id ,
193198 counter - 1 )
199+
200+ return should_terminate
194201
195202
196203class PyExecutor :
@@ -1291,7 +1298,8 @@ def _kv_connector_terminate_requests(self):
12911298 if self .kv_connector_manager :
12921299 reqs_to_terminate = self .kv_connector_manager .get_finished ()
12931300 for req in reqs_to_terminate :
1294- self .async_transfer_manager .end_transfer (req )
1301+ if self .async_transfer_manager .end_transfer (req ):
1302+ self ._terminate_request (req )
12951303
12961304 def _kv_connector_wait_for_save (self ):
12971305 if self .kv_connector_manager is not None :
@@ -2104,6 +2112,9 @@ def kv_connector_request_finished(req: LlmRequest):
21042112 self .async_transfer_manager .start_transfer (
21052113 req , should_store_blocks = disagg_should_store_blocks )
21062114
2115+ if self .kv_cache_transceiver .kv_transfer_timeout_ms is not None :
2116+ req .py_kv_transfer_start_time = time .time ()
2117+
21072118 if self .kv_connector_manager :
21082119 if not self .disable_overlap_scheduler :
21092120 requests = self .previous_batch .sample_state .scheduled_requests .all_requests (
@@ -2133,7 +2144,32 @@ def _check_cache_transfer_errors(self, error_msg_prefix: str):
21332144
21342145 @nvtx_range ("_check_disagg_ctx_cache_transfer_status" )
21352146 def _check_disagg_ctx_cache_transfer_status (self , atLeastNum : int = 0 ):
2136- self .kv_cache_transceiver .check_context_transfer_status (atLeastNum )
2147+ finished_requests , error_requests = self .kv_cache_transceiver .check_context_transfer_status (
2148+ atLeastNum )
2149+
2150+ finished_error_ids = set (finished_requests + error_requests )
2151+
2152+ for request_id in finished_requests + error_requests :
2153+
2154+ request = self .async_transfer_manager .requests_in_transfer ().get (
2155+ request_id )
2156+
2157+ if self .async_transfer_manager .end_transfer (request ):
2158+ self ._terminate_request (request )
2159+
2160+ for request in list (
2161+ self .async_transfer_manager .requests_in_transfer ().values ()):
2162+ if request .py_kv_transfer_timed_out and request .py_request_id not in finished_error_ids :
2163+ is_cancelled = self .kv_cache_transceiver .cancel_request (request )
2164+ # If cancel is successful, mark as complete so it can be cleaned up
2165+ # Otherwise, try at next iteration
2166+ if is_cancelled :
2167+ request .py_kv_transfer_start_time = None
2168+ request .state = LlmRequestState .DISAGG_CONTEXT_COMPLETE
2169+
2170+ if self .async_transfer_manager .end_transfer (request ):
2171+ self ._terminate_request (request )
2172+
21372173 self ._check_cache_transfer_errors ("context requests" )
21382174
21392175 @nvtx_range ("_check_disagg_gen_cache_transfer_status" )
@@ -2511,10 +2547,7 @@ def _handle_responses(self):
25112547 if self .block_reuse_enabled and not self .kv_cache_manager .is_vswa :
25122548 requests_to_terminate .append (request )
25132549 else :
2514- if request .is_disagg_context_transmission_state :
2515- self .async_transfer_manager .begin_transfer (
2516- request , False )
2517- else :
2550+ if not request .is_disagg_context_transmission_state :
25182551 requests_to_terminate .append (request )
25192552 else :
25202553 new_active_requests .append (request )
0 commit comments