Skip to content

Commit 0e8f82a

Browse files
committed
address review: move src_xfer_side_handle init out of hotpath and docs
Signed-off-by: nicklucche <[email protected]>
1 parent c59ca52 commit 0e8f82a

File tree

1 file changed

+81
-51
lines changed

1 file changed

+81
-51
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 81 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -341,14 +341,13 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
341341
self.num_regions = 0
342342
self.num_layers = 0
343343

344-
# nixl_prepped_dlist_handle. Different dst TP sizes require preparing
345-
# xfer layout differently.
346-
self.src_xfer_side_handle: dict[int, int] = dict()
344+
# nixl_prepped_dlist_handle.
345+
self.src_xfer_side_handle: int = 0
347346
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
348347
self.dst_xfer_side_handles: dict[str, int] = dict()
349348

350-
# Map of engine_id -> num_blocks. Remote TP ranks will have the same
351-
# number of blocks.
349+
# Map of engine_id -> num_blocks. All ranks in the same deployment will
350+
# have the same number of blocks.
352351
self.dst_num_blocks: dict[str, int] = dict()
353352
self._registered_descs: list[Any] = []
354353

@@ -375,8 +374,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
375374
# Optimization for models with local attention (Llama 4)
376375
# List of block window sizes for each layer for local attention
377376
self.block_window_per_layer: list[Optional[int]] = []
378-
self._tp_size = {self.engine_id: self.world_size}
379377

378+
self._tp_size: dict[str, int] = {self.engine_id: self.world_size}
380379
# With heterogeneous TP, P must wait for all assigned D TP workers to
381380
# finish reading before safely freeing the blocks.
382381
self.consumer_notification_counts_by_req = defaultdict(int)
@@ -475,15 +474,15 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
475474
block_rank = 2 # [block_size, latent_dim]
476475
block_shape = first_kv_cache.shape[-block_rank:]
477476
self.block_size, kv_latent_dim = block_shape
478-
self.kv_dim = kv_elem_size * kv_latent_dim
477+
self.slot_size_bytes = kv_elem_size * kv_latent_dim
479478
else:
480479
# [2 (k and v), num_blocks, block_size, kv_heads, head_dim]
481480
self.num_blocks = first_kv_cache.shape[1]
482481
block_rank = 3 # [block_size, kv_heads, head_dim]
483482
block_shape = first_kv_cache.shape[-block_rank:]
484483
self.block_size, n_kv_heads, head_dim = block_shape
485484
# head size in bytes.
486-
self.kv_dim = kv_elem_size * n_kv_heads * head_dim
485+
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
487486

488487
# TODO(tms): self.block_len needs to be per-layer for sliding window,
489488
# hybrid attn, etc
@@ -544,6 +543,29 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
544543
logger.debug("Done registering descs")
545544
self._registered_descs.append(descs)
546545

