Skip to content

Commit d88a978

Browse files
NickLuccheBoyuanFeng
authored andcommitted
[CI][Nixl] Check kv cache layout during handshake (vllm-project#22745)
Signed-off-by: NickLucche <[email protected]> Signed-off-by: Boyuan Feng <[email protected]>
1 parent 0991c73 commit d88a978

File tree

2 files changed

+56
-3
lines changed

2 files changed

+56
-3
lines changed

tests/v1/kv_connector/unit/test_nixl_connector.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,52 @@ def test_concurrent_load_kv(
419419
return
420420
raise TimeoutError("Took too long to complete async handshake.")
421421

422+
@patch(
423+
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
424+
FakeNixlWrapper)
425+
def test_handshake_fails_on_kv_cache_layout_mismatch(self, dist_init):
426+
"""
427+
Verify that adding a remote agent fails if kv_cache_layout differs.
428+
This test is only relevant for heterogeneous TP.
429+
"""
430+
vllm_config = create_vllm_config()
431+
432+
# Mock TP world size to 2 to force heterogeneous TP when
433+
# remote_tp_size=1
434+
with patch(
435+
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.get_tensor_model_parallel_world_size", # noqa: E501
436+
return_value=2):
437+
# Initialize connector and worker (with fake NIXL wrapper)
438+
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
439+
connector.connector_worker = FakeNixlConnectorWorker(
440+
vllm_config, connector.engine_id, hand_shake_latency=0)
441+
worker = connector.connector_worker
442+
443+
# Minimal local registration params used by add_remote_agent
444+
worker.slot_size_bytes = 4096
445+
worker.block_len = worker.slot_size_bytes * worker.block_size
446+
worker.num_blocks = 1
447+
worker.dst_num_blocks[worker.engine_id] = worker.num_blocks
448+
449+
# Metadata with different kv_cache_layout than local worker
450+
mismatched_layout = "HND" if worker.kv_cache_layout != "HND" \
451+
else "NHD"
452+
meta = NixlAgentMetadata(
453+
engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
454+
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
455+
kv_caches_base_addr=[0],
456+
num_blocks=1,
457+
block_len=worker.block_len,
458+
attn_backend_name=worker.backend_name,
459+
kv_cache_layout=mismatched_layout,
460+
)
461+
462+
# We don't check layout for homogeneous TP and MLA for now, as the
463+
# whole block is moved.
464+
worker.add_remote_agent(meta, remote_tp_size=2)
465+
with pytest.raises(AssertionError):
466+
worker.add_remote_agent(meta, remote_tp_size=1)
467+
422468

423469
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
424470
# we put here is important. First run ray, it will clean up the resources, then

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from vllm.logger import init_logger
3131
from vllm.platforms import _Backend, current_platform
3232
from vllm.utils import make_zmq_path, make_zmq_socket
33+
from vllm.v1.attention.backends.utils import get_kv_cache_layout
3334
from vllm.v1.core.sched.output import SchedulerOutput
3435
from vllm.v1.request import RequestStatus
3536

@@ -73,6 +74,7 @@ class NixlAgentMetadata(
7374
num_blocks: int
7475
block_len: int
7576
attn_backend_name: str
77+
kv_cache_layout: str
7678

7779

7880
@dataclass
@@ -538,7 +540,9 @@ def __init__(self, vllm_config: VllmConfig, engine_id: str):
538540
attn_backend = backend_name_to_enum(self.backend_name)
539541
self._use_flashinfer = attn_backend == _Backend.FLASHINFER_VLLM_V1
540542
self._use_pallas_v1 = attn_backend == _Backend.PALLAS_VLLM_V1
543+
self.kv_cache_layout = get_kv_cache_layout()
541544
logger.debug("Detected attention backend %s", self.backend_name)
545+
logger.debug("Detected kv cache layout %s", self.kv_cache_layout)
542546

543547
self._tp_size: dict[EngineId, int] = {self.engine_id: self.world_size}
544548
# With heterogeneous TP, P must wait for all assigned D TP workers to
@@ -839,7 +843,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
839843
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
840844
num_blocks=self.num_blocks,
841845
block_len=self.block_len,
842-
attn_backend_name=self.backend_name)
846+
attn_backend_name=self.backend_name,
847+
kv_cache_layout=self.kv_cache_layout)
843848
ready_event = threading.Event()
844849
self._nixl_handshake_listener_t = threading.Thread(
845850
target=self._nixl_handshake_listener,
@@ -900,8 +905,7 @@ def add_remote_agent(self,
900905
self._tp_size[engine_id] = remote_tp_size
901906
else:
902907
assert self._tp_size[engine_id] == remote_tp_size
903-
# We may eventually enable this after asserting equality in cache
904-
# layout and close outputs.
908+
# TODO We may eventually want to skip enforcing the same attn backend.
905909
assert nixl_agent_meta.attn_backend_name == self.backend_name
906910

907911
remote_agent_name = self.nixl_wrapper.add_remote_agent(
@@ -930,6 +934,9 @@ def add_remote_agent(self,
930934
if self._use_flashinfer:
931935
# Account for joint KV in FlashInfer.
932936
remote_block_size //= 2
937+
if tp_ratio > 1:
938+
# Heterogeneous TP expects same kv_cache_layout.
939+
assert nixl_agent_meta.kv_cache_layout == self.kv_cache_layout
933940

934941
assert nixl_agent_meta.block_len == self.block_len * tp_ratio, (
935942
"Remote P worker KV layer cache must be of shape [2, N, "

0 commit comments

Comments
 (0)