@@ -53,6 +53,8 @@ class NixlAgentMetadata(
5353 agent_metadata : bytes
5454 kv_caches_base_addr : list [int ]
5555 num_blocks : int
56+ tp_size : int
57+ block_len : int
5658
5759
5860@dataclass
@@ -318,8 +320,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
318320
319321 # Agent.
320322 self .nixl_wrapper = NixlWrapper (str (uuid .uuid4 ()), None )
321- # Map of engine_id -> agent_name .
322- self ._remote_agents : dict [str , str ] = {}
323+ # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..} .
324+ self ._remote_agents : dict [str , dict [ int , str ]] = defaultdict ( dict )
323325
324326 # Metadata.
325327 self .engine_id = engine_id
@@ -330,21 +332,24 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
330332 # KV Caches and nixl tracking data.
331333 self .kv_caches : dict [str , torch .Tensor ] = {}
332334
333- # Map of engine_id -> kv_caches_base_addr
334- self .kv_caches_base_addr : dict [str , list [int ]] = {}
335+ # Map of engine_id -> kv_caches_base_addr. For TP case, each local
336+ # rank will still only pull from a single remote TP worker.
337+ self .kv_caches_base_addr : dict [str , list [int ]] = dict ()
335338
336339 # Number of NIXL regions. Currently one region per cache
337340 # (so 1 per layer for MLA, otherwise 2 per layer)
338341 self .num_regions = 0
339342 self .num_layers = 0
340343
341- # nixl_prepped_dlist_handle (int).
342- self .src_xfer_side_handle : int = 0
344+ # nixl_prepped_dlist_handle. Different dst TP sizes require preparing
345+ # xfer layout differently.
346+ self .src_xfer_side_handle : dict [int , int ] = dict ()
343347 # Map of engine_id -> nixl_prepped_dlist_handle (int)].
344- self .dst_xfer_side_handles : dict [str , int ] = {}
348+ self .dst_xfer_side_handles : dict [str , int ] = dict ()
345349
346- # Map of engine_id -> num_blocks.
347- self .dst_num_blocks : dict [str , int ] = {}
350+ # Map of engine_id -> num_blocks. Remote TP ranks will have the same
351+ # number of blocks.
352+ self .dst_num_blocks : dict [str , int ] = dict ()
348353 self ._registered_descs : list [Any ] = []
349354
350355 # In progress transfers.
@@ -370,6 +375,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
370375 # Optimization for models with local attention (Llama 4)
371376 # List of block window sizes for each layer for local attention
372377 self .block_window_per_layer : list [Optional [int ]] = []
378+ self ._tp_size = {self .engine_id : self .world_size }
373379
374380 @staticmethod
375381 def _nixl_handshake_listener (metadata : NixlAgentMetadata ,
@@ -410,12 +416,12 @@ def _nixl_handshake(self, host: str, port: int):
410416 """Do a NIXL handshake with a remote instance."""
411417
412418 start_time = time .perf_counter ()
419+
413420 # NOTE(rob): we need each rank to have a unique port. This is
414421 # a hack to keep us moving. We will switch when moving to etcd
415422 # or where we have a single ZMQ socket in the scheduler.
416- path = make_zmq_path ("tcp" , host , port + self .rank )
417- logger .debug ("Querying metadata on path: %s" , path )
418- with zmq_ctx (zmq .REQ , path ) as sock :
423+
424+ def handshake (sock , rank : int ) -> NixlAgentMetadata :
419425 # Send query for the request.
420426 sock .send (GET_META_MSG )
421427 metadata_bytes = sock .recv ()
@@ -424,13 +430,32 @@ def _nixl_handshake(self, host: str, port: int):
424430 got_metadata_time = time .perf_counter ()
425431
426432 # Register Remote agent.
427- self .add_remote_agent (metadata )
433+ self .add_remote_agent (metadata , rank )
428434 setup_agent_time = time .perf_counter ()
429435
430436 logger .debug ("NIXL handshake: get metadata took: %s" ,
431437 got_metadata_time - start_time )
432438 logger .debug ("NIXL handshake: add agent took: %s" ,
433439 setup_agent_time - got_metadata_time )
440+ return metadata
441+
442+ # Handshake with remote agent-rank0 first to get the tp_size of remote
443+ path = f"tcp://{ host } :{ port } "
444+ logger .debug ("Querying master rank metadata on path: %s" , path )
445+ with zmq_ctx (zmq .REQ , path ) as sock :
446+ metadata = handshake (sock , 0 )
447+
448+ # Handshake only with the other TP remote the current local rank will
449+ # pull from. With homogeneous TP it happens to be the same rank_i.
450+ d_workers_per_p_worker = self ._tp_size [
451+ self .engine_id ] // metadata .tp_size
452+ p_remote_rank = self .rank // d_workers_per_p_worker
453+ if p_remote_rank > 0 :
454+ path = f"tcp://{ host } :{ port + p_remote_rank } "
455+ logger .debug ("Querying metadata on path: %s at remote rank %s" ,
456+ path , p_remote_rank )
457+ with zmq_ctx (zmq .REQ , path ) as sock :
458+ metadata = handshake (sock , p_remote_rank )
434459
435460 def register_kv_caches (self , kv_caches : dict [str , torch .Tensor ]):
436461 """Register the KV Cache data in nixl."""
@@ -445,14 +470,20 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
445470 self .num_blocks = first_kv_cache .shape [0 ]
446471 block_rank = 2 # [block_size, latent_dim]
447472 block_shape = first_kv_cache .shape [- block_rank :]
473+ self .block_size , kv_latent_dim = block_shape
474+ self .kv_dim = kv_elem_size * kv_latent_dim
448475 else :
449- # [2 (k and v), num_blocks, ... ]
476+ # [2 (k and v), num_blocks, block_size, kv_heads, head_dim ]
450477 self .num_blocks = first_kv_cache .shape [1 ]
451478 block_rank = 3 # [block_size, kv_heads, head_dim]
452479 block_shape = first_kv_cache .shape [- block_rank :]
480+ self .block_size , n_kv_heads , head_dim = block_shape
481+ # head size in bytes.
482+ self .kv_dim = kv_elem_size * n_kv_heads * head_dim
453483
454484 # TODO(tms): self.block_len needs to be per-layer for sliding window,
455485 # hybrid attn, etc
486+ # block size in bytes
456487 self .block_len = kv_elem_size * math .prod (block_shape )
457488
458489 logger .debug ("Registering KV_Caches. use_mla: %s, shape %s" , use_mla ,
@@ -507,7 +538,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
507538 logger .debug ("Registering descs: %s" , caches_data )
508539 self .nixl_wrapper .register_memory (descs )
509540 logger .debug ("Done registering descs" )
510-
511541 self ._registered_descs .append (descs )
512542
513543 # After KV Caches registered, listen for new connections.
@@ -516,7 +546,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
516546 agent_metadata = self .nixl_wrapper .get_agent_metadata (),
517547 kv_caches_base_addr = self .kv_caches_base_addr [self .engine_id ],
518548 num_blocks = self .num_blocks ,
519- )
549+ tp_size = self .world_size ,
550+ block_len = self .block_len )
520551 ready_event = threading .Event ()
521552 self ._nixl_handshake_listener_t = threading .Thread (
522553 target = self ._nixl_handshake_listener ,
@@ -526,49 +557,97 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
526557 self ._nixl_handshake_listener_t .start ()
527558 ready_event .wait ()
528559
529- def add_remote_agent (self , nixl_agent_meta : NixlAgentMetadata ):
560+ def add_remote_agent (self ,
561+ nixl_agent_meta : NixlAgentMetadata ,
562+ remote_rank : int = 0 ):
530563 engine_id = nixl_agent_meta .engine_id
531- if engine_id in self ._remote_agents :
564+ # TODO re-evaluate refreshing for scaling/recovery
565+ if (engine_id in self ._remote_agents and \
566+ remote_rank in self ._remote_agents [engine_id ]):
532567 return
533568
534- self ._remote_agents [engine_id ] = self .nixl_wrapper .add_remote_agent (
535- nixl_agent_meta .agent_metadata )
536- self .kv_caches_base_addr [
537- engine_id ] = nixl_agent_meta .kv_caches_base_addr
569+ if engine_id in self ._tp_size :
570+ assert self ._tp_size [engine_id ] == nixl_agent_meta .tp_size
571+ self ._tp_size [engine_id ] = nixl_agent_meta .tp_size
572+ self ._remote_agents [engine_id ][
573+ remote_rank ] = self .nixl_wrapper .add_remote_agent (
574+ nixl_agent_meta .agent_metadata )
575+
576+ d_workers_per_p_worker = self ._tp_size [
577+ self .engine_id ] // self ._tp_size [engine_id ]
578+ assert d_workers_per_p_worker > 0 , "Decode TP cannot be smaller than"
579+ " prefill TP"
580+
581+ # TODO we should also check hidden_dim and kv precision, they must match
582+ remote_block_size = nixl_agent_meta .block_len / (
583+ self .kv_dim * d_workers_per_p_worker )
584+ assert self .block_size == remote_block_size , "Remote P worker with "
585+ "different block size is not supported"
538586
539587 # Create src descs and xfer side handles.
540- blocks_data = []
541- for base_addr in self .kv_caches_base_addr [self .engine_id ]:
542- for block_id in range (self .num_blocks ):
543- block_offset = block_id * self .block_len
544- # (addr, len, device id)
545- blocks_data .append (
546- (base_addr + block_offset , self .block_len , self .rank ))
547- logger .debug ("Created %s blocks for src engine %s and rank %s" ,
548- len (blocks_data ), self .engine_id , self .rank )
549-
550- # Register with NIXL.
551- descs = self .nixl_wrapper .get_xfer_descs (blocks_data , "VRAM" )
552- self .src_xfer_side_handle = self .nixl_wrapper .prep_xfer_dlist (
553- "NIXL_INIT_AGENT" , descs )
554-
555- # Create dst descs and xfer side handles.
588+ if d_workers_per_p_worker not in self .src_xfer_side_handle :
589+ blocks_data = []
590+ for base_addr in self .kv_caches_base_addr [self .engine_id ]:
591+ # NOTE With heter-TP, more blocks are prepared than what are
592+ # needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
593+ # could create fewer, but then _get_block_descs_ids needs to
594+ # select agent_meta.num_blocks instead of self.num_blocks for
595+ # local descr, and that makes handling regular flow less clean.
596+ for block_id in range (self .num_blocks ):
597+ block_offset = block_id * self .block_len
598+ for b in range (self .block_size ):
599+ head_offset = b * self .kv_dim
600+ addr = base_addr + block_offset + head_offset
601+ # (addr, len, device id)
602+ blocks_data .append ((addr , self .kv_dim , self .rank ))
603+ logger .debug ("Created %s blocks for src engine %s and rank %s" ,
604+ len (blocks_data ), self .engine_id , self .rank )
605+
606+ # Register with NIXL.
607+ descs = self .nixl_wrapper .get_xfer_descs (blocks_data , "VRAM" )
608+ # NIXL_INIT_AGENT to be used for preparations of local descs.
609+ self .src_xfer_side_handle [
610+ d_workers_per_p_worker ] = self .nixl_wrapper .prep_xfer_dlist (
611+ "NIXL_INIT_AGENT" , descs )
612+
613+ # Create dst descs and xfer side handles. TP workers have same #blocks.
614+ if engine_id in self .dst_num_blocks :
615+ assert self .dst_num_blocks [engine_id ] == nixl_agent_meta .num_blocks
616+
556617 self .dst_num_blocks [engine_id ] = nixl_agent_meta .num_blocks
618+
557619 blocks_data = []
558- for base_addr in self .kv_caches_base_addr [engine_id ]:
559- for block_id in range (nixl_agent_meta .num_blocks ):
560- block_offset = block_id * self .block_len
561- # (addr, len, device id)
562- blocks_data .append (
563- (base_addr + block_offset , self .block_len , self .rank ))
564- logger .debug ("Created %s blocks for dst engine %s and rank %s" ,
565- len (blocks_data ), engine_id , self .rank )
566-
567- # Register with NIXL.
568- descs = self .nixl_wrapper .get_xfer_descs (blocks_data , "VRAM" )
569- self .dst_xfer_side_handles [
570- engine_id ] = self .nixl_wrapper .prep_xfer_dlist (
571- self ._remote_agents [engine_id ], descs )
620+ # With homogeneous TP, D pulls the whole kv cache from corresponding
621+ # rank. With heterogeneous TP, prepare the descriptors by splitting the
622+ # P KV cache along kv_head dim, of D worker's kv_head size (D>P).
623+ # Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
624+ p_remote_rank = self .rank // d_workers_per_p_worker
625+ # Only register the remote's descriptors if current rank pulls from it.
626+ if p_remote_rank == remote_rank :
627+ self .kv_caches_base_addr [
628+ engine_id ] = nixl_agent_meta .kv_caches_base_addr
629+ rank_offset = self .rank % d_workers_per_p_worker * self .kv_dim
630+ # Register all remote blocks, but only the corresponding kv heads.
631+ for base_addr in nixl_agent_meta .kv_caches_base_addr :
632+ for block_id in range (nixl_agent_meta .num_blocks ):
633+ block_offset = block_id * nixl_agent_meta .block_len
634+ for b in range (self .block_size ):
635+ # Remote kv_dim = local kv_dim * d_workers_per_p_worker
636+ head_offset = b * self .kv_dim * d_workers_per_p_worker
637+ addr = base_addr + block_offset + head_offset
638+ # (addr, len, device id)
639+ blocks_data .append (
640+ (addr + rank_offset , self .kv_dim , remote_rank ))
641+ logger .debug (
642+ "Created %s blocks for dst engine %s with remote rank %s and " \
643+ "local rank %s" ,
644+ len (blocks_data ), engine_id , remote_rank , self .rank )
645+
646+ # Register with NIXL.
647+ descs = self .nixl_wrapper .get_xfer_descs (blocks_data , "VRAM" )
648+ self .dst_xfer_side_handles [
649+ engine_id ] = self .nixl_wrapper .prep_xfer_dlist (
650+ self ._remote_agents [engine_id ][remote_rank ], descs )
572651
573652 def get_finished (self ) -> tuple [set [str ], set [str ]]:
574653 """
@@ -733,6 +812,16 @@ def _read_blocks(
733812
734813 # Get side handles.
735814 local_xfer_side_handle = self .src_xfer_side_handle
815+
816+ # NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
817+ # corresponding rank. With heterogeneous TP, fixing D>P, the D tp
818+ # workers will issue xfers to parts of the P worker remote kv caches.
819+
820+ # Get side handles.
821+ d_workers_per_p_worker = self ._tp_size [
822+ self .engine_id ] // self ._tp_size [dst_engine_id ]
823+ local_xfer_side_handle = self .src_xfer_side_handle [
824+ d_workers_per_p_worker ]
736825 remote_xfer_side_handle = self .dst_xfer_side_handles [dst_engine_id ]
737826
738827 # Get descs ids.
@@ -818,7 +907,9 @@ def _get_block_descs_ids(self,
818907 descs_ids : list [int ] = []
819908 for reg_id in region_ids :
820909 for block_id in block_ids :
821- descs_ids .append (reg_id * num_blocks + block_id )
910+ for kv_block in range (self .block_size ):
911+ descs_ids .append (reg_id * num_blocks * self .block_size +
912+ block_id * self .block_size + kv_block )
822913 return descs_ids
823914
824915
0 commit comments