Skip to content

Commit f674fc7

Browse files
committed
address review: clarity
Signed-off-by: nicklucche <[email protected]>
1 parent bde92a5 commit f674fc7

File tree

1 file changed

+31
-28
lines changed

1 file changed

+31
-28
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
155155
self.connector_worker: Optional[NixlConnectorWorker] = None
156156
elif role == KVConnectorRole.WORKER:
157157
self.connector_scheduler = None
158-
self.connector_worker = NixlConnectorWorker(str(self.engine_id))
158+
self.connector_worker = NixlConnectorWorker(
159+
vllm_config, str(self.engine_id))
159160

160161
############################################################
161162
# Scheduler Side Methods
@@ -349,13 +350,17 @@ def request_finished(
349350
class NixlConnectorWorker:
350351
"""Implementation of Worker side methods"""
351352

352-
def __init__(self, engine_id: str):
353+
def __init__(self, vllm_config: VllmConfig, engine_id: str):
353354
if NixlWrapper is None:
354355
logger.error("NIXL is not available")
355356
raise RuntimeError("NIXL is not available")
356357
logger.info("Initializing NIXL wrapper")
357358
logger.info("Initializing NIXL worker %s", engine_id)
358359

360+
# Config.
361+
self.vllm_config = vllm_config
362+
self.block_size = vllm_config.cache_config.block_size
363+
359364
# Agent.
360365
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
361366
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
@@ -454,41 +459,39 @@ def _nixl_handshake(self, host: str, port: int):
454459
# a hack to keep us moving. We will switch when moving to etcd
455460
# or where we have a single ZMQ socket in the scheduler.
456461

457-
def handshake(sock, rank: int) -> NixlAgentMetadata:
462+
def handshake(path: str, rank: int) -> NixlAgentMetadata:
458463
# Send query for the request.
459-
sock.send(GET_META_MSG)
460-
metadata_bytes = sock.recv()
461-
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
462-
metadata = decoder.decode(metadata_bytes)
463-
got_metadata_time = time.perf_counter()
464-
465-
# Register Remote agent.
466-
self.add_remote_agent(metadata, rank)
467-
setup_agent_time = time.perf_counter()
468-
469-
logger.debug("NIXL handshake: get metadata took: %s",
470-
got_metadata_time - start_time)
471-
logger.debug("NIXL handshake: add agent took: %s",
472-
setup_agent_time - got_metadata_time)
473-
return metadata
464+
with zmq_ctx(zmq.REQ, path) as sock:
465+
sock.send(GET_META_MSG)
466+
metadata_bytes = sock.recv()
467+
decoder = msgspec.msgpack.Decoder(NixlAgentMetadata)
468+
metadata = decoder.decode(metadata_bytes)
469+
got_metadata_time = time.perf_counter()
470+
471+
# Register Remote agent.
472+
self.add_remote_agent(metadata, rank)
473+
setup_agent_time = time.perf_counter()
474+
475+
logger.debug("NIXL handshake: get metadata took: %s",
476+
got_metadata_time - start_time)
477+
logger.debug("NIXL handshake: add agent took: %s",
478+
setup_agent_time - got_metadata_time)
479+
return metadata
474480

475481
# Handshake with remote agent-rank0 first to get the tp_size of remote
476482
path = f"tcp://{host}:{port}"
477483
logger.debug("Querying master rank metadata on path: %s", path)
478-
with zmq_ctx(zmq.REQ, path) as sock:
479-
metadata = handshake(sock, 0)
484+
metadata = handshake(path, 0)
480485

481486
# Handshake only with the other TP remote the current local rank will
482487
# pull from. With homogeneous TP it happens to be the same rank_i.
483-
d_workers_per_p_worker = self._tp_size[
484-
self.engine_id] // metadata.tp_size
485-
p_remote_rank = self.rank // d_workers_per_p_worker
488+
tp_rate = self._tp_size[self.engine_id] // metadata.tp_size
489+
p_remote_rank = self.rank // tp_rate
486490
if p_remote_rank > 0:
487491
path = f"tcp://{host}:{port + p_remote_rank}"
488492
logger.debug("Querying metadata on path: %s at remote rank %s",
489493
path, p_remote_rank)
490-
with zmq_ctx(zmq.REQ, path) as sock:
491-
metadata = handshake(sock, p_remote_rank)
494+
_ = handshake(path, p_remote_rank)
492495

493496
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
494497
"""Register the KV Cache data in nixl."""
@@ -503,17 +506,17 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
503506
self.num_blocks = first_kv_cache.shape[0]
504507
block_rank = 2 # [block_size, latent_dim]
505508
block_shape = first_kv_cache.shape[-block_rank:]
506-
self.block_size, kv_latent_dim = block_shape
509+
block_size, kv_latent_dim = block_shape
507510
self.slot_size_bytes = kv_elem_size * kv_latent_dim
508511
else:
509512
# [2 (k and v), num_blocks, block_size, kv_heads, head_dim]
510513
self.num_blocks = first_kv_cache.shape[1]
511514
block_rank = 3 # [block_size, kv_heads, head_dim]
512515
block_shape = first_kv_cache.shape[-block_rank:]
513-
self.block_size, n_kv_heads, head_dim = block_shape
516+
block_size, n_kv_heads, head_dim = block_shape
514517
# head size in bytes.
515518
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
516-
519+
assert block_size == self.block_size
517520
# TODO(tms): self.block_len needs to be per-layer for sliding window,
518521
# hybrid attn, etc
519522
# block size in bytes

0 commit comments

Comments
 (0)