@@ -155,7 +155,8 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
155155 self .connector_worker : Optional [NixlConnectorWorker ] = None
156156 elif role == KVConnectorRole .WORKER :
157157 self .connector_scheduler = None
158- self .connector_worker = NixlConnectorWorker (str (self .engine_id ))
158+ self .connector_worker = NixlConnectorWorker (
159+ vllm_config , str (self .engine_id ))
159160
160161 ############################################################
161162 # Scheduler Side Methods
@@ -349,13 +350,17 @@ def request_finished(
349350class NixlConnectorWorker :
350351 """Implementation of Worker side methods"""
351352
352- def __init__ (self , engine_id : str ):
353+ def __init__ (self , vllm_config : VllmConfig , engine_id : str ):
353354 if NixlWrapper is None :
354355 logger .error ("NIXL is not available" )
355356 raise RuntimeError ("NIXL is not available" )
356357 logger .info ("Initializing NIXL wrapper" )
357358 logger .info ("Initializing NIXL worker %s" , engine_id )
358359
360+ # Config.
361+ self .vllm_config = vllm_config
362+ self .block_size = vllm_config .cache_config .block_size
363+
359364 # Agent.
360365 self .nixl_wrapper = NixlWrapper (str (uuid .uuid4 ()), None )
361366 # Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
@@ -454,41 +459,39 @@ def _nixl_handshake(self, host: str, port: int):
454459 # a hack to keep us moving. We will switch when moving to etcd
455460 # or where we have a single ZMQ socket in the scheduler.
456461
457- def handshake (sock , rank : int ) -> NixlAgentMetadata :
462+ def handshake (path : str , rank : int ) -> NixlAgentMetadata :
458463 # Send query for the request.
459- sock .send (GET_META_MSG )
460- metadata_bytes = sock .recv ()
461- decoder = msgspec .msgpack .Decoder (NixlAgentMetadata )
462- metadata = decoder .decode (metadata_bytes )
463- got_metadata_time = time .perf_counter ()
464-
465- # Register Remote agent.
466- self .add_remote_agent (metadata , rank )
467- setup_agent_time = time .perf_counter ()
468-
469- logger .debug ("NIXL handshake: get metadata took: %s" ,
470- got_metadata_time - start_time )
471- logger .debug ("NIXL handshake: add agent took: %s" ,
472- setup_agent_time - got_metadata_time )
473- return metadata
464+ with zmq_ctx (zmq .REQ , path ) as sock :
465+ sock .send (GET_META_MSG )
466+ metadata_bytes = sock .recv ()
467+ decoder = msgspec .msgpack .Decoder (NixlAgentMetadata )
468+ metadata = decoder .decode (metadata_bytes )
469+ got_metadata_time = time .perf_counter ()
470+
471+ # Register Remote agent.
472+ self .add_remote_agent (metadata , rank )
473+ setup_agent_time = time .perf_counter ()
474+
475+ logger .debug ("NIXL handshake: get metadata took: %s" ,
476+ got_metadata_time - start_time )
477+ logger .debug ("NIXL handshake: add agent took: %s" ,
478+ setup_agent_time - got_metadata_time )
479+ return metadata
474480
475481 # Handshake with remote agent-rank0 first to get the tp_size of remote
476482 path = f"tcp://{ host } :{ port } "
477483 logger .debug ("Querying master rank metadata on path: %s" , path )
478- with zmq_ctx (zmq .REQ , path ) as sock :
479- metadata = handshake (sock , 0 )
484+ metadata = handshake (path , 0 )
480485
481486 # Handshake only with the other TP remote the current local rank will
482487 # pull from. With homogeneous TP it happens to be the same rank_i.
483- d_workers_per_p_worker = self ._tp_size [
484- self .engine_id ] // metadata .tp_size
485- p_remote_rank = self .rank // d_workers_per_p_worker
488+ tp_rate = self ._tp_size [self .engine_id ] // metadata .tp_size
489+ p_remote_rank = self .rank // tp_rate
486490 if p_remote_rank > 0 :
487491 path = f"tcp://{ host } :{ port + p_remote_rank } "
488492 logger .debug ("Querying metadata on path: %s at remote rank %s" ,
489493 path , p_remote_rank )
490- with zmq_ctx (zmq .REQ , path ) as sock :
491- metadata = handshake (sock , p_remote_rank )
494+ _ = handshake (path , p_remote_rank )
492495
493496 def register_kv_caches (self , kv_caches : dict [str , torch .Tensor ]):
494497 """Register the KV Cache data in nixl."""
@@ -503,17 +506,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
503506 self .num_blocks = first_kv_cache .shape [0 ]
504507 block_rank = 2 # [block_size, latent_dim]
505508 block_shape = first_kv_cache .shape [- block_rank :]
506- self . block_size , kv_latent_dim = block_shape
509+ block_size , kv_latent_dim = block_shape
507510 self .slot_size_bytes = kv_elem_size * kv_latent_dim
508511 else :
509512 # [2 (k and v), num_blocks, block_size, kv_heads, head_dim]
510513 self .num_blocks = first_kv_cache .shape [1 ]
511514 block_rank = 3 # [block_size, kv_heads, head_dim]
512515 block_shape = first_kv_cache .shape [- block_rank :]
513- self . block_size , n_kv_heads , head_dim = block_shape
516+ block_size , n_kv_heads , head_dim = block_shape
514517 # head size in bytes.
515518 self .slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
516-
519+ assert block_size == self . block_size
517520 # TODO(tms): self.block_len needs to be per-layer for sliding window,
518521 # hybrid attn, etc
519522 # block size in bytes
0 commit comments