Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit f87a8e2

Browse files
committed
fixes
1 parent 9ba8734 commit f87a8e2

File tree

15 files changed

+60
-77
lines changed

15 files changed

+60
-77
lines changed

examples/offline_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
1212

1313
# Create an LLM.
14-
llm = LLM(model="facebook/opt-125m")
14+
llm = LLM(model="state-spaces/mamba-370m-hf")
1515
# Generate texts from the prompts. The output is a list of RequestOutput objects
1616
# that contain the prompt, generated text, and other information.
1717
outputs = llm.generate(prompts, sampling_params)

vllm/attention/backends/placeholder_attn.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
77
AttentionMetadata,
88
AttentionMetadataBuilder)
9+
from vllm.attention.backends.utils import CommonAttentionState
910

1011
if TYPE_CHECKING:
1112
from vllm.worker.model_runner import ModelInputForGPUBuilder
1213

13-
# Placeholder attention backend for models like Mamba that don't have attention.
14-
# Mainly exists to sidestep get_attn_backend.
15-
# The attention metadata is still needed for Mamba.
14+
# Placeholder attention backend for models like Mamba and embedding models that
15+
# lack attention.
1616

1717

1818
class PlaceholderAttentionBackend(AttentionBackend):
@@ -34,6 +34,10 @@ def get_builder_cls() -> Type["PlaceholderAttentionMetadataBuilder"]:
3434
def get_metadata_cls() -> Type["PlaceholderAttentionMetadata"]:
3535
return PlaceholderAttentionMetadata
3636

