Skip to content
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3442,6 +3442,7 @@ class KVTransferConfig(BaseModel):
# any extra config that the connector may need
kv_connector_extra_config: dict[str, Any] = {}


def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
Expand Down
195 changes: 140 additions & 55 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing_extensions import Optional

from vllm import envs
from vllm.config import VllmConfig
from vllm.config import VllmConfig, KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (
Expand Down Expand Up @@ -52,6 +52,8 @@
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
tp_size: int
block_len: int


@dataclass
Expand Down Expand Up @@ -98,7 +100,7 @@
self.connector_worker: Optional[NixlConnectorWorker] = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = NixlConnectorWorker(str(self.engine_id))
self.connector_worker = NixlConnectorWorker(str(self.engine_id), vllm_config.kv_transfer_config)

############################################################
# Scheduler Side Methods
Expand Down Expand Up @@ -214,7 +216,7 @@
class NixlConnectorWorker:
"""Implementation of Worker side methods"""

def __init__(self, engine_id: str):
def __init__(self, engine_id: str, kv_config: KVTransferConfig):
if NixlWrapper is None:
logger.error("NIXL is not available")
raise RuntimeError("NIXL is not available")
Expand All @@ -223,32 +225,36 @@

# Agent.
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
# Map of engine_id -> agent_name.
self._remote_agents: dict[str, str] = {}
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict)

# Metadata.
self.engine_id = engine_id
self.rank = get_tensor_model_parallel_rank()
self.world_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group()

# Remote tracking ds only contain one entry for own tp group: engine_id-self.rank
# KV Caches and nixl tracking data.

Check failure on line 238 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:238:81: E501 Line too long (89 > 80)
self.kv_caches: dict[str, torch.Tensor] = {}

# Map of engine_id -> kv_caches_base_addr
self.kv_caches_base_addr: dict[str, list[int]] = {}
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
self.kv_caches_base_addr: dict[str, list[int]] = dict()

# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
self.num_regions = 0

# nixl_prepped_dlist_handle (int).
self.src_xfer_side_handle: int = 0
# nixl_prepped_dlist_handle. Different dst TP sizes require preparing
# xfer layout differently.
self.src_xfer_side_handle: int = dict()
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
self.dst_xfer_side_handles: dict[str, int] = {}
self.dst_xfer_side_handles: dict[str, int] = dict()

# Map of engine_id -> num_blocks.
self.dst_num_blocks: dict[str, int] = {}
# Map of engine_id -> num_blocks. Remote TP ranks will have the same
# number of blocks.
self.dst_num_blocks: dict[str, int] = dict()
self._registered_descs: list[Any] = []

# In progress transfers.
Expand All @@ -266,6 +272,8 @@

# Background thread for establishing new connections.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None

self._tp_size = {self.engine_id: self.world_size}

@staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
Expand All @@ -278,7 +286,8 @@
# move this into the scheduler rather than worker, since
# each rank needs the metadata of all other ranks (whereas
# in this setup, each rank only gets one other rank's meta.
# TODO iterate over all ranks to handshake with M. Can we get M from config?

Check failure on line 290 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:290:81: E501 Line too long (84 > 80)
encoder = msgspec.msgpack.Encoder()
encoded_data = encoder.encode(metadata)
size_in_bytes = len(encoded_data)
Expand All @@ -290,6 +299,7 @@
# NOTE(rob): we need each rank to have a unique port. This
# hack to keeps us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
# TODO get rank port util
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
path = f"tcp://{host}:{port}"
logger.debug("Starting listening on path: %s", path)
Expand All @@ -309,9 +319,8 @@
# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
path = f"tcp://{host}:{port + self.rank}"
logger.debug("Querying metadata on path: %s", path)
with zmq_ctx(zmq.REQ, path) as sock:

def handshake(sock, rank: int)->NixlAgentMetadata:
# Send query for the request.
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
Expand All @@ -320,13 +329,33 @@
got_metadata_time = time.perf_counter()

# Register Remote agent.
self.add_remote_agent(metadata)
self.add_remote_agent(metadata, rank)
setup_agent_time = time.perf_counter()

logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time)
got_metadata_time - start_time)
logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time)
setup_agent_time - got_metadata_time)
return metadata

