@@ -377,6 +377,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
377377 self .block_window_per_layer : list [Optional [int ]] = []
378378 self ._tp_size = {self .engine_id : self .world_size }
379379
380+ # With heterogeneous TP, P must wait for all assigned D TP workers to
381+ # finish reading before safely freeing the blocks.
382+ self .consumer_notification_counts_by_req = defaultdict (int )
383+
380384 @staticmethod
381385 def _nixl_handshake_listener (metadata : NixlAgentMetadata ,
382386 ready_event : threading .Event , rank : int ):
@@ -718,10 +722,15 @@ def _get_new_notifs(self) -> set[str]:
718722 """Get req_ids which got a remote xfer message."""
719723
720724 notified_req_ids : set [str ] = set ()
721- for req_ids in self .nixl_wrapper .get_new_notifs ().values ():
722- for req_id in req_ids :
723- assert req_id not in notified_req_ids
724- notified_req_ids .add (req_id .decode ("utf-8" ))
725+ for notifs in self .nixl_wrapper .get_new_notifs ().values ():
726+ for notif in notifs :
727+ req_id , tp_ratio = notif .decode ("utf-8" ).rsplit (":" , 1 )
728+ self .consumer_notification_counts_by_req [req_id ] += 1
729+ # Wait all consumers (D) to be done reading before freeing.
730+ if self .consumer_notification_counts_by_req [
731+ req_id ] == tp_ratio :
732+ notified_req_ids .add (req_id )
733+ del self .consumer_notification_counts_by_req [req_id ]
725734 return notified_req_ids
726735
727736 def _pop_done_transfers (self , transfers : dict [str , list [int ]]) -> set [str ]:
@@ -796,12 +805,17 @@ def _read_blocks(
796805 # saturate IB with heterogeneous TP sizes. We should remove the staging
797806 # blocks until we are ready.
798807
808+ # Number of D TP workers that will read from dst P. Propagate tp_ratio
809+ # on notification so that dst worker can wait before freeing blocks.
810+ tp_ratio = self ._tp_size [
811+ self .engine_id ] // self ._tp_size [dst_engine_id ]
812+ notif_id = f"{ request_id } :{ tp_ratio } " .encode ()
813+
799814 # Full prefix cache hit: do not need to read remote blocks,
800815 # just notify P worker that we have the blocks we need.
801816 num_local_blocks = len (local_block_ids )
802817 if num_local_blocks == 0 :
803- self .nixl_wrapper .send_notif (dst_engine_id ,
804- notif_msg = request_id .encode ("utf-8" ))
818+ self .nixl_wrapper .send_notif (dst_engine_id , notif_msg = notif_id )
805819 return
806820
807821 # Partial prefix cache hit: just read uncomputed blocks.
@@ -818,10 +832,7 @@ def _read_blocks(
818832 # workers will issue xfers to parts of the P worker remote kv caches.
819833
820834 # Get side handles.
821- d_workers_per_p_worker = self ._tp_size [
822- self .engine_id ] // self ._tp_size [dst_engine_id ]
823- local_xfer_side_handle = self .src_xfer_side_handle [
824- d_workers_per_p_worker ]
835+ local_xfer_side_handle = self .src_xfer_side_handle [tp_ratio ]
825836 remote_xfer_side_handle = self .dst_xfer_side_handles [dst_engine_id ]
826837
827838 # Get descs ids.
@@ -867,7 +878,7 @@ def _read_blocks(
867878 local_block_descs_ids ,
868879 remote_xfer_side_handle ,
869880 remote_block_descs_ids ,
870- notif_msg = request_id . encode ( "utf-8" ) ,
881+ notif_msg = notif_id ,
871882 )
872883
873884 # Begin async xfer.
0 commit comments