37+
@staticmethod
38+
def get_state_cls() -> Type["CommonAttentionState"]:
39+
return CommonAttentionState
40+
3741
@staticmethod
3842
def get_kv_cache_shape(
3943
num_blocks: int,
@@ -118,11 +122,15 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
118122
assert self.context_lens_tensor is not None
119123
assert self.seq_start_loc is not None
120124

125+
# Placeholders
126+
slot_mapping = torch.empty(0)
127+
block_tables = torch.empty(0)
128+
121129
self._cached_prefill_metadata = PlaceholderAttentionMetadata(
122130
num_prefills=self.num_prefills,
123131
num_prefill_tokens=self.num_prefill_tokens,
124132
num_decode_tokens=0,
125-
slot_mapping=None,
133+
slot_mapping=slot_mapping,
126134
seq_lens=self.seq_lens[:self.num_prefills],
127135
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
128136
max_query_len=self.max_query_len,
@@ -131,7 +139,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
131139
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
132140
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
133141
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
134-
block_tables=None,
142+
block_tables=block_tables,
135143
use_cuda_graph=False,
136144
)
137145
return self._cached_prefill_metadata
@@ -145,11 +153,15 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
145153
return self._cached_decode_metadata
146154
assert self.seq_lens_tensor is not None
147155

156+
# Placeholders
157+
slot_mapping = torch.empty(0)
158+
block_tables = torch.empty(0)
159+
148160
self._cached_decode_metadata = PlaceholderAttentionMetadata(
149161
num_prefills=0,
150162
num_prefill_tokens=0,
151163
num_decode_tokens=self.num_decode_tokens,
152-
slot_mapping=None,
164+
slot_mapping=slot_mapping,
153165
seq_lens=None,
154166
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
155167
max_query_len=None,
@@ -158,7 +170,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
158170
query_start_loc=None,
159171
seq_start_loc=None,
160172
context_lens_tensor=None,
161-
block_tables=None,
173+
block_tables=block_tables,
162174
use_cuda_graph=self.use_cuda_graph,
163175
)
164176
return self._cached_decode_metadata
@@ -266,9 +278,13 @@ def build(self, seq_lens: List[int], query_lens: List[int],
266278
dtype=query_start_loc.dtype,
267279
out=query_start_loc[1:])
268280

281+
# Placeholders
282+
slot_mapping = torch.empty(0)
283+
block_tables = torch.empty(0)
284+
269285
return PlaceholderAttentionMetadata(
270286
num_prefills=self.num_prefills,
271-
slot_mapping=None,
287+
slot_mapping=slot_mapping,
272288
num_prefill_tokens=self.num_prefill_tokens,
273289
num_decode_tokens=num_decode_tokens,
274290
seq_lens=seq_lens,
@@ -279,7 +295,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
279295
query_start_loc=query_start_loc,
280296
seq_start_loc=seq_start_loc,
281297
context_lens_tensor=context_lens_tensor,
282-
block_tables=None,
298+
block_tables=block_tables,
283299
use_cuda_graph=use_captured_graph,
284300
)
285301

vllm/attention/layer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@ def __init__(
7878
# During model initialization, the default dtype is set as the model
7979
# weight and activation dtype.
8080
dtype = torch.get_default_dtype()
81-
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
82-
sliding_window, dtype, kv_cache_dtype,
83-
block_size, is_attention_free,
84-
blocksparse_params is not None)
81+
attn_backend = get_attn_backend(head_size, sliding_window, dtype,
82+
kv_cache_dtype, block_size,
83+
is_attention_free, blocksparse_params
84+
is not None)
8585
impl_cls = attn_backend.get_impl_cls()
8686
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
8787
alibi_slopes, sliding_window, kv_cache_dtype,

vllm/attention/selector.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,14 +89,12 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
8989

9090
@lru_cache(maxsize=None)
9191
def get_attn_backend(
92-
num_heads: int,
9392
head_size: int,
94-
num_kv_heads: int,
9593
sliding_window: Optional[int],
9694
dtype: torch.dtype,
9795
kv_cache_dtype: Optional[str],
9896
block_size: int,
99-
is_attention_free: bool, #TODO: pass in from all users
97+
is_attention_free: bool,
10098
is_blocksparse: bool = False,
10199
) -> Type[AttentionBackend]:
102100
"""Selects which attention backend to use and lazily imports it."""
@@ -107,9 +105,8 @@ def get_attn_backend(
107105
BlocksparseFlashAttentionBackend)
108106
return BlocksparseFlashAttentionBackend
109107

110-
backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
111-
sliding_window, dtype, kv_cache_dtype,
112-
block_size, is_attention_free)
108+
backend = which_attn_to_use(head_size, sliding_window, dtype,
109+
kv_cache_dtype, block_size, is_attention_free)
113110
if backend == _Backend.FLASH_ATTN:
114111
from vllm.attention.backends.flash_attn import ( # noqa: F401
115112
FlashAttentionBackend)
@@ -157,9 +154,7 @@ def get_attn_backend(
157154

158155

159156
def which_attn_to_use(
160-
num_heads: int,
161157
head_size: int,
162-
num_kv_heads: int,
163158
sliding_window: Optional[int],
164159
dtype: torch.dtype,
165160
kv_cache_dtype: Optional[str],

vllm/model_executor/models/jamba.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -616,10 +616,9 @@ def forward(self,
616616
num_mamba_layers = sum(
617617
[layer_type == "mamba" for layer_type in layers_type])
618618

619-
self.mamba_cache = MambaCacheManager(self.lm_head.weight.dtype,
620-
num_mamba_layers,
621-
max_batch_size,
622-
*self._get_mamba_cache_shape())
619+
self.mamba_cache = MambaCacheManager(
620+
self.lm_head.weight.dtype, num_mamba_layers, max_batch_size,
621+
*self._get_mamba_cache_shape())
623622

624623
if "seqlen_agnostic_capture_inputs" not in kwargs:
625624
# We get here only on Prefill/Eager mode runs
@@ -645,7 +644,8 @@ def forward(self,
645644
mamba_cache_tensors[1])
646645
return hidden_states
647646

648-
def _get_mamba_cache_shape(self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
647+
def _get_mamba_cache_shape(
648+
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
649649
world_size = get_tensor_model_parallel_world_size()
650650
hidden_size = self.config.hidden_size
651651
conv_state_shape = (

vllm/model_executor/models/mamba.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -440,10 +440,9 @@ def forward(self,
440440
max_batch_size = (_get_graph_batch_size(
441441
self.scheduler_config.max_num_seqs) if self.scheduler_config
442442
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
443-
self.mamba_cache = MambaCacheManager(self.lm_head.weight.dtype,
444-
self.config.num_hidden_layers,
445-
max_batch_size,
446-
*self._get_mamba_cache_shape())
443+
self.mamba_cache = MambaCacheManager(
444+
self.lm_head.weight.dtype, self.config.num_hidden_layers,
445+
max_batch_size, *self._get_mamba_cache_shape())
447446

448447
if "seqlen_agnostic_capture_inputs" not in kwargs:
449448
# We get here only on Prefill/Eager mode runs
@@ -471,7 +470,8 @@ def forward(self,
471470

472471
return hidden_states
473472

474-
def _get_mamba_cache_shape(self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
473+
def _get_mamba_cache_shape(
474+
self) -> Tuple[Tuple[int, int], Tuple[int, int]]:
475475
world_size = get_tensor_model_parallel_world_size()
476476
conv_state_shape = (
477477
self.config.intermediate_size // world_size,

vllm/model_executor/models/mamba_cache.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
1-
from typing import Dict, List, Optional, Tuple
1+
from typing import Dict, List, Optional
22

33
import torch
44

5-
from vllm.distributed import get_tensor_model_parallel_world_size
6-
75

86
class MambaCacheManager:
97

108
def __init__(self, dtype, num_mamba_layers, max_batch_size,
11-
conv_state_shape, temporal_state_shape):
9+
conv_state_shape, temporal_state_shape):
1210

1311
conv_state = torch.empty(size=(num_mamba_layers, max_batch_size) +
1412
conv_state_shape,

vllm/worker/cache_engine.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,12 @@ def __init__(
5252
self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
5353

5454
# Get attention backend.
55-
self.attn_backend = get_attn_backend(
56-
model_config.get_num_attention_heads(parallel_config),
57-
self.head_size, self.num_kv_heads,
58-
model_config.get_sliding_window(), model_config.dtype,
59-
cache_config.cache_dtype, self.block_size,
60-
model_config.is_attention_free())
55+
self.attn_backend = get_attn_backend(self.head_size,
56+
model_config.get_sliding_window(),
57+
model_config.dtype,
58+
cache_config.cache_dtype,
59+
self.block_size,
60+
model_config.is_attention_free())
6161

6262
# Initialize the cache.
6363
self.gpu_cache = self._allocate_kv_cache(

vllm/worker/cpu_model_runner.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,12 @@ def __init__(
103103
self.sliding_window = model_config.get_sliding_window()
104104
self.block_size = cache_config.block_size
105105
self.attn_backend = get_attn_backend(
106-
self.model_config.get_num_attention_heads(self.parallel_config),
107106
self.model_config.get_head_size(),
108-
self.model_config.get_num_kv_heads(self.parallel_config),
109107
self.model_config.get_sliding_window(),
110108
self.model_config.dtype,
111109
self.kv_cache_dtype,
112110
self.block_size,
111+
self.model_config.is_attention_free(),
113112
)
114113

115114
# Multi-modal data support

vllm/worker/cpu_worker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,12 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
5555

5656
# Get attention backend.
5757
self.attn_backend = get_attn_backend(
58-
self.model_config.get_num_attention_heads(self.parallel_config),
5958
self.model_config.get_head_size(),
60-
self.model_config.get_num_kv_heads(self.parallel_config),
6159
self.model_config.get_sliding_window(),
6260
self.model_config.dtype,
6361
cache_config.cache_dtype,
6462
self.block_size,
63+
self.model_config.is_attention_free(),
6564
)
6665

6766
# Initialize the cache.

0 commit comments

Comments
 (0)