# Handshake with remote agent-rank0 first to get the tp_size of remote
path = f"tcp://{host}:{port}"
logger.debug("Querying master rank metadata on path: %s", path)
with zmq_ctx(zmq.REQ, path) as sock:
metadata = handshake(sock, 0)

# TODO should we skip this if remote world_size == world_size (homogeneous)?

# Handshake only with the other TP remote the current local rank will

Check failure on line 349 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:349:81: E501 Line too long (84 > 80)
# pull from. With homogeneous TP it happens to be the same rank_i.
p_remote_rank = self.rank % metadata.tp_size
if p_remote_rank > 0:
path = f"tcp://{host}:{port + p_remote_rank}"
logger.debug("Querying metadata on path: %s at remote rank %s", path, p_remote_rank)
with zmq_ctx(zmq.REQ, path) as sock:
metadata = handshake(sock, p_remote_rank)



def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
Expand All @@ -340,6 +369,7 @@
# MLA case.
self.num_blocks = first_kv_cache.shape[0]
block_rank = 2 # [block_size, latent_dim]
# TODO does this include tp dependent size?
Copy link
Member

@tlrmchlsmth tlrmchlsmth May 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For MLA we replicate the KV cache across TP ranks, so in this case the prefiller would need to send the same blocks to all decoders. This is the same when TP size is greater than the num kv heads

block_shape = first_kv_cache.shape[-block_rank:]
else:
# [2 (k and v), num_blocks, ...]
Expand All @@ -350,6 +380,7 @@
# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
self.block_len = kv_elem_size * math.prod(block_shape)
print(f"\n\n{self.block_len=}\n\n")

logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla,
first_kv_cache.shape)
Expand Down Expand Up @@ -377,6 +408,9 @@
kv_caches_base_addr.append(base_addr)
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
self.num_regions = len(caches_data)
print("************************BLOCKS SETUP")
print(f"Number of blocks {len(kv_caches_base_addr)=}\n")
print(f"{self.num_blocks=}, {self.block_len=}, {self.num_regions=}\n")

descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
logger.debug("Registering descs: %s", caches_data)
Expand All @@ -391,6 +425,8 @@
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
tp_size=self.world_size,
block_len=self.block_len
)
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
Expand All @@ -401,49 +437,93 @@
self._nixl_handshake_listener_t.start()
ready_event.wait()

def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata, remote_rank: int=0):
# FIXME one other approach I tried is loading half of every remote block instead of half the blocks. Doesnt seem to make much difference
engine_id = nixl_agent_meta.engine_id
if engine_id in self._remote_agents:
# TODO re-evaluate refreshing for scaling/recovery

Check failure on line 443 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:443:81: E501 Line too long (144 > 80)
if engine_id in self._remote_agents and remote_rank in self._remote_agents[engine_id]:
return

Check failure on line 446 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:446:81: E501 Line too long (83 > 80)
self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent(
if engine_id in self._tp_size:
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
self._tp_size[engine_id] = nixl_agent_meta.tp_size
self._remote_agents[engine_id][remote_rank] = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata)
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr

# TODO enforce tp sizes are exact multiples
d_workers_per_p_worker = self._tp_size[self.engine_id] // self._tp_size[engine_id]
assert d_workers_per_p_worker > 0, "Decode TP cannot be smaller than prefill TP"
dst_num_blocks_per_local_rank = nixl_agent_meta.num_blocks // d_workers_per_p_worker

# Create src descs and xfer side handles.
if d_workers_per_p_worker not in self.src_xfer_side_handle:
blocks_data = []

Check failure on line 460 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:460:81: E501 Line too long (88 > 80)
for base_addr in self.kv_caches_base_addr[self.engine_id]:

Check failure on line 461 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:461:81: E501 Line too long (92 > 80)
for block_id in range(dst_num_blocks_per_local_rank):
block_offset = block_id * nixl_agent_meta.block_len
# (addr, len, device id)
# use the block size of the dst/P node to make sure regions match
blocks_data.append(
(base_addr + block_offset, nixl_agent_meta.block_len, self.rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.rank)

Check failure on line 470 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:470:81: E501 Line too long (85 > 80)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
# NIXL_INIT_AGENT to be used for preparations of local descs.
self.src_xfer_side_handle[d_workers_per_p_worker] = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)

# Create dst descs and xfer side handles. TP workers have same #blocks
# if engine_id in self.dst_num_blocks:
# assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks

