|
17 | 17 | import vllm.envs as envs |
18 | 18 | from vllm.attention import AttentionMetadata, get_attn_backend |
19 | 19 | from vllm.attention.backends.abstract import AttentionState |
| 20 | +from vllm.attention.backends.utils import CommonAttentionState |
20 | 21 | from vllm.compilation.compile_context import set_compile_context |
21 | 22 | from vllm.compilation.levels import CompilationLevel |
22 | 23 | from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, |
@@ -1001,16 +1002,30 @@ def __init__( |
1001 | 1002 | self.graph_block_tables = np.zeros( |
1002 | 1003 | (self.max_batchsize_to_capture, self.get_max_block_per_batch()), |
1003 | 1004 | dtype=np.int32) |
| 1005 | + |
| 1006 | + # Attention-free but stateful models like Mamba need a placeholder attn |
| 1007 | + # backend, as the attention metadata is needed to manage internal state. |
| 1008 | + # However we must bypass attention selection altogether for some models |
| 1009 | + # used for speculative decoding to avoid a divide-by-zero in |
| 1010 | + # model_config.get_head_size() |
| 1011 | + num_attn_heads = self.model_config.get_num_attention_heads( |
| 1012 | + self.parallel_config) |
| 1013 | + needs_attn_backend = (num_attn_heads != 0 |
| 1014 | + or self.model_config.is_attention_free) |
| 1015 | + |
1004 | 1016 | self.attn_backend = get_attn_backend( |
1005 | 1017 | self.model_config.get_head_size(), |
1006 | 1018 | self.model_config.get_sliding_window(), |
1007 | 1019 | self.model_config.dtype, |
1008 | 1020 | self.kv_cache_dtype, |
1009 | 1021 | self.block_size, |
1010 | 1022 | self.model_config.is_attention_free, |
1011 | | - ) |
1012 | | - self.attn_state = self.attn_backend.get_state_cls()( |
1013 | | - weakref.proxy(self)) |
| 1023 | + ) if needs_attn_backend else None |
| 1024 | + if self.attn_backend: |
| 1025 | + self.attn_state = self.attn_backend.get_state_cls()( |
| 1026 | + weakref.proxy(self)) |
| 1027 | + else: |
| 1028 | + self.attn_state = CommonAttentionState(weakref.proxy(self)) |
1014 | 1029 |
|
1015 | 1030 | # Multi-modal data support |
1016 | 1031 | self.input_registry = input_registry |
|
0 commit comments