@@ -100,6 +100,7 @@ def __init__(
100100 self .mapping = mapping
101101
102102 self .ctx_need_tp_sync = mapping .tp_size > 1 and (not mapping .enable_attention_dp )
103+ self .ctx_need_pp_sync = mapping .pp_size > 1
103104
104105 self .gen_need_sync = not (
105106 mapping .world_size == 1 or (mapping .enable_attention_dp and mapping .pp_size == 1 )
@@ -381,24 +382,25 @@ def get_disaggregated_params(self) -> Dict[str, Any]:
381382
382383 def prepare_context_requests (self , requests : List [LlmRequest ]):
383384 # Place new generation-first context requests into wait state, then
384- # use tp_allgather consensus to promote ready requests to CONTEXT_INIT.
385+ # use allgather consensus to promote ready requests to CONTEXT_INIT.
385386 for req in requests :
386387 unique_rid = get_unique_rid (req )
387388 if unique_rid not in self .send_sessions :
388389 self .wait_req_id_to_request [unique_rid ] = req
389390 req .state = LlmRequestState .DISAGG_CONTEXT_WAIT_SCHEDULER
390391
391392 # Check which waiting requests have peer info locally, then use
392- # tp_allgather consensus so all TP ranks agree before promoting.
393+ # allgather consensus so all TP/PP ranks agree before promoting.
393394 # Without consensus, background peer info arriving at different
394395 # times on different ranks causes scheduling mismatches → hang.
395- # Place tp sync here because this function runs in every iteration
396+ # Place sync here because this function runs in every iteration
396397 # but check_context_transfer_status runs when can_queue is True
397398 local_ready_request_ids = []
398399 for request_id in self .wait_req_id_to_request .keys ():
399400 if self .transfer_worker .has_all_peer_req_infos_for_send (request_id ):
400401 local_ready_request_ids .append (request_id )
401402
403+ # TP consensus: ensure all TP ranks have peer info
402404 if self .ctx_need_tp_sync :
403405 ready_request_ids_all_ranks = self .dist .tp_allgather (local_ready_request_ids )
404406 else :
@@ -407,6 +409,16 @@ def prepare_context_requests(self, requests: List[LlmRequest]):
407409 sync_size = self .dist .tp_size if self .ctx_need_tp_sync else 1
408410 ready_request_ids = _find_consensus_request_ids (ready_request_ids_all_ranks , sync_size )
409411
412+ # PP consensus: ensure all PP ranks have peer info before promoting.
413+ # In PP, the first PP rank schedules and propagates to others. If a
414+ # request is promoted on the first rank but peer info hasn't arrived
415+ # on other ranks, respond_and_send_async on those ranks would fail
416+ # to dispatch the KV transfer (gen-first skips listener dispatch).
417+ if self .ctx_need_pp_sync :
418+ ready_request_ids_pp = self .dist .pp_allgather (ready_request_ids )
419+ pp_sync_size = self .mapping .pp_size
420+ ready_request_ids = _find_consensus_request_ids (ready_request_ids_pp , pp_sync_size )
421+
410422 for request_id in ready_request_ids :
411423 self .wait_req_id_to_request [request_id ].state = LlmRequestState .CONTEXT_INIT
412424 del self .wait_req_id_to_request [request_id ]
0 commit comments