Skip to content

Commit f0c503f

Browse files
authored
[Nixl] Heterogeneous TP support FlashInfer (vllm-project#20189)
Signed-off-by: NickLucche <[email protected]>
1 parent f38035c commit f0c503f

File tree

1 file changed

+53
-9
lines changed

1 file changed

+53
-9
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
715715
# are non-contiguous (it's not locally guaranteed that they will be)
716716
# Disadvantage is that the encoded NixlAgentMetadata is now larger
717717
# (roughly 8KB vs 5KB).
718-
# Conversely for FlashInfer, K and V are transferred in the same tensor
718+
# Conversely for FlashInfer, K and V are registered in the same region
719719
# to better exploit the memory layout (ie num_blocks is the first dim).
720720
split_k_and_v = not (self.use_mla or self._use_pallas_v1
721721
or self._use_flashinfer)
@@ -758,12 +758,21 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
758758
assert tensor_size_bytes % self.num_blocks == 0
759759
self.block_len = tensor_size_bytes // self.num_blocks
760760
self.slot_size_bytes = self.block_len // self.block_size
761+
self.device_kv_caches = kv_caches
762+
self.dst_num_blocks[self.engine_id] = self.num_blocks
761763
if self._use_flashinfer:
762764
assert self.slot_size_bytes % 2 == 0
763765
self.slot_size_bytes /= 2
764-
self.device_kv_caches = kv_caches
765-
self.dst_num_blocks[self.engine_id] = self.num_blocks
766766

767+
# NOTE (NickLucche) When FlashInfer is used, memory is registered
768+
# with joint KV for each block. This minimizes the overhead in
769+
# registerMem allowing faster descs queries. In order to be able to
770+
# split on kv_heads dim as required by heterogeneous TP, one must
771+
# be able to index K/V separately. Hence the we double the number
772+
# of 'virtual' regions here and halve `block_len` below.
773+
self.num_regions *= 2
774+
775+
kv_block_len = self.get_backend_aware_kv_block_len()
767776
# Register local/src descr for NIXL xfer.
768777
blocks_data = []
769778
for base_addr in seen_base_addresses:
@@ -776,8 +785,18 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
776785
block_offset = block_id * self.block_len
777786
addr = base_addr + block_offset
778787
# (addr, len, device id)
779-
# TODO: does device_id matter to DRAM?
780-
blocks_data.append((addr, self.block_len, self.tp_rank))
788+
blocks_data.append((addr, kv_block_len, self.tp_rank))
789+
790+
if self._use_flashinfer:
791+
# Separate and interleave K/V regions to maintain the same
792+
# descs ordering. This is needed for selecting contiguous heads
793+
# when split across TP ranks.
794+
for block_id in range(self.num_blocks):
795+
block_offset = block_id * self.block_len
796+
addr = base_addr + block_offset
797+
# Register addresses for V cache (K registered first).
798+
v_addr = addr + kv_block_len
799+
blocks_data.append((v_addr, kv_block_len, self.tp_rank))
781800
logger.debug("Created %s blocks for src engine %s and rank %s",
782801
len(blocks_data), self.engine_id, self.tp_rank)
783802

@@ -903,7 +922,7 @@ def add_remote_agent(self,
903922
remote_block_size = nixl_agent_meta.block_len // (
904923
self.slot_size_bytes * tp_ratio)
905924
if self._use_flashinfer:
906-
# Account for joint KV in FlashInfer.
925+
# With flashinfer, KV are sent in the same message.
907926
remote_block_size //= 2
908927
if tp_ratio > 1:
909928
# Heterogeneous TP expects same kv_cache_layout.
@@ -929,10 +948,10 @@ def add_remote_agent(self,
929948
# rank. With heterogeneous TP, prepare the descriptors by splitting the
930949
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
931950
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
932-
# Only register the remote's descriptors if current rank pulls from it.
933951
self.kv_caches_base_addr[
934952
engine_id] = nixl_agent_meta.kv_caches_base_addr
935-
rank_offset = self.tp_rank % tp_ratio * self.block_len \
953+
kv_block_len = self.get_backend_aware_kv_block_len()
954+
rank_offset = self.tp_rank % tp_ratio * kv_block_len \
936955
if not (self.use_mla or is_kv_replicated) else 0
937956
# Register all remote blocks, but only the corresponding kv heads.
938957
for base_addr in nixl_agent_meta.kv_caches_base_addr:
@@ -943,7 +962,16 @@ def add_remote_agent(self,
943962
# self.block_len == remote_block_len//tp_ratio bytes.
944963
addr = base_addr + block_offset + rank_offset
945964
# (addr, len, device id)
946-
blocks_data.append((addr, self.block_len, remote_tp_rank))
965+
blocks_data.append((addr, kv_block_len, remote_tp_rank))
966+
967+
if self._use_flashinfer:
968+
# With FlashInfer index V separately to allow head splitting.
969+
for block_id in range(nixl_agent_meta.num_blocks):
970+
block_offset = block_id * nixl_agent_meta.block_len
971+
addr = base_addr + block_offset + rank_offset
972+
v_addr = addr + nixl_agent_meta.block_len // 2
973+
blocks_data.append((v_addr, kv_block_len, remote_tp_rank))
974+
947975
logger.debug(
948976
"Created %s blocks for dst engine %s with remote rank %s and "
949977
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
@@ -1249,6 +1277,22 @@ def _get_block_descs_ids(self,
12491277
descs_ids.append(reg_id * num_blocks + block_id)
12501278
return descs_ids
12511279

1280+
def get_backend_aware_kv_block_len(self):
1281+
"""
1282+
Get the block length for one K/V element (K and V have the same size).
1283+
1284+
For FA and other backends, this is equal to the length of the whole
1285+
block, as K and V are in separate regions.
1286+
For FlashInfer, this is half the length of the whole block, as K and V
1287+
share the same region.
1288+
"""
1289+
if self._use_flashinfer:
1290+
# For indexing only half (either just the K or V part).
1291+
block_len = self.block_len // 2
1292+
else:
1293+
block_len = self.block_len
1294+
return block_len
1295+
12521296

12531297
@contextlib.contextmanager
12541298
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:

0 commit comments

Comments
 (0)