Skip to content

Commit c6ac725

Browse files
committed
add KVTransferParams class back to base.py and change block handling in nixl connector
1 parent 188ae95 commit c6ac725

File tree

2 files changed

+57
-8
lines changed

2 files changed

+57
-8
lines changed

vllm/distributed/kv_transfer/kv_connector/v1/base.py

Lines changed: 39 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
import enum
3434
from abc import ABC, abstractmethod
35-
from typing import TYPE_CHECKING, Any, Optional
35+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional
3636

3737
import torch
3838

@@ -46,6 +46,12 @@
4646
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
4747
from vllm.v1.request import Request
4848

49+
# s_tensor_list, d_tensor_list, s_indices, d_indices, direction
50+
CopyBlocksOp = Callable[[
51+
dict[str, torch.Tensor], dict[
52+
str, torch.Tensor], list[int], list[int], Literal["h2d", "d2h"]
53+
], None]
54+
4955
logger = init_logger(__name__)
5056

5157

@@ -60,7 +66,7 @@ class KVTransferParams:
6066
"""
6167
Abstract KVTransferParams used to send KVTransfer
6268
parameters between instances of vLLM.
63-
69+
6470
Specific instances of KVConnector customize this
6571
method for serializing / deserializing msgs sent
6672
via the HTTP protocol.
@@ -72,7 +78,7 @@ def from_raw_dict(
7278
Any]]) -> Optional["KVTransferParams"]:
7379
return None
7480

75-
class KVConnectorMetadata:
81+
class KVConnectorMetadata(ABC): # noqa: B024
7682
"""
7783
Abstract Metadata used to communicate between the
7884
Scheduler KVConnector and Worker KVConnector.
@@ -87,7 +93,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
8793
logger.warning(
8894
"Initializing KVConnectorBase_V1. This API is experimental and "
8995
"subject to change in the future as we iterate the design.")
90-
self._connector_metadata = KVConnectorMetadata()
96+
self._connector_metadata: Optional[KVConnectorMetadata] = None
9197
self._vllm_config = vllm_config
9298
self._role = role
9399

@@ -118,7 +124,7 @@ def clear_connector_metadata(self) -> None:
118124
This function should be called by the model runner every time
119125
after the model execution.
120126
"""
121-
self._connector_metadata = KVConnectorMetadata()
127+
self._connector_metadata = None
122128

123129
def _get_connector_metadata(self) -> KVConnectorMetadata:
124130
"""Get the connector metadata.
@@ -128,6 +134,9 @@ def _get_connector_metadata(self) -> KVConnectorMetadata:
128134
Returns:
129135
ConnectorMetadata: the connector metadata.
130136
"""
137+
138+
# Should only be called while set to valid metadata.
139+
assert self._connector_metadata is not None
131140
return self._connector_metadata
132141

133142
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
@@ -140,6 +149,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
140149
"""
141150
return
142151

152+
def set_host_xfer_buffer_ops(self, copy_operation: CopyBlocksOp):
153+
"""
154+
Set the xPU-specific ops for copying KV between host and device.
155+
Needed when host buffer is used for kv transfer (e.g., in NixlConnector)
156+
"""
157+
return
158+
143159
@abstractmethod
144160
def start_load_kv(self, forward_context: "ForwardContext",
145161
**kwargs) -> None:
@@ -206,7 +222,9 @@ def get_finished(
206222
) -> tuple[Optional[set[str]], Optional[set[str]]]:
207223
"""
208224
Notifies worker-side connector ids of requests that have
209-
finished generating tokens.
225+
finished generating tokens on the worker.
226+
The scheduler process (via the Executors) will use this output
227+
to track which workers are done.
210228
211229
Returns:
212230
ids of requests that have finished asynchronous transfer
@@ -226,7 +244,7 @@ def set_kv_transfer_params(self, request: "Request"):
226244
kv_transfer_params = self._KVTransferParams.from_raw_dict(
227245
request.raw_kv_transfer_params)
228246
request.kv_transfer_params = kv_transfer_params
229-
247+
230248
@abstractmethod
231249
def get_num_new_matched_tokens(
232250
self,
@@ -303,3 +321,17 @@ def request_finished(
303321
returned by the engine.
304322
"""
305323
return False, None
324+
325+
@classmethod
326+
def get_required_kvcache_layout(
327+
cls, vllm_config: "VllmConfig") -> Optional[str]:
328+
"""
329+
Get the required KV cache layout for this connector.
330+
Args:
331+
vllm_config (VllmConfig): the vllm config.
332+
333+
Returns:
334+
str: the required KV cache layout. e.g. HND, or NHD.
335+
None if the connector does not require a specific layout.
336+
"""
337+
return None

vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -734,16 +734,33 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
734734
block_size, kv_latent_dim = block_shape
735735
self.slot_size_bytes = kv_elem_size * kv_latent_dim
736736
else:
737+
# [2 (k and v), num_blocks, ...]
738+
#if self._use_flashinfer:
739+
# # FlashInfer swaps 2<->num_blocks dimensions.
740+
# self.num_blocks = first_kv_cache.shape[0]
741+
# block_rank = 4 # [2, block_size, kv_heads, head_dim]
742+
#else:
743+
# self.num_blocks = first_kv_cache.shape[1]
744+
# block_rank = 3 # [block_size, kv_heads, head_dim]
745+
#block_shape = first_kv_cache.shape[-block_rank:]
746+
#block_size, n_kv_heads, head_dim = block_shape[-3:]
747+
748+
# TODO see if below is necessary, else uncomment above
737749
# [2 (k and v), num_blocks, ...]
738750
if self._use_flashinfer:
739751
# FlashInfer swaps 2<->num_blocks dimensions.
740752
self.num_blocks = first_kv_cache.shape[0]
741753
block_rank = 4 # [2, block_size, kv_heads, head_dim]
742754
else:
743-
self.num_blocks = first_kv_cache.shape[1]
755+
# habana kv_cache: [2, num_blocks*block_size, kv_heads, head_dim]
756+
self.num_blocks = first_kv_cache.shape[1] // self.block_size
744757
block_rank = 3 # [block_size, kv_heads, head_dim]
745758
block_shape = first_kv_cache.shape[-block_rank:]
759+
block_shape = list(block_shape)
760+
block_shape[0] = block_shape[0] // self.num_blocks
761+
block_shape = torch.Size(block_shape)
746762
block_size, n_kv_heads, head_dim = block_shape[-3:]
763+
747764
# head size in bytes.
748765
self.slot_size_bytes = kv_elem_size * n_kv_heads * head_dim
749766
assert block_size == self.block_size

0 commit comments

Comments
 (0)