@@ -318,6 +318,10 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
318318 logger .info ("Initializing NIXL wrapper" )
319319 logger .info ("Initializing NIXL worker %s" , engine_id )
320320
321+ # Config.
322+ self .vllm_config = vllm_config
323+ self .block_size = vllm_config .cache_config .block_size
324+
321325 # Agent.
322326 self .nixl_wrapper = NixlWrapper (str (uuid .uuid4 ()), None )
323327 # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
@@ -378,7 +382,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
378382 self ._tp_size : dict [str , int ] = {self .engine_id : self .world_size }
379383 # With heterogeneous TP, P must wait for all assigned D TP workers to
380384 # finish reading before safely freeing the blocks.
381- self .consumer_notification_counts_by_req = defaultdict (int )
385+ self .consumer_notification_counts_by_req : dict [str ,
386+ int ] = defaultdict (int )
382387
383388 @staticmethod
384389 def _nixl_handshake_listener (metadata : NixlAgentMetadata ,
@@ -424,41 +429,39 @@ def _nixl_handshake(self, host: str, port: int):
424429 # a hack to keep us moving. We will switch when moving to etcd
425430 # or where we have a single ZMQ socket in the scheduler.
426431
427- def handshake (sock , rank : int ) -> NixlAgentMetadata :
432+ def handshake (path : str , rank : int ) -> NixlAgentMetadata :
428433 # Send query for the request.
429- sock .send (GET_META_MSG )
430- metadata_bytes = sock .recv ()
431- decoder = msgspec .msgpack .Decoder (NixlAgentMetadata )
432- metadata = decoder .decode (metadata_bytes )
433- got_metadata_time = time .perf_counter ()
434-
435- # Register Remote agent.
436- self .add_remote_agent (metadata , rank )
437- setup_agent_time = time .perf_counter ()
438-
439- logger .debug ("NIXL handshake: get metadata took: %s" ,
440- got_metadata_time - start_time )
441- logger .debug ("NIXL handshake: add agent took: %s" ,
442- setup_agent_time - got_metadata_time )
443- return metadata
434+ with zmq_ctx (zmq .REQ , path ) as sock :
435+ sock .send (GET_META_MSG )
436+ metadata_bytes = sock .recv ()
437+ decoder = msgspec .msgpack .Decoder (NixlAgentMetadata )
438+ metadata = decoder .decode (metadata_bytes )
439+ got_metadata_time = time .perf_counter ()
440+
441+ # Register Remote agent.
442+ self .add_remote_agent (metadata , rank )
443+ setup_agent_time = time .perf_counter ()
444+
445+ logger .debug ("NIXL handshake: get metadata took: %s" ,
446+ got_metadata_time - start_time )
447+ logger .debug ("NIXL handshake: add agent took: %s" ,
448+ setup_agent_time - got_metadata_time )
449+ return metadata
444450
445451 # Handshake with remote agent-rank0 first to get the tp_size of remote
446452 path = f"tcp://{ host } :{ port } "
447453 logger .debug ("Querying master rank metadata on path: %s" , path )
448- with zmq_ctx (zmq .REQ , path ) as sock :
449- metadata = handshake (sock , 0 )
454+ metadata = handshake (path , 0 )
450455
451456 # Handshake only with the other TP remote the current local rank will
452457 # pull from. With homogeneous TP it happens to be the same rank_i.
453- d_workers_per_p_worker = self ._tp_size [
454- self .engine_id ] // metadata .tp_size
455- p_remote_rank = self .rank // d_workers_per_p_worker
458+ tp_rate = self ._tp_size [self .engine_id ] // metadata .tp_size
459+ p_remote_rank = self .rank // tp_rate
456460 if p_remote_rank > 0 :
457461 path = f"tcp://{ host } :{ port + p_remote_rank } "
458462 logger .debug ("Querying metadata on path: %s at remote rank %s" ,
459463 path , p_remote_rank )
460- with zmq_ctx (zmq .REQ , path ) as sock :
461- metadata = handshake (sock , p_remote_rank )
464+ _ = handshake (path , p_remote_rank )
462465
463466 def register_kv_caches (self , kv_caches : dict [str , torch .Tensor ]):
464467 """Register the KV Cache data in nixl."""
@@ -473,17 +476,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
473476 self .num_blocks = first_kv_cache .shape [0 ]
474477 block_rank = 2 # [block_size, latent_dim]
475478 block_shape = first_kv_cache .shape [- block_rank :]
476- self . block_size , kv_latent_dim = block_shape
479+ block_size , kv_latent_dim = block_shape
477480 self .slot_size_bytes = kv_elem_size * kv_latent_dim
478481 else :
479482 # [2 (k and v), num_blocks, block_size, kv_heads, head_dim]
480483 self .num_blocks = first_kv_cache .shape [1 ]
481484 block_rank = 3 # [block_size, kv_heads, head_dim]
482485 block_shape = first_kv_cache .shape [- block_rank :]
483- self . block_size , n_kv_heads , head_dim = block_shape
486+ block_size , n_kv_heads , head_dim = block_shape
484487 # head size in bytes.
485488 self .slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
486-
489+ assert block_size == self . block_size
487490 # TODO(tms): self.block_len needs to be per-layer for sliding window,
488491 # hybrid attn, etc
489492 # block size in bytes
0 commit comments