Skip to content

Commit 20ca3aa

Browse files
committed
split kv_cache along head dim
fix descr indexing change remote worker selection indexing; test ptp2-dtp4 Signed-off-by: nicklucche <[email protected]>
1 parent d981396 commit 20ca3aa

File tree

1 file changed

+144
-53
lines changed

1 file changed

+144
-53
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 144 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ class NixlAgentMetadata(
5353
agent_metadata: bytes
5454
kv_caches_base_addr: list[int]
5555
num_blocks: int
56+
tp_size: int
57+
block_len: int
5658

5759

5860
@dataclass
@@ -318,8 +320,8 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
318320

319321
# Agent.
320322
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
321-
# Map of engine_id -> agent_name.
322-
self._remote_agents: dict[str, str] = {}
323+
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
324+
self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict)
323325

324326
# Metadata.
325327
self.engine_id = engine_id
@@ -330,21 +332,24 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
330332
# KV Caches and nixl tracking data.
331333
self.kv_caches: dict[str, torch.Tensor] = {}
332334

333-
# Map of engine_id -> kv_caches_base_addr
334-
self.kv_caches_base_addr: dict[str, list[int]] = {}
335+
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
336+
# rank will still only pull from a single remote TP worker.
337+
self.kv_caches_base_addr: dict[str, list[int]] = dict()
335338

336339
# Number of NIXL regions. Currently one region per cache
337340
# (so 1 per layer for MLA, otherwise 2 per layer)
338341
self.num_regions = 0
339342
self.num_layers = 0
340343

341-
# nixl_prepped_dlist_handle (int).
342-
self.src_xfer_side_handle: int = 0
344+
# nixl_prepped_dlist_handle. Different dst TP sizes require preparing
345+
# xfer layout differently.
346+
self.src_xfer_side_handle: dict[int, int] = dict()
343347
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
344-
self.dst_xfer_side_handles: dict[str, int] = {}
348+
self.dst_xfer_side_handles: dict[str, int] = dict()
345349

346-
# Map of engine_id -> num_blocks.
347-
self.dst_num_blocks: dict[str, int] = {}
350+
# Map of engine_id -> num_blocks. Remote TP ranks will have the same
351+
# number of blocks.
352+
self.dst_num_blocks: dict[str, int] = dict()
348353
self._registered_descs: list[Any] = []
349354

350355
# In progress transfers.
@@ -370,6 +375,7 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
370375
# Optimization for models with local attention (Llama 4)
371376
# List of block window sizes for each layer for local attention
372377
self.block_window_per_layer: list[Optional[int]] = []
378+
self._tp_size = {self.engine_id: self.world_size}
373379

374380
@staticmethod
375381
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
@@ -410,12 +416,12 @@ def _nixl_handshake(self, host: str, port: int):
410416
"""Do a NIXL handshake with a remote instance."""
411417

412418
start_time = time.perf_counter()
419+
413420
# NOTE(rob): we need each rank to have a unique port. This is
414421
# a hack to keep us moving. We will switch when moving to etcd
415422
# or where we have a single ZMQ socket in the scheduler.
416-
path = make_zmq_path("tcp", host, port + self.rank)
417-
logger.debug("Querying metadata on path: %s", path)
418-
with zmq_ctx(zmq.REQ, path) as sock:
423+
424+
def handshake(sock, rank: int) -> NixlAgentMetadata:
419425
# Send query for the request.
420426
sock.send(GET_META_MSG)
421427
metadata_bytes = sock.recv()
@@ -424,13 +430,32 @@ def _nixl_handshake(self, host: str, port: int):
424430
got_metadata_time = time.perf_counter()
425431

426432
# Register Remote agent.
427-
self.add_remote_agent(metadata)
433+
self.add_remote_agent(metadata, rank)
428434
setup_agent_time = time.perf_counter()
429435

430436
logger.debug("NIXL handshake: get metadata took: %s",
431437
got_metadata_time - start_time)
432438
logger.debug("NIXL handshake: add agent took: %s",
433439
setup_agent_time - got_metadata_time)
440+
return metadata
441+
442+
# Handshake with remote agent-rank0 first to get the tp_size of remote
443+
path = f"tcp://{host}:{port}"
444+
logger.debug("Querying master rank metadata on path: %s", path)
445+
with zmq_ctx(zmq.REQ, path) as sock:
446+
metadata = handshake(sock, 0)
447+
448+
# Handshake only with the other TP remote the current local rank will
449+
# pull from. With homogeneous TP it happens to be the same rank_i.
450+
d_workers_per_p_worker = self._tp_size[
451+
self.engine_id] // metadata.tp_size
452+
p_remote_rank = self.rank // d_workers_per_p_worker
453+
if p_remote_rank > 0:
454+
path = f"tcp://{host}:{port + p_remote_rank}"
455+
logger.debug("Querying metadata on path: %s at remote rank %s",
456+
path, p_remote_rank)
457+
with zmq_ctx(zmq.REQ, path) as sock:
458+
metadata = handshake(sock, p_remote_rank)
434459

