66from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
77 AttentionMetadata ,
88 AttentionMetadataBuilder )
9+ from vllm .attention .backends .utils import CommonAttentionState
910
1011if TYPE_CHECKING :
1112 from vllm .worker .model_runner import ModelInputForGPUBuilder
1213
13- # Placeholder attention backend for models like Mamba that don't have attention.
14- # Mainly exists to sidestep get_attn_backend.
15- # The attention metadata is still needed for Mamba.
14+ # Placeholder attention backend for models like Mamba and embedding models that
15+ # lack attention.
1616
1717
1818class PlaceholderAttentionBackend (AttentionBackend ):
@@ -34,6 +34,10 @@ def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
3434 def get_metadata_cls () -> Type ["PlaceholderAttentionMetadata" ]:
3535 return PlaceholderAttentionMetadata
3636
37+ @staticmethod
38+ def get_state_cls () -> Type ["CommonAttentionState" ]:
39+ return CommonAttentionState
40+
3741 @staticmethod
3842 def get_kv_cache_shape (
3943 num_blocks : int ,
@@ -118,11 +122,15 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
118122 assert self .context_lens_tensor is not None
119123 assert self .seq_start_loc is not None
120124
125+ # Placeholders
126+ slot_mapping = torch .empty (0 )
127+ block_tables = torch .empty (0 )
128+
121129 self ._cached_prefill_metadata = PlaceholderAttentionMetadata (
122130 num_prefills = self .num_prefills ,
123131 num_prefill_tokens = self .num_prefill_tokens ,
124132 num_decode_tokens = 0 ,
125- slot_mapping = None ,
133+ slot_mapping = slot_mapping ,
126134 seq_lens = self .seq_lens [:self .num_prefills ],
127135 seq_lens_tensor = self .seq_lens_tensor [:self .num_prefills ],
128136 max_query_len = self .max_query_len ,
@@ -131,7 +139,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
131139 query_start_loc = self .query_start_loc [:self .num_prefills + 1 ],
132140 seq_start_loc = self .seq_start_loc [:self .num_prefills + 1 ],
133141 context_lens_tensor = self .context_lens_tensor [:self .num_prefills ],
134- block_tables = None ,
142+ block_tables = block_tables ,
135143 use_cuda_graph = False ,
136144 )
137145 return self ._cached_prefill_metadata
@@ -145,11 +153,15 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
145153 return self ._cached_decode_metadata
146154 assert self .seq_lens_tensor is not None
147155
156+ # Placeholders
157+ slot_mapping = torch .empty (0 )
158+ block_tables = torch .empty (0 )
159+
148160 self ._cached_decode_metadata = PlaceholderAttentionMetadata (
149161 num_prefills = 0 ,
150162 num_prefill_tokens = 0 ,
151163 num_decode_tokens = self .num_decode_tokens ,
152- slot_mapping = None ,
164+ slot_mapping = slot_mapping ,
153165 seq_lens = None ,
154166 seq_lens_tensor = self .seq_lens_tensor [self .num_prefills :],
155167 max_query_len = None ,
@@ -158,7 +170,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
158170 query_start_loc = None ,
159171 seq_start_loc = None ,
160172 context_lens_tensor = None ,
161- block_tables = None ,
173+ block_tables = block_tables ,
162174 use_cuda_graph = self .use_cuda_graph ,
163175 )
164176 return self ._cached_decode_metadata
@@ -266,9 +278,13 @@ def build(self, seq_lens: List[int], query_lens: List[int],
266278 dtype = query_start_loc .dtype ,
267279 out = query_start_loc [1 :])
268280
281+ # Placeholders
282+ slot_mapping = torch .empty (0 )
283+ block_tables = torch .empty (0 )
284+
269285 return PlaceholderAttentionMetadata (
270286 num_prefills = self .num_prefills ,
271- slot_mapping = None ,
287+ slot_mapping = slot_mapping ,
272288 num_prefill_tokens = self .num_prefill_tokens ,
273289 num_decode_tokens = num_decode_tokens ,
274290 seq_lens = seq_lens ,
@@ -279,7 +295,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
279295 query_start_loc = query_start_loc ,
280296 seq_start_loc = seq_start_loc ,
281297 context_lens_tensor = context_lens_tensor ,
282- block_tables = None ,
298+ block_tables = block_tables ,
283299 use_cuda_graph = use_captured_graph ,
284300 )
285301
0 commit comments