|
17 | 17 | import vllm.envs as envs |
18 | 18 | from vllm.attention import AttentionMetadata, get_attn_backend |
19 | 19 | from vllm.attention.backends.abstract import AttentionState |
| 20 | +from vllm.attention.backends.utils import CommonAttentionState |
20 | 21 | from vllm.compilation.compile_context import set_compile_context |
21 | 22 | from vllm.compilation.levels import CompilationLevel |
22 | 23 | from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, |
@@ -1028,16 +1029,30 @@ def __init__( |
1028 | 1029 | self.graph_block_tables = np.zeros( |
1029 | 1030 | (self.max_batchsize_to_capture, self.get_max_block_per_batch()), |
1030 | 1031 | dtype=np.int32) |
| 1032 | + |
| 1033 | + # Attention-free but stateful models like Mamba need a placeholder attn |
| 1034 | + # backend, as the attention metadata is needed to manage internal state. |
| 1035 | + # However we must bypass attention selection altogether for some models |
| 1036 | + # used for speculative decoding to avoid a divide-by-zero in |
| 1037 | + # model_config.get_head_size() |
| 1038 | + num_attn_heads = self.model_config.get_num_attention_heads( |
| 1039 | + self.parallel_config) |
| 1040 | + needs_attn_backend = (num_attn_heads != 0 |
| 1041 | + or self.model_config.is_attention_free) |
| 1042 | + |
1031 | 1043 | self.attn_backend = get_attn_backend( |
1032 | 1044 | self.model_config.get_head_size(), |
1033 | 1045 | self.model_config.get_sliding_window(), |
1034 | 1046 | self.model_config.dtype, |
1035 | 1047 | self.kv_cache_dtype, |
1036 | 1048 | self.block_size, |
1037 | 1049 | self.model_config.is_attention_free, |
1038 | | - ) |
1039 | | - self.attn_state = self.attn_backend.get_state_cls()( |
1040 | | - weakref.proxy(self)) |
| 1050 | + ) if needs_attn_backend else None |
| 1051 | + if self.attn_backend: |
| 1052 | + self.attn_state = self.attn_backend.get_state_cls()( |
| 1053 | + weakref.proxy(self)) |
| 1054 | + else: |
| 1055 | + self.attn_state = CommonAttentionState(weakref.proxy(self)) |
1041 | 1056 |
|
1042 | 1057 | # Multi-modal data support |
1043 | 1058 | self.input_registry = input_registry |
|
0 commit comments