435460
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
436461
"""Register the KV Cache data in nixl."""
@@ -445,14 +470,20 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
445470
self.num_blocks = first_kv_cache.shape[0]
446471
block_rank = 2 # [block_size, latent_dim]
447472
block_shape = first_kv_cache.shape[-block_rank:]
473+
self.block_size, kv_latent_dim = block_shape
474+
self.kv_dim = kv_elem_size * kv_latent_dim
448475
else:
449-
# [2 (k and v), num_blocks, ...]
476+
# [2 (k and v), num_blocks, block_size, kv_heads, head_dim]
450477
self.num_blocks = first_kv_cache.shape[1]
451478
block_rank = 3 # [block_size, kv_heads, head_dim]
452479
block_shape = first_kv_cache.shape[-block_rank:]
480+
self.block_size, n_kv_heads, head_dim = block_shape
481+
# head size in bytes.
482+
self.kv_dim = kv_elem_size * n_kv_heads * head_dim
453483

454484
# TODO(tms): self.block_len needs to be per-layer for sliding window,
455485
# hybrid attn, etc
486+
# block size in bytes
456487
self.block_len = kv_elem_size * math.prod(block_shape)
457488

458489
logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla,
@@ -507,7 +538,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
507538
logger.debug("Registering descs: %s", caches_data)
508539
self.nixl_wrapper.register_memory(descs)
509540
logger.debug("Done registering descs")
510-
511541
self._registered_descs.append(descs)
512542

