3232
3333import enum
3434from abc import ABC , abstractmethod
35- from typing import TYPE_CHECKING , Any , Optional
35+ from typing import TYPE_CHECKING , Any , Callable , Literal , Optional
3636
3737import torch
3838
4646 from vllm .v1 .core .kv_cache_manager import KVCacheBlocks
4747 from vllm .v1 .request import Request
4848
49+ # s_tensor_list, d_tensor_list, s_indices, d_indices, direction
50+ CopyBlocksOp = Callable [[
51+ dict [str , torch .Tensor ], dict [
52+ str , torch .Tensor ], list [int ], list [int ], Literal ["h2d" , "d2h" ]
53+ ], None ]
54+
4955logger = init_logger (__name__ )
5056
5157
@@ -60,7 +66,7 @@ class KVTransferParams:
6066 """
6167 Abstract KVTransferParams used to send KVTransfer
6268 parameters between instances of vLLM.
63-
69+
6470 Specific instances of KVConnector customize this
6571 method for serializing / deserializing msgs sent
6672 via the HTTP protocol.
@@ -72,7 +78,7 @@ def from_raw_dict(
7278 Any ]]) -> Optional ["KVTransferParams" ]:
7379 return None
7480
75- class KVConnectorMetadata :
81+ class KVConnectorMetadata ( ABC ): # noqa: B024
7682 """
7783 Abstract Metadata used to communicate between the
7884 Scheduler KVConnector and Worker KVConnector.
@@ -87,7 +93,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
8793 logger .warning (
8894 "Initializing KVConnectorBase_V1. This API is experimental and "
8995 "subject to change in the future as we iterate the design." )
90- self ._connector_metadata = KVConnectorMetadata ()
96+ self ._connector_metadata : Optional [ KVConnectorMetadata ] = None
9197 self ._vllm_config = vllm_config
9298 self ._role = role
9399
@@ -118,7 +124,7 @@ def clear_connector_metadata(self) -> None:
118124 This function should be called by the model runner every time
119125 after the model execution.
120126 """
121- self ._connector_metadata = KVConnectorMetadata ()
127+ self ._connector_metadata = None
122128
123129 def _get_connector_metadata (self ) -> KVConnectorMetadata :
124130 """Get the connector metadata.
@@ -128,6 +134,9 @@ def _get_connector_metadata(self) -> KVConnectorMetadata:
128134 Returns:
129135 ConnectorMetadata: the connector metadata.
130136 """
137+
138+ # Should only be called while set to valid metadata.
139+ assert self ._connector_metadata is not None
131140 return self ._connector_metadata
132141
133142 def register_kv_caches (self , kv_caches : dict [str , torch .Tensor ]):
@@ -140,6 +149,13 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
140149 """
141150 return
142151
152+ def set_host_xfer_buffer_ops (self , copy_operation : CopyBlocksOp ):
153+ """
154+ Set the xPU-specific ops for copying KV between host and device.
155+ Needed when host buffer is used for kv transfer (e.g., in NixlConnector)
156+ """
157+ return
158+
143159 @abstractmethod
144160 def start_load_kv (self , forward_context : "ForwardContext" ,
145161 ** kwargs ) -> None :
@@ -206,7 +222,9 @@ def get_finished(
206222 ) -> tuple [Optional [set [str ]], Optional [set [str ]]]:
207223 """
208224 Notifies worker-side connector ids of requests that have
209- finished generating tokens.
225+ finished generating tokens on the worker.
226+ The scheduler process (via the Executors) will use this output
227+ to track which workers are done.
210228
211229 Returns:
212230 ids of requests that have finished asynchronous transfer
@@ -226,7 +244,7 @@ def set_kv_transfer_params(self, request: "Request"):
226244 kv_transfer_params = self ._KVTransferParams .from_raw_dict (
227245 request .raw_kv_transfer_params )
228246 request .kv_transfer_params = kv_transfer_params
229-
247+
230248 @abstractmethod
231249 def get_num_new_matched_tokens (
232250 self ,
@@ -303,3 +321,17 @@ def request_finished(
303321 returned by the engine.
304322 """
305323 return False , None
324+
325+ @classmethod
326+ def get_required_kvcache_layout (
327+ cls , vllm_config : "VllmConfig" ) -> Optional [str ]:
328+ """
329+ Get the required KV cache layout for this connector.
330+ Args:
331+ vllm_config (VllmConfig): the vllm config.
332+
333+ Returns:
334+ str: the required KV cache layout. e.g. HND, or NHD.
335+ None if the connector does not require a specific layout.
336+ """
337+ return None
0 commit comments