Skip to content

Commit 747915e

Browse files
committed
[TRTLLM-8922][fix] Fix gen-first disagg scheduling with pipeline parallelism
Add PP consensus in prepare_context_requests so all PP ranks agree before promoting gen-first context requests, and call _check_disagg_ctx_schedulable_status in the PP executor loop so requests are not stuck in DISAGG_CONTEXT_WAIT_SCHEDULER state. Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com>
1 parent 6a7863d commit 747915e

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

tensorrt_llm/_torch/disaggregation/native/py_cache_transceiver.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def __init__(
100100
self.mapping = mapping
101101

102102
self.ctx_need_tp_sync = mapping.tp_size > 1 and (not mapping.enable_attention_dp)
103+
self.ctx_need_pp_sync = mapping.pp_size > 1
103104

104105
self.gen_need_sync = not (
105106
mapping.world_size == 1 or (mapping.enable_attention_dp and mapping.pp_size == 1)
@@ -381,24 +382,25 @@ def get_disaggregated_params(self) -> Dict[str, Any]:
381382

382383
def prepare_context_requests(self, requests: List[LlmRequest]):
383384
# Place new generation-first context requests into wait state, then
384-
# use tp_allgather consensus to promote ready requests to CONTEXT_INIT.
385+
# use allgather consensus to promote ready requests to CONTEXT_INIT.
385386
for req in requests:
386387
unique_rid = get_unique_rid(req)
387388
if unique_rid not in self.send_sessions:
388389
self.wait_req_id_to_request[unique_rid] = req
389390
req.state = LlmRequestState.DISAGG_CONTEXT_WAIT_SCHEDULER
390391

391392
# Check which waiting requests have peer info locally, then use
392-
# tp_allgather consensus so all TP ranks agree before promoting.
393+
# allgather consensus so all TP/PP ranks agree before promoting.
393394
# Without consensus, background peer info arriving at different
394395
# times on different ranks causes scheduling mismatches → hang.
395-
# Place tp sync here because this function runs in every iteration
396+
# Place sync here because this function runs in every iteration
396397
# but check_context_transfer_status runs when can_queue is True
397398
local_ready_request_ids = []
398399
for request_id in self.wait_req_id_to_request.keys():
399400
if self.transfer_worker.has_all_peer_req_infos_for_send(request_id):
400401
local_ready_request_ids.append(request_id)
401402

403+
# TP consensus: ensure all TP ranks have peer info
402404
if self.ctx_need_tp_sync:
403405
ready_request_ids_all_ranks = self.dist.tp_allgather(local_ready_request_ids)
404406
else:
@@ -407,6 +409,16 @@ def prepare_context_requests(self, requests: List[LlmRequest]):
407409
sync_size = self.dist.tp_size if self.ctx_need_tp_sync else 1
408410
ready_request_ids = _find_consensus_request_ids(ready_request_ids_all_ranks, sync_size)
409411

412+
# PP consensus: ensure all PP ranks have peer info before promoting.
413+
# In PP, the first PP rank schedules and propagates to others. If a
414+
# request is promoted on the first rank but peer info hasn't arrived
415+
# on other ranks, respond_and_send_async on those ranks would fail
416+
# to dispatch the KV transfer (gen-first skips listener dispatch).
417+
if self.ctx_need_pp_sync:
418+
ready_request_ids_pp = self.dist.pp_allgather(ready_request_ids)
419+
pp_sync_size = self.mapping.pp_size
420+
ready_request_ids = _find_consensus_request_ids(ready_request_ids_pp, pp_sync_size)
421+
410422
for request_id in ready_request_ids:
411423
self.wait_req_id_to_request[request_id].state = LlmRequestState.CONTEXT_INIT
412424
del self.wait_req_id_to_request[request_id]

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,7 @@ def _executor_loop_pp(self):
12551255
self._handle_control_request()
12561256

12571257
if self.kv_cache_transceiver:
1258+
self._check_disagg_ctx_schedulable_status(new_requests)
12581259
self._check_disagg_gen_transfer_status()
12591260

12601261
if self.enable_iter_perf_stats:

0 commit comments

Comments
 (0)