513543
# After KV Caches registered, listen for new connections.
@@ -516,7 +546,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
516546
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
517547
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
518548
num_blocks=self.num_blocks,
519-
)
549+
tp_size=self.world_size,
550+
block_len=self.block_len)
520551
ready_event = threading.Event()
521552
self._nixl_handshake_listener_t = threading.Thread(
522553
target=self._nixl_handshake_listener,
@@ -526,49 +557,97 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
526557
self._nixl_handshake_listener_t.start()
527558
ready_event.wait()
528559

529-
def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
560+
def add_remote_agent(self,
561+
nixl_agent_meta: NixlAgentMetadata,
562+
remote_rank: int = 0):
530563
engine_id = nixl_agent_meta.engine_id
531-
if engine_id in self._remote_agents:
564+
# TODO re-evaluate refreshing for scaling/recovery
565+
if (engine_id in self._remote_agents and \
566+
remote_rank in self._remote_agents[engine_id]):
532567
return
533568

534-
self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent(
535-
nixl_agent_meta.agent_metadata)
536-
self.kv_caches_base_addr[
537-
engine_id] = nixl_agent_meta.kv_caches_base_addr
569+
if engine_id in self._tp_size:
570+
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
571+
self._tp_size[engine_id] = nixl_agent_meta.tp_size
572+
self._remote_agents[engine_id][
573+
remote_rank] = self.nixl_wrapper.add_remote_agent(
574+
nixl_agent_meta.agent_metadata)
575+
576+
d_workers_per_p_worker = self._tp_size[
577+
self.engine_id] // self._tp_size[engine_id]
578+
assert d_workers_per_p_worker > 0, "Decode TP cannot be smaller than"
579+
" prefill TP"
580+
581+
# TODO we should also check hidden_dim and kv precision, they must match
582+
remote_block_size = nixl_agent_meta.block_len / (
583+
self.kv_dim * d_workers_per_p_worker)
584+
assert self.block_size == remote_block_size, "Remote P worker with "
585+
"different block size is not supported"
538586

539587
# Create src descs and xfer side handles.
540-
blocks_data = []
541-
for base_addr in self.kv_caches_base_addr[self.engine_id]:
542-
for block_id in range(self.num_blocks):
543-
block_offset = block_id * self.block_len
544-
# (addr, len, device id)
545-
blocks_data.append(
546-
(base_addr + block_offset, self.block_len, self.rank))
547-
logger.debug("Created %s blocks for src engine %s and rank %s",
548-
len(blocks_data), self.engine_id, self.rank)
549-
550-
# Register with NIXL.
551-
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
552-
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
553-
"NIXL_INIT_AGENT", descs)
554-
555-
# Create dst descs and xfer side handles.
588+
if d_workers_per_p_worker not in self.src_xfer_side_handle:
589+
blocks_data = []
590+
for base_addr in self.kv_caches_base_addr[self.engine_id]:
591+
# NOTE With heter-TP, more blocks are prepared than what are
592+
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
593+
# could create fewer, but then _get_block_descs_ids needs to
594+
# select agent_meta.num_blocks instead of self.num_blocks for
595+
# local descr, and that makes handling regular flow less clean.
596+
for block_id in range(self.num_blocks):
597+
block_offset = block_id * self.block_len
598+
for b in range(self.block_size):
599+
head_offset = b * self.kv_dim
600+
addr = base_addr + block_offset + head_offset
601+
# (addr, len, device id)
602+
blocks_data.append((addr, self.kv_dim, self.rank))
603+
logger.debug("Created %s blocks for src engine %s and rank %s",
604+
len(blocks_data), self.engine_id, self.rank)
605+
606+
# Register with NIXL.
607+
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
608+
# NIXL_INIT_AGENT to be used for preparations of local descs.
609+
self.src_xfer_side_handle[
610+
d_workers_per_p_worker] = self.nixl_wrapper.prep_xfer_dlist(
611+
"NIXL_INIT_AGENT", descs)
612+
613+
# Create dst descs and xfer side handles. TP workers have same #blocks.
614+
if engine_id in self.dst_num_blocks:
615+
assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks
616+
556617
self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks
618+
557619
blocks_data = []
558-
for base_addr in self.kv_caches_base_addr[engine_id]:
559-
for block_id in range(nixl_agent_meta.num_blocks):
560-
block_offset = block_id * self.block_len
561-
# (addr, len, device id)
562-
blocks_data.append(
563-
(base_addr + block_offset, self.block_len, self.rank))
564-
logger.debug("Created %s blocks for dst engine %s and rank %s",
565-
len(blocks_data), engine_id, self.rank)
566-
567-
# Register with NIXL.
568-
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
569-
self.dst_xfer_side_handles[
570-
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
571-
self._remote_agents[engine_id], descs)
620+
# With homogeneous TP, D pulls the whole kv cache from corresponding
621+
# rank. With heterogeneous TP, prepare the descriptors by splitting the
622+
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
623+
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
624+
p_remote_rank = self.rank // d_workers_per_p_worker
625+
# Only register the remote's descriptors if current rank pulls from it.
626+
if p_remote_rank == remote_rank:
627+
self.kv_caches_base_addr[
628+
engine_id] = nixl_agent_meta.kv_caches_base_addr
629+
rank_offset = self.rank % d_workers_per_p_worker * self.kv_dim
630+
# Register all remote blocks, but only the corresponding kv heads.
631+
for base_addr in nixl_agent_meta.kv_caches_base_addr:
632+
for block_id in range(nixl_agent_meta.num_blocks):
633+
block_offset = block_id * nixl_agent_meta.block_len
634+
for b in range(self.block_size):
635+
# Remote kv_dim = local kv_dim * d_workers_per_p_worker
636+
head_offset = b * self.kv_dim * d_workers_per_p_worker
637+
addr = base_addr + block_offset + head_offset
638+
# (addr, len, device id)
639+
blocks_data.append(
640+
(addr + rank_offset, self.kv_dim, remote_rank))
641+
logger.debug(
642+
"Created %s blocks for dst engine %s with remote rank %s and " \
643+
"local rank %s",
644+
len(blocks_data), engine_id, remote_rank, self.rank)
645+
646+
# Register with NIXL.
647+
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
648+
self.dst_xfer_side_handles[
649+
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
650+
self._remote_agents[engine_id][remote_rank], descs)
572651

573652
def get_finished(self) -> tuple[set[str], set[str]]:
574653
"""
@@ -733,6 +812,16 @@ def _read_blocks(
733812

734813
# Get side handles.
735814
local_xfer_side_handle = self.src_xfer_side_handle
815+
816+
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
817+
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
818+
# workers will issue xfers to parts of the P worker remote kv caches.
819+
820+
# Get side handles.
821+
d_workers_per_p_worker = self._tp_size[
822+
self.engine_id] // self._tp_size[dst_engine_id]
823+
local_xfer_side_handle = self.src_xfer_side_handle[
824+
d_workers_per_p_worker]
736825
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]
737826

738827
# Get descs ids.
@@ -818,7 +907,9 @@ def _get_block_descs_ids(self,
818907
descs_ids: list[int] = []
819908
for reg_id in region_ids:
820909
for block_id in block_ids:
821-
descs_ids.append(reg_id * num_blocks + block_id)
910+
for kv_block in range(self.block_size):
911+
descs_ids.append(reg_id * num_blocks * self.block_size +
912+
block_id * self.block_size + kv_block)
822913
return descs_ids
823914

824915

0 commit comments

Comments
 (0)