11# SPDX-License-Identifier: Apache-2.0
22import copy
3+ from dataclasses import dataclass
34from typing import TYPE_CHECKING , Any , Optional
45
56import torch
2122logger = init_logger (__name__ )
2223
2324
24- class MultiKVConnectorMetadata (tuple [KVConnectorMetadata , ...],
25- KVConnectorMetadata ):
26- pass
25+ @dataclass
26+ class MultiKVConnectorMetadata (KVConnectorMetadata ):
27+ metadata : tuple [KVConnectorMetadata , ...]
28+ extra_async_saves : Optional [dict [str , int ]] = None
2729
2830
2931class MultiConnector (KVConnectorBase_V1 ):
@@ -54,6 +56,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
5456 # Keeps track of *additional* remaining async saves (beyond 1) to be
5557 # finished per request. Not needed for async loads since we only allow
5658 # a single connector to load.
59+ # Propagated from scheduler to worker side via the connector metadata.
5760 self ._extra_async_saves : dict [str , int ] = {}
5861
5962 def register_kv_caches (self , kv_caches : dict [str , torch .Tensor ]):
@@ -66,7 +69,10 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
6669 def bind_connector_metadata (
6770 self , connector_metadata : KVConnectorMetadata ) -> None :
6871 assert isinstance (connector_metadata , MultiKVConnectorMetadata )
69- for c , cm in zip (self ._connectors , connector_metadata ):
72+ if connector_metadata .extra_async_saves :
73+ self ._extra_async_saves .update (
74+ connector_metadata .extra_async_saves )
75+ for c , cm in zip (self ._connectors , connector_metadata .metadata ):
7076 c .bind_connector_metadata (cm )
7177
7278 def clear_connector_metadata (self ) -> None :
@@ -152,8 +158,13 @@ def update_state_after_alloc(self, request: "Request",
152158 def build_connector_meta (
153159 self ,
154160 scheduler_output : SchedulerOutput ) -> MultiKVConnectorMetadata :
155- return MultiKVConnectorMetadata (
156- c .build_connector_meta (scheduler_output ) for c in self ._connectors )
161+ metadata = MultiKVConnectorMetadata (metadata = tuple (
162+ c .build_connector_meta (scheduler_output )
163+ for c in self ._connectors ))
164+ if self ._extra_async_saves :
165+ metadata .extra_async_saves = self ._extra_async_saves
166+ self ._extra_async_saves = {}
167+ return metadata
157168
158169 def request_finished (
159170 self ,
0 commit comments