Skip to content

Commit 9ed29aa

Browse files
committed
address race condition
Signed-off-by: nicklucche <[email protected]>
1 parent 20ca3aa commit 9ed29aa

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
377377
self.block_window_per_layer: list[Optional[int]] = []
378378
self._tp_size = {self.engine_id: self.world_size}
379379

380+
# With heterogeneous TP, P must wait for all assigned D TP workers to
381+
# finish reading before safely freeing the blocks.
382+
self.consumer_notification_counts_by_req = defaultdict(int)
383+
380384
@staticmethod
381385
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
382386
ready_event: threading.Event, rank: int):
@@ -718,10 +722,15 @@ def _get_new_notifs(self) -> set[str]:
718722
"""Get req_ids which got a remote xfer message."""
719723

720724
notified_req_ids: set[str] = set()
721-
for req_ids in self.nixl_wrapper.get_new_notifs().values():
722-
for req_id in req_ids:
723-
assert req_id not in notified_req_ids
724-
notified_req_ids.add(req_id.decode("utf-8"))
725+
for notifs in self.nixl_wrapper.get_new_notifs().values():
726+
for notif in notifs:
727+
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
728+
self.consumer_notification_counts_by_req[req_id] += 1
729+
# Wait all consumers (D) to be done reading before freeing.
730+
if self.consumer_notification_counts_by_req[
731+
req_id] == tp_ratio:
732+
notified_req_ids.add(req_id)
733+
del self.consumer_notification_counts_by_req[req_id]
725734
return notified_req_ids
726735

727736
def _pop_done_transfers(self, transfers: dict[str, list[int]]) -> set[str]:
@@ -796,12 +805,17 @@ def _read_blocks(
796805
# saturate IB with heterogeneous TP sizes. We should remove the staging
797806
# blocks until we are ready.
798807

808+
# Number of D TP workers that will read from dst P. Propagate tp_ratio
809+
# on notification so that dst worker can wait before freeing blocks.
810+
tp_ratio = self._tp_size[
811+
self.engine_id] // self._tp_size[dst_engine_id]
812+
notif_id = f"{request_id}:{tp_ratio}".encode()
813+
799814
# Full prefix cache hit: do not need to read remote blocks,
800815
# just notify P worker that we have the blocks we need.
801816
num_local_blocks = len(local_block_ids)
802817
if num_local_blocks == 0:
803-
self.nixl_wrapper.send_notif(dst_engine_id,
804-
notif_msg=request_id.encode("utf-8"))
818+
self.nixl_wrapper.send_notif(dst_engine_id, notif_msg=notif_id)
805819
return
806820

807821
# Partial prefix cache hit: just read uncomputed blocks.
@@ -818,10 +832,7 @@ def _read_blocks(
818832
# workers will issue xfers to parts of the P worker remote kv caches.
819833

820834
# Get side handles.
821-
d_workers_per_p_worker = self._tp_size[
822-
self.engine_id] // self._tp_size[dst_engine_id]
823-
local_xfer_side_handle = self.src_xfer_side_handle[
824-
d_workers_per_p_worker]
835+
local_xfer_side_handle = self.src_xfer_side_handle[tp_ratio]
825836
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
826837

827838
# Get descs ids.
@@ -867,7 +878,7 @@ def _read_blocks(
867878
local_block_descs_ids,
868879
remote_xfer_side_handle,
869880
remote_block_descs_ids,
870-
notif_msg=request_id.encode("utf-8"),
881+
notif_msg=notif_id,
871882
)
872883

873884
# Begin async xfer.

0 commit comments

Comments
 (0)