Skip to content

Commit ef7f19a

Browse files
committed
Fix disagg timeout
Signed-off-by: jthomson04 <[email protected]>
1 parent a60fa96 commit ef7f19a

File tree

1 file changed

+39
-6
lines changed

1 file changed

+39
-6
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,20 @@ def end_transfer(self, request: LlmRequest):
184184
"""
185185
request, block_id, counter = self.requests.pop(request.py_request_id)
186186

187+
should_terminate = False
188+
187189
if counter == 1:
188190
if self.request_should_store_blocks.pop(request.py_request_id):
189191
self.kv_cache_manager.unpin_blocks_by_id(block_id)
192+
else:
193+
should_terminate = True
194+
190195
request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE
191196
else:
192197
self.requests[request.py_request_id] = (request, block_id,
193198
counter - 1)
199+
200+
return should_terminate
194201

195202

196203
class PyExecutor:
@@ -1291,7 +1298,8 @@ def _kv_connector_terminate_requests(self):
12911298
if self.kv_connector_manager:
12921299
reqs_to_terminate = self.kv_connector_manager.get_finished()
12931300
for req in reqs_to_terminate:
1294-
self.async_transfer_manager.end_transfer(req)
1301+
if self.async_transfer_manager.end_transfer(req):
1302+
self._terminate_request(req)
12951303

12961304
def _kv_connector_wait_for_save(self):
12971305
if self.kv_connector_manager is not None:
@@ -2104,6 +2112,9 @@ def kv_connector_request_finished(req: LlmRequest):
21042112
self.async_transfer_manager.start_transfer(
21052113
req, should_store_blocks=disagg_should_store_blocks)
21062114

2115+
if self.kv_cache_transceiver.kv_transfer_timeout_ms is not None:
2116+
req.py_kv_transfer_start_time = time.time()
2117+
21072118
if self.kv_connector_manager:
21082119
if not self.disable_overlap_scheduler:
21092120
requests = self.previous_batch.sample_state.scheduled_requests.all_requests(
@@ -2133,7 +2144,32 @@ def _check_cache_transfer_errors(self, error_msg_prefix: str):
21332144

21342145
@nvtx_range("_check_disagg_ctx_cache_transfer_status")
21352146
def _check_disagg_ctx_cache_transfer_status(self, atLeastNum: int = 0):
2136-
self.kv_cache_transceiver.check_context_transfer_status(atLeastNum)
2147+
finished_requests, error_requests = self.kv_cache_transceiver.check_context_transfer_status(
2148+
atLeastNum)
2149+
2150+
finished_error_ids = set(finished_requests + error_requests)
2151+
2152+
for request_id in finished_requests + error_requests:
2153+
2154+
request = self.async_transfer_manager.requests_in_transfer().get(
2155+
request_id)
2156+
2157+
if self.async_transfer_manager.end_transfer(request):
2158+
self._terminate_request(request)
2159+
2160+
for request in list(
2161+
self.async_transfer_manager.requests_in_transfer().values()):
2162+
if request.py_kv_transfer_timed_out and request.py_request_id not in finished_error_ids:
2163+
is_cancelled = self.kv_cache_transceiver.cancel_request(request)
2164+
# If cancel is successful, mark as complete so it can be cleaned up
2165+
# Otherwise, try at next iteration
2166+
if is_cancelled:
2167+
request.py_kv_transfer_start_time = None
2168+
request.state = LlmRequestState.DISAGG_CONTEXT_COMPLETE
2169+
2170+
if self.async_transfer_manager.end_transfer(request):
2171+
self._terminate_request(request)
2172+
21372173
self._check_cache_transfer_errors("context requests")
21382174

21392175
@nvtx_range("_check_disagg_gen_cache_transfer_status")
@@ -2511,10 +2547,7 @@ def _handle_responses(self):
25112547
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa:
25122548
requests_to_terminate.append(request)
25132549
else:
2514-
if request.is_disagg_context_transmission_state:
2515-
self.async_transfer_manager.begin_transfer(
2516-
request, False)
2517-
else:
2550+
if not request.is_disagg_context_transmission_state:
25182551
requests_to_terminate.append(request)
25192552
else:
25202553
new_active_requests.append(request)

0 commit comments

Comments
 (0)