@@ -341,14 +341,13 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
341341 self .num_regions = 0
342342 self .num_layers = 0
343343
344- # nixl_prepped_dlist_handle. Different dst TP sizes require preparing
345- # xfer layout differently.
346- self .src_xfer_side_handle : dict [int , int ] = dict ()
344+ # nixl_prepped_dlist_handle.
345+ self .src_xfer_side_handle : int = 0
347346 # Map of engine_id -> nixl_prepped_dlist_handle (int)].
348347 self .dst_xfer_side_handles : dict [str , int ] = dict ()
349348
350- # Map of engine_id -> num_blocks. Remote TP ranks will have the same
351- # number of blocks.
349+ # Map of engine_id -> num_blocks. All ranks in the same deployment will
350+ # have the same number of blocks.
352351 self .dst_num_blocks : dict [str , int ] = dict ()
353352 self ._registered_descs : list [Any ] = []
354353
@@ -375,8 +374,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
375374 # Optimization for models with local attention (Llama 4)
376375 # List of block window sizes for each layer for local attention
377376 self .block_window_per_layer : list [Optional [int ]] = []
378- self ._tp_size = {self .engine_id : self .world_size }
379377
378+ self ._tp_size : dict [str , int ] = {self .engine_id : self .world_size }
380379 # With heterogeneous TP, P must wait for all assigned D TP workers to
381380 # finish reading before safely freeing the blocks.
382381 self .consumer_notification_counts_by_req = defaultdict (int )
@@ -475,15 +474,15 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
475474 block_rank = 2 # [block_size, latent_dim]
476475 block_shape = first_kv_cache .shape [- block_rank :]
477476 self .block_size , kv_latent_dim = block_shape
478- self .kv_dim = kv_elem_size * kv_latent_dim
477+ self .slot_size_bytes = kv_elem_size * kv_latent_dim
479478 else :
480479 # [2 (k and v), num_blocks, block_size, kv_heads, head_dim]
481480 self .num_blocks = first_kv_cache .shape [1 ]
482481 block_rank = 3 # [block_size, kv_heads, head_dim]
483482 block_shape = first_kv_cache .shape [- block_rank :]
484483 self .block_size , n_kv_heads , head_dim = block_shape
485484 # head size in bytes.
486- self .kv_dim = kv_elem_size * n_kv_heads * head_dim
485+ self .slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
487486
488487 # TODO(tms): self.block_len needs to be per-layer for sliding window,
489488 # hybrid attn, etc
@@ -544,6 +543,29 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
544543 logger .debug ("Done registering descs" )
545544 self ._registered_descs .append (descs )
546545
546+ # Register local/src descr for NIXL xfer.
547+ blocks_data = []
548+ for base_addr in self .kv_caches_base_addr [self .engine_id ]:
549+ # NOTE With heter-TP, more blocks are prepared than what are
550+ # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
551+ # could create fewer, but then _get_block_descs_ids needs to
552+ # select agent_meta.num_blocks instead of self.num_blocks for
553+ # local descr, and that makes handling regular flow less clean.
554+ for block_id in range (self .num_blocks ):
555+ block_offset = block_id * self .block_len
556+ for slot_idx in range (self .block_size ):
557+ slot_offset = slot_idx * self .slot_size_bytes
558+ addr = base_addr + block_offset + slot_offset
559+ # (addr, len, device id)
560+ blocks_data .append ((addr , self .slot_size_bytes , self .rank ))
561+ logger .debug ("Created %s blocks for src engine %s and rank %s" ,
562+ len (blocks_data ), self .engine_id , self .rank )
563+
564+ descs = self .nixl_wrapper .get_xfer_descs (blocks_data , "VRAM" )
565+ # NIXL_INIT_AGENT to be used for preparations of local descs.
566+ self .src_xfer_side_handle = self .nixl_wrapper .prep_xfer_dlist (
567+ "NIXL_INIT_AGENT" , descs )
568+
547569 # After KV Caches registered, listen for new connections.
548570 metadata = NixlAgentMetadata (
549571 engine_id = self .engine_id ,
@@ -564,6 +586,42 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
564586 def add_remote_agent (self ,
565587 nixl_agent_meta : NixlAgentMetadata ,
566588 remote_rank : int = 0 ):
589+ """
590+ Add the remote NIXL agent and prepare the descriptors for reading cache
591+ blocks from remote.
592+
593+ In particular, handle both homogeneous and heterogeneous TP. The latter
594+ requires local rank_i to read from remote rank_i.
595+ The former, assuming D.world_size > P.world_size, requires that two or
596+ more local TP worker share the xfer from a single TP worker.
597+
598+ Here's an example:
599+
600+ rank_offset p_remote_rank
601+ (kv split no)
602+ --------------------------------
603+ 0 0 Worker0 ---- 1st half of KV ----> Worker0 [ KV Cache ]
604+ /
605+ 1 0 Worker1 ---- 2nd half of KV -----/
606+
607+ 0 1 Worker2 ---- 1st half of KV ----> Worker1 [ KV Cache ]
608+ /
609+ 1 1 Worker3 ---- 2nd half of KV -----/
610+
611+
612+ Decoder TP workers Prefix TP workers
613+ (world_size=4) (world_size=2)
614+ tp_ratio = 4 // 2 = 2
615+
616+ Considering the KV Caches, if P-Worker_i has cache size [2, num_blocksP, block_size, kv_heads, head_dim]
617+ then D-Worker_j has [2, num_blocksD, block_size, kv_heads//tp_ratio, head_dim].
618+ Assuming num_blocksD >= num_blocksP, D-Worker0 reads from P-Worker0 by preparing the kv_heads//tp_ratio
619+ first heads from all the slots of all the blocks in the case.
620+ D-Worker1 will do the same, but reading the second split along the kv_heads dimension.
621+
622+ Note that the above will also hold true for the homogeneous TP case.
623+ """ # noqa: E501
624+
567625 engine_id = nixl_agent_meta .engine_id
568626 # TODO re-evaluate refreshing for scaling/recovery
569627 if (engine_id in self ._remote_agents and \
@@ -577,43 +635,18 @@ def add_remote_agent(self,
577635 remote_rank ] = self .nixl_wrapper .add_remote_agent (
578636 nixl_agent_meta .agent_metadata )
579637
580- d_workers_per_p_worker = self ._tp_size [
581- self .engine_id ] // self ._tp_size [engine_id ]
582- assert d_workers_per_p_worker > 0 , "Decode TP cannot be smaller than"
638+ # Number of D TP workers reading from a single P TP worker. This is
639+ # 1 when P and D `--tensor-parallel-size` match.
640+ tp_ratio = self ._tp_size [self .engine_id ] // self ._tp_size [engine_id ]
641+ assert tp_ratio > 0 , "Decode TP cannot be smaller than"
583642 " prefill TP"
584643
585644 # TODO we should also check hidden_dim and kv precision, they must match
586- remote_block_size = nixl_agent_meta .block_len / (
587- self . kv_dim * d_workers_per_p_worker )
645+ remote_block_size = nixl_agent_meta .block_len / (self . slot_size_bytes *
646+ tp_ratio )
588647 assert self .block_size == remote_block_size , "Remote P worker with "
589648 "different block size is not supported"
590649
591- # Create src descs and xfer side handles.
592- if d_workers_per_p_worker not in self .src_xfer_side_handle :
593- blocks_data = []
594- for base_addr in self .kv_caches_base_addr [self .engine_id ]:
595- # NOTE With heter-TP, more blocks are prepared than what are
596- # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
597- # could create fewer, but then _get_block_descs_ids needs to
598- # select agent_meta.num_blocks instead of self.num_blocks for
599- # local descr, and that makes handling regular flow less clean.
600- for block_id in range (self .num_blocks ):
601- block_offset = block_id * self .block_len
602- for b in range (self .block_size ):
603- head_offset = b * self .kv_dim
604- addr = base_addr + block_offset + head_offset
605- # (addr, len, device id)
606- blocks_data .append ((addr , self .kv_dim , self .rank ))
607- logger .debug ("Created %s blocks for src engine %s and rank %s" ,
608- len (blocks_data ), self .engine_id , self .rank )
609-
610- # Register with NIXL.
611- descs = self .nixl_wrapper .get_xfer_descs (blocks_data , "VRAM" )
612- # NIXL_INIT_AGENT to be used for preparations of local descs.
613- self .src_xfer_side_handle [
614- d_workers_per_p_worker ] = self .nixl_wrapper .prep_xfer_dlist (
615- "NIXL_INIT_AGENT" , descs )
616-
617650 # Create dst descs and xfer side handles. TP workers have same #blocks.
618651 if engine_id in self .dst_num_blocks :
619652 assert self .dst_num_blocks [engine_id ] == nixl_agent_meta .num_blocks
@@ -625,23 +658,23 @@ def add_remote_agent(self,
625658 # rank. With heterogeneous TP, prepare the descriptors by splitting the
626659 # P KV cache along kv_head dim, of D worker's kv_head size (D>P).
627660 # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
628- p_remote_rank = self .rank // d_workers_per_p_worker
661+ p_remote_rank = self .rank // tp_ratio
629662 # Only register the remote's descriptors if current rank pulls from it.
630663 if p_remote_rank == remote_rank :
631664 self .kv_caches_base_addr [
632665 engine_id ] = nixl_agent_meta .kv_caches_base_addr
633- rank_offset = self .rank % d_workers_per_p_worker * self .kv_dim
666+ rank_offset = self .rank % tp_ratio * self .slot_size_bytes
634667 # Register all remote blocks, but only the corresponding kv heads.
635668 for base_addr in nixl_agent_meta .kv_caches_base_addr :
636669 for block_id in range (nixl_agent_meta .num_blocks ):
637670 block_offset = block_id * nixl_agent_meta .block_len
638- for b in range (self .block_size ):
639- # Remote kv_dim = local kv_dim * d_workers_per_p_worker
640- head_offset = b * self .kv_dim * d_workers_per_p_worker
641- addr = base_addr + block_offset + head_offset
671+ for slot_idx in range (self .block_size ):
672+ # Remote has `tp_ratio` times the kv_heads of local.
673+ slot_offset = slot_idx * self .slot_size_bytes * tp_ratio
674+ addr = base_addr + block_offset + slot_offset
642675 # (addr, len, device id)
643- blocks_data .append (
644- ( addr + rank_offset , self .kv_dim , remote_rank ))
676+ blocks_data .append (( addr + rank_offset ,
677+ self .slot_size_bytes , remote_rank ))
645678 logger .debug (
646679 "Created %s blocks for dst engine %s with remote rank %s and " \
647680 "local rank %s" ,
@@ -826,15 +859,12 @@ def _read_blocks(
826859
827860 # Get side handles.
828861 local_xfer_side_handle = self .src_xfer_side_handle
862+ remote_xfer_side_handle = self .dst_xfer_side_handles [dst_engine_id ]
829863
830864 # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
831865 # corresponding rank. With heterogeneous TP, fixing D>P, the D tp
832866 # workers will issue xfers to parts of the P worker remote kv caches.
833867
834- # Get side handles.
835- local_xfer_side_handle = self .src_xfer_side_handle [tp_ratio ]
836- remote_xfer_side_handle = self .dst_xfer_side_handles [dst_engine_id ]
837-
838868 # Get descs ids.
839869 local_block_descs_ids : list [int ] = []
840870 remote_block_descs_ids : list [int ] = []
0 commit comments