Skip to content

Commit abc710a

Browse files
committed
fix req_finished sync bug
Signed-off-by: nicklucche <[email protected]>
1 parent 8274635 commit abc710a

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -773,8 +773,8 @@ def _get_new_notifs(self) -> set[str]:
773773
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
774774
self.consumer_notification_counts_by_req[req_id] += 1
775775
# Wait all consumers (D) to be done reading before freeing.
776-
if self.consumer_notification_counts_by_req[
777-
req_id] == tp_ratio:
776+
if self.consumer_notification_counts_by_req[req_id] == int(
777+
tp_ratio):
778778
notified_req_ids.add(req_id)
779779
del self.consumer_notification_counts_by_req[req_id]
780780
return notified_req_ids
@@ -904,18 +904,20 @@ def _read_blocks(
904904
def _get_block_descs_ids(self, engine_id: str,
905905
block_ids: list[int]) -> list[int]:
906906
"""Get the descs ids for a set of block ids."""
907+
# TODO docs
907908

908909
# range(1) for MLA, range(2) otherwise.
909910
region_ids = range(self.num_regions)
911+
# TODO using a diff num of blocks here in dst and src
910912
num_blocks = self.dst_num_blocks[engine_id]
911913

912914
# Compute the desc ids for each block.
913915
descs_ids: list[int] = []
914916
for reg_id in region_ids:
915917
for block_id in block_ids:
916-
for kv_block in range(self.block_size):
918+
for slot_id in range(self.block_size):
917919
descs_ids.append(reg_id * num_blocks * self.block_size +
918-
block_id * self.block_size + kv_block)
920+
block_id * self.block_size + slot_id)
919921
return descs_ids
920922

921923

0 commit comments

Comments
 (0)