# self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
# When D_TP>P_TP, P blocks are split between D workers. Hence we may
# record a fraction of the total num_blocks in P.
self.dst_num_blocks[engine_id] = dst_num_blocks_per_local_rank

Check failure on line 485 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:485:81: E501 Line too long (81 > 80)
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
# (addr, len, device id)
blocks_data.append(
(base_addr + block_offset, self.block_len, self.rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.rank)

# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)

# Create dst descs and xfer side handles.
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
blocks_data = []
for base_addr in self.kv_caches_base_addr[engine_id]:
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * self.block_len
# (addr, len, device id)
blocks_data.append(
(base_addr + block_offset, self.block_len, self.rank))
logger.debug("Created %s blocks for dst engine %s and rank %s",
len(blocks_data), engine_id, self.rank)

# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
self._remote_agents[engine_id], descs)
# With heterogenous TP, prepare the descriptors by splitting the P KV
# cache into chunks of D worker's size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[KV_0 | KV_1] (contiguous view).
p_remote_rank = self.rank % nixl_agent_meta.tp_size
# Only register the remote's descriptor if current rank pulls from it
if p_remote_rank == remote_rank:
self.kv_caches_base_addr[engine_id] = nixl_agent_meta.kv_caches_base_addr

# TODO in case sizes aren't exactly divisible, we may want to swap
# self.block_len with meta.block_len // d_workers_per_p_worker
# (eg when dividing by 3) and handle final block. src_xfer too.
# assert nixl_agent_meta.block_len % self.block_len == 0

# Split the kv memory inside a nixl region to guarantee each local
# rank is pulling the kv cache of all layers of a remote worker.
# TODO what if the region_len of P and D don't match in size due to some TP overhead??Also this would assume the mem utilized is the same..
rank_offset = self.rank // nixl_agent_meta.tp_size * nixl_agent_meta.block_len * dst_num_blocks_per_local_rank
print(f"Local Rank {self.rank} remote {remote_rank}: {rank_offset=}/ Remote region_len {nixl_agent_meta.num_blocks*nixl_agent_meta.block_len}\n\n")
print(f"{nixl_agent_meta.num_blocks=}, {dst_num_blocks_per_local_rank=}")
# DECODE TP2 || self.num_blocks=33769, self.block_len=16384, self.num_regions=56
# PREFILL TP1 || self.num_blocks=17371, self.block_len=32768, self.num_regions=56
# FIXME assume num_blocks and block_len are actually divisible and all is nice. This needs to be enforced (eg diff mem usage might break)
for base_addr in nixl_agent_meta.kv_caches_base_addr:

Check failure on line 509 in vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py:509:81: E501 Line too long (151 > 80)
base_addr += rank_offset
# for block_id in range(self.num_blocks):
for block_id in range(dst_num_blocks_per_local_rank):
# block_offset = block_id * self.block_len
block_offset = block_id * nixl_agent_meta.block_len
# (addr, len, device id)
blocks_data.append(
(base_addr + block_offset, nixl_agent_meta.block_len, self.rank))
# blocks_data.append(
# (base_addr + block_offset, self.block_len, self.rank))
logger.debug("Created %s blocks for dst engine %s with remote rank %s and local rank %s",
len(blocks_data), engine_id, remote_rank, self.rank)

# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[engine_id] = self.nixl_wrapper.prep_xfer_dlist(
self._remote_agents[engine_id][remote_rank], descs)

def get_finished(self) -> tuple[set[str], set[str]]:
"""
Expand Down Expand Up @@ -580,6 +660,7 @@
request_id: str,
):
# NOTE(rob): this takes ~2s. We need to get this off the hotpath.
# TODO check remote_rank in here too?
if dst_engine_id not in self._remote_agents:
self._nixl_handshake(remote_host, remote_port)

Expand All @@ -595,9 +676,13 @@

assert len(local_block_ids) > 0
assert len(local_block_ids) == len(remote_block_ids)
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogenous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.

# Get side handles.
local_xfer_side_handle = self.src_xfer_side_handle
d_workers_per_p_worker = self._tp_size[self.engine_id] // self._tp_size[dst_engine_id]
local_xfer_side_handle = self.src_xfer_side_handle[d_workers_per_p_worker]
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]

# Get descs ids.
Expand Down