546+
# Register local/src descr for NIXL xfer.
547+
blocks_data = []
548+
for base_addr in self.kv_caches_base_addr[self.engine_id]:
549+
# NOTE With heter-TP, more blocks are prepared than what are
550+
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
551+
# could create fewer, but then _get_block_descs_ids needs to
552+
# select agent_meta.num_blocks instead of self.num_blocks for
553+
# local descr, and that makes handling regular flow less clean.
554+
for block_id in range(self.num_blocks):
555+
block_offset = block_id * self.block_len
556+
for slot_idx in range(self.block_size):
557+
slot_offset = slot_idx * self.slot_size_bytes
558+
addr = base_addr + block_offset + slot_offset
559+
# (addr, len, device id)
560+
blocks_data.append((addr, self.slot_size_bytes, self.rank))
561+
logger.debug("Created %s blocks for src engine %s and rank %s",
562+
len(blocks_data), self.engine_id, self.rank)
563+
564+
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
565+
# NIXL_INIT_AGENT to be used for preparations of local descs.
566+
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
567+
"NIXL_INIT_AGENT", descs)
568+
547569
# After KV Caches registered, listen for new connections.
548570
metadata = NixlAgentMetadata(
549571
engine_id=self.engine_id,
@@ -564,6 +586,42 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
564586
def add_remote_agent(self,
565587
nixl_agent_meta: NixlAgentMetadata,
566588
remote_rank: int = 0):
589+
"""
590+
Add the remote NIXL agent and prepare the descriptors for reading cache
591+
blocks from remote.
592+
593+
In particular, handle both homogeneous and heterogeneous TP. The latter
594+
requires local rank_i to read from remote rank_i.
595+
The former, assuming D.world_size > P.world_size, requires that two or
596+
more local TP worker share the xfer from a single TP worker.
597+
598+
Here's an example:
599+
600+
rank_offset p_remote_rank
601+
(kv split no)
602+
--------------------------------
603+
0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ]
604+
/
605+
1 0 Worker1 ---- 2nd half of KV -----/
606+
607+
0 1 Worker2 ---- 1st half of KV ----> Worker1 [ KV Cache ]
608+
/
609+
1 1 Worker3 ---- 2nd half of KV -----/
610+
611+
612+
Decoder TP workers Prefix TP workers
613+
(world_size=4) (world_size=2)
614+
tp_ratio = 4 // 2 = 2
615+
616+
Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, block_size, kv_heads, head_dim]
617+
then D-Worker_j has [2, num_blocksD, block_size, kv_heads//tp_ratio, head_dim].
618+
Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio
619+
first heads from all the slots of all the blocks in the case.
620+
D-Worker1 will do the same, but reading the second split along the kv_heads dimension.
621+
622+
Note that the above will also hold true for the homogeneous TP case.
623+
""" # noqa: E501
624+
567625
engine_id = nixl_agent_meta.engine_id
568626
# TODO re-evaluate refreshing for scaling/recovery
569627
if (engine_id in self._remote_agents and \
@@ -577,43 +635,18 @@ def add_remote_agent(self,
577635
remote_rank] = self.nixl_wrapper.add_remote_agent(
578636
nixl_agent_meta.agent_metadata)
579637

580-
d_workers_per_p_worker = self._tp_size[
581-
self.engine_id] // self._tp_size[engine_id]
582-
assert d_workers_per_p_worker > 0, "Decode TP cannot be smaller than"
638+
# Number of D TP workers reading from a single P TP worker. This is
639+
# 1 when P and D `--tensor-parallel-size` match.
640+
tp_ratio = self._tp_size[self.engine_id] // self._tp_size[engine_id]
641+
assert tp_ratio > 0, "Decode TP cannot be smaller than"
583642
" prefill TP"
584643

585644
# TODO we should also check hidden_dim and kv precision, they must match
586-
remote_block_size = nixl_agent_meta.block_len / (
587-
self.kv_dim * d_workers_per_p_worker)
645+
remote_block_size = nixl_agent_meta.block_len / (self.slot_size_bytes *
646+
tp_ratio)
588647
assert self.block_size == remote_block_size, "Remote P worker with "
589648
"different block size is not supported"
590649

591-
# Create src descs and xfer side handles.
592-
if d_workers_per_p_worker not in self.src_xfer_side_handle:
593-
blocks_data = []
594-
for base_addr in self.kv_caches_base_addr[self.engine_id]:
595-
# NOTE With heter-TP, more blocks are prepared than what are
596-
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
597-
# could create fewer, but then _get_block_descs_ids needs to
598-
# select agent_meta.num_blocks instead of self.num_blocks for
599-
# local descr, and that makes handling regular flow less clean.
600-
for block_id in range(self.num_blocks):
601-
block_offset = block_id * self.block_len
602-
for b in range(self.block_size):
603-
head_offset = b * self.kv_dim
604-
addr = base_addr + block_offset + head_offset
605-
# (addr, len, device id)
606-
blocks_data.append((addr, self.kv_dim, self.rank))
607-
logger.debug("Created %s blocks for src engine %s and rank %s",
608-
len(blocks_data), self.engine_id, self.rank)
609-
610-
# Register with NIXL.
611-
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
612-
# NIXL_INIT_AGENT to be used for preparations of local descs.
613-
self.src_xfer_side_handle[
614-
d_workers_per_p_worker] = self.nixl_wrapper.prep_xfer_dlist(
615-
"NIXL_INIT_AGENT", descs)
616-
617650
# Create dst descs and xfer side handles. TP workers have same #blocks.
618651
if engine_id in self.dst_num_blocks:
619652
assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks
@@ -625,23 +658,23 @@ def add_remote_agent(self,
625658
# rank. With heterogeneous TP, prepare the descriptors by splitting the
626659
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
627660
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
628-
p_remote_rank = self.rank // d_workers_per_p_worker
661+
p_remote_rank = self.rank // tp_ratio
629662
# Only register the remote's descriptors if current rank pulls from it.
630663
if p_remote_rank == remote_rank:
631664
self.kv_caches_base_addr[
632665
engine_id] = nixl_agent_meta.kv_caches_base_addr
633-
rank_offset = self.rank % d_workers_per_p_worker * self.kv_dim
666+
rank_offset = self.rank % tp_ratio * self.slot_size_bytes
634667
# Register all remote blocks, but only the corresponding kv heads.
635668
for base_addr in nixl_agent_meta.kv_caches_base_addr:
636669
for block_id in range(nixl_agent_meta.num_blocks):
637670
block_offset = block_id * nixl_agent_meta.block_len
638-
for b in range(self.block_size):
639-
# Remote kv_dim = local kv_dim * d_workers_per_p_worker
640-
head_offset = b * self.kv_dim * d_workers_per_p_worker
641-
addr = base_addr + block_offset + head_offset
671+
for slot_idx in range(self.block_size):
672+
# Remote has `tp_ratio` times the kv_heads of local.
673+
slot_offset = slot_idx * self.slot_size_bytes * tp_ratio
674+
addr = base_addr + block_offset + slot_offset
642675
# (addr, len, device id)
643-
blocks_data.append(
644-
(addr + rank_offset, self.kv_dim, remote_rank))
676+
blocks_data.append((addr + rank_offset,
677+
self.slot_size_bytes, remote_rank))
645678
logger.debug(
646679
"Created %s blocks for dst engine %s with remote rank %s and " \
647680
"local rank %s",
@@ -826,15 +859,12 @@ def _read_blocks(
826859

827860
# Get side handles.
828861
local_xfer_side_handle = self.src_xfer_side_handle
862+
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
829863

830864
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
831865
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
832866
# workers will issue xfers to parts of the P worker remote kv caches.
833867

834-
# Get side handles.
835-
local_xfer_side_handle = self.src_xfer_side_handle[tp_ratio]
836-
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
837-
838868
# Get descs ids.
839869
local_block_descs_ids: list[int] = []
840870
remote_block_descs_ids: list[int] = []

0 commit comments

Comments
 (0)