Skip to content

Commit 5e1e462

Browse files
tlrmchlsmthgarg-amit
authored andcommitted
[Bugfix] Bandaid fix for speculative decoding tests (vllm-project#9327)
Signed-off-by: Amit Garg <[email protected]>
1 parent d6d3698 commit 5e1e462

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

vllm/worker/model_runner.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import vllm.envs as envs
1818
from vllm.attention import AttentionMetadata, get_attn_backend
1919
from vllm.attention.backends.abstract import AttentionState
20+
from vllm.attention.backends.utils import CommonAttentionState
2021
from vllm.compilation.compile_context import set_compile_context
2122
from vllm.compilation.levels import CompilationLevel
2223
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
@@ -1028,16 +1029,30 @@ def __init__(
10281029
self.graph_block_tables = np.zeros(
10291030
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
10301031
dtype=np.int32)
1032+
1033+
# Attention-free but stateful models like Mamba need a placeholder attn
1034+
# backend, as the attention metadata is needed to manage internal state.
1035+
# However we must bypass attention selection altogether for some models
1036+
# used for speculative decoding to avoid a divide-by-zero in
1037+
# model_config.get_head_size()
1038+
num_attn_heads = self.model_config.get_num_attention_heads(
1039+
self.parallel_config)
1040+
needs_attn_backend = (num_attn_heads != 0
1041+
or self.model_config.is_attention_free)
1042+
10311043
self.attn_backend = get_attn_backend(
10321044
self.model_config.get_head_size(),
10331045
self.model_config.get_sliding_window(),
10341046
self.model_config.dtype,
10351047
self.kv_cache_dtype,
10361048
self.block_size,
10371049
self.model_config.is_attention_free,
1038-
)
1039-
self.attn_state = self.attn_backend.get_state_cls()(
1040-
weakref.proxy(self))
1050+
) if needs_attn_backend else None
1051+
if self.attn_backend:
1052+
self.attn_state = self.attn_backend.get_state_cls()(
1053+
weakref.proxy(self))
1054+
else:
1055+
self.attn_state = CommonAttentionState(weakref.proxy(self))
10411056

10421057
# Multi-modal data support
10431058
self.input_registry = input_registry

0 commit comments

Comments
 (0)