Skip to content

Commit db3e239

Browse files
Josephasafgasafgardin
authored andcommitted
[v1] - Mamba1 Attention Metadata (vllm-project#21249)
Signed-off-by: asafg <[email protected]> Co-authored-by: asafg <[email protected]>
1 parent 517f8db commit db3e239

File tree

19 files changed

+377
-171
lines changed

19 files changed

+377
-171
lines changed

csrc/mamba/mamba_ssm/selective_scan.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ struct SSMParamsBase {
4545
index_t out_d_stride;
4646
index_t out_z_batch_stride;
4747
index_t out_z_d_stride;
48+
index_t ssm_states_batch_stride;
49+
index_t ssm_states_dim_stride;
50+
index_t ssm_states_dstate_stride;
4851

4952
// Common data pointers.
5053
void *__restrict__ A_ptr;

csrc/mamba/mamba_ssm/selective_scan_fwd.cu

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,10 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
132132
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
133133
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
134134
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
135-
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate;
136-
135+
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) +
136+
cache_index * params.ssm_states_batch_stride +
137+
dim_id * kNRows * params.ssm_states_dim_stride;
138+
137139
float D_val[kNRows] = {0};
138140
if (params.D_ptr != nullptr) {
139141
#pragma unroll
@@ -248,7 +250,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
248250
}
249251
// Initialize running total
250252

251-
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0);
253+
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx * params.ssm_states_dstate_stride]): 0.0);
252254

253255
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
254256
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
@@ -259,7 +261,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
259261
if (threadIdx.x == 0) {
260262
smem_running_prefix[state_idx] = prefix_op.running_prefix;
261263
if (chunk == n_chunks - 1) {
262-
ssm_states[state_idx] = input_t(prefix_op.running_prefix.y);
264+
ssm_states[state_idx * params.ssm_states_dstate_stride] = input_t(prefix_op.running_prefix.y);
263265
}
264266
}
265267
#pragma unroll
@@ -481,6 +483,10 @@ void set_ssm_params_fwd(SSMParamsBase &params,
481483
params.out_batch_stride = out.stride(1);
482484
params.out_d_stride = out.stride(0);
483485

486+
params.ssm_states_batch_stride = ssm_states.stride(0);
487+
params.ssm_states_dim_stride = ssm_states.stride(1);
488+
params.ssm_states_dstate_stride = ssm_states.stride(2);
489+
484490
}
485491
else{
486492
if (!is_variable_B) {
@@ -509,6 +515,10 @@ void set_ssm_params_fwd(SSMParamsBase &params,
509515
}
510516
params.out_batch_stride = out.stride(0);
511517
params.out_d_stride = out.stride(1);
518+
519+
params.ssm_states_batch_stride = ssm_states.stride(0);
520+
params.ssm_states_dim_stride = ssm_states.stride(1);
521+
params.ssm_states_dstate_stride = ssm_states.stride(2);
512522
}
513523
}
514524

docs/models/supported_models.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,9 +370,9 @@ th {
370370
| `InternLM2ForCausalLM` | InternLM2 | `internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc. | ✅︎ | ✅︎ | ✅︎ |
371371
| `InternLM3ForCausalLM` | InternLM3 | `internlm/internlm3-8b-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
372372
| `JAISLMHeadModel` | Jais | `inceptionai/jais-13b`, `inceptionai/jais-13b-chat`, `inceptionai/jais-30b-v3`, `inceptionai/jais-30b-chat-v3`, etc. | | ✅︎ | ✅︎ |
373-
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | |
373+
| `JambaForCausalLM` | Jamba | `ai21labs/AI21-Jamba-1.5-Large`, `ai21labs/AI21-Jamba-1.5-Mini`, `ai21labs/Jamba-v0.1`, etc. | ✅︎ | ✅︎ | ✅︎ |
374374
| `LlamaForCausalLM` | Llama 3.1, Llama 3, Llama 2, LLaMA, Yi | `meta-llama/Meta-Llama-3.1-405B-Instruct`, `meta-llama/Meta-Llama-3.1-70B`, `meta-llama/Meta-Llama-3-70B-Instruct`, `meta-llama/Llama-2-70b-hf`, `01-ai/Yi-34B`, etc. | ✅︎ | ✅︎ | ✅︎ |
375-
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | |
375+
| `MambaForCausalLM` | Mamba | `state-spaces/mamba-130m-hf`, `state-spaces/mamba-790m-hf`, `state-spaces/mamba-2.8b-hf`, etc. | | ✅︎ | ✅︎ |
376376
| `Mamba2ForCausalLM` | Mamba2 | `mistralai/Mamba-Codestral-7B-v0.1`, etc. | | ✅︎ | ✅︎ |
377377
| `MiMoForCausalLM` | MiMo | `XiaomiMiMo/MiMo-7B-RL`, etc. | ✅︎ | ✅︎ | ✅︎ |
378378
| `MiniCPMForCausalLM` | MiniCPM | `openbmb/MiniCPM-2B-sft-bf16`, `openbmb/MiniCPM-2B-dpo-bf16`, `openbmb/MiniCPM-S-1B-sft`, etc. | ✅︎ | ✅︎ | ✅︎ |

docs/usage/v1_guide.md

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ based on assigned priority, with FCFS as a tie-breaker), configurable via the
8383
| **Decoder-only Models** | <nobr>🚀 Optimized</nobr> |
8484
| **Encoder-Decoder Models** | <nobr>🟠 Delayed</nobr> |
8585
| **Embedding Models** | <nobr>🟢 Functional</nobr> |
86-
| **Mamba Models** | <nobr>🟢 (Mamba-2), 🟡 (Mamba-1)</nobr> |
86+
| **Mamba Models** | <nobr>🟢 (Mamba-2), 🟢 (Mamba-1)</nobr> |
8787
| **Multimodal Models** | <nobr>🟢 Functional</nobr> |
8888

8989
vLLM V1 currently excludes model architectures with the `SupportsV0Only` protocol.
@@ -104,13 +104,11 @@ to enable simultaneous generation and embedding using the same engine instance i
104104

105105
#### Mamba Models
106106

107-
Models using selective state-space mechanisms instead of standard transformer attention are partially supported.
108-
Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers
109-
(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet supported. Please note that these models currently require
110-
disabling prefix caching in V1.
107+
Models using selective state-space mechanisms instead of standard transformer attention are supported.
108+
Models that use Mamba-2 and Mamba-1 layers (e.g., `Mamba2ForCausalLM`, `MambaForCausalLM`) are supported. Please note that these models currently require disabling prefix caching in V1. Additionally, Mamba-1 models require `enforce_eager=True`.
111109

112-
Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
113-
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that
110+
Models that combine Mamba-2 and Mamba-1 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
111+
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`, `JambaForCausalLM`). Please note that
114112
these models currently require disabling prefix caching and using the FlashInfer attention backend in V1.
115113

116114
#### Encoder-Decoder Models

tests/models/language/generation/test_hybrid.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
]
5454

5555
V1_SUPPORTED_MODELS = [
56+
"state-spaces/mamba-130m-hf",
57+
"ai21labs/Jamba-tiny-dev",
5658
"mistralai/Mamba-Codestral-7B-v0.1",
5759
"ibm-ai-platform/Bamba-9B-v1",
5860
"Zyphra/Zamba2-1.2B-instruct",

tests/v1/test_oracle.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
UNSUPPORTED_MODELS_V1 = [
1313
"openai/whisper-large-v3", # transcription
1414
"facebook/bart-large-cnn", # encoder decoder
15-
"state-spaces/mamba-130m-hf", # mamba1
1615
]
1716

1817
MODEL = "meta-llama/Llama-3.2-1B-Instruct"

vllm/model_executor/layers/mamba/mamba_mixer.py

Lines changed: 108 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,37 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
from typing import Optional
5+
46
import torch
57
from torch import nn
68
from torch.nn.parameter import Parameter
79

8-
from vllm.attention.backends.abstract import AttentionMetadata
10+
from vllm import envs
11+
from vllm.config import get_current_vllm_config
912
from vllm.distributed.parallel_state import (
1013
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
11-
from vllm.forward_context import get_forward_context
14+
from vllm.forward_context import ForwardContext, get_forward_context
1215
from vllm.model_executor.custom_op import CustomOp
1316
from vllm.model_executor.layers.layernorm import RMSNorm
1417
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1518
MergedColumnParallelLinear,
1619
RowParallelLinear)
20+
from vllm.model_executor.layers.mamba.abstract import MambaBase
21+
from vllm.model_executor.layers.mamba.mamba_utils import (
22+
MambaStateShapeCalculator)
1723
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
1824
causal_conv1d_fn, causal_conv1d_update)
1925
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
2026
selective_scan_fn, selective_state_update)
2127
from vllm.model_executor.models.mamba_cache import MambaCacheParams
2228
from vllm.model_executor.utils import set_weight_attrs
29+
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
2330

2431

2532
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
2633
@CustomOp.register("mamba_mixer")
27-
class MambaMixer(CustomOp):
34+
class MambaMixer(MambaBase, CustomOp):
2835
"""
2936
Compute ∆, A, B, C, and D the state space parameters and compute
3037
the `contextualized_states`. A, D are input independent
@@ -47,13 +54,16 @@ def __init__(self,
4754
rms_norm_has_weight: bool = True,
4855
rms_norm_eps: float = 1e-5,
4956
activation="silu",
50-
is_lora_enabled: bool = False):
57+
is_lora_enabled: bool = False,
58+
prefix: str = ""):
5159
super().__init__()
5260
self.time_step_rank = time_step_rank
5361
self.ssm_state_size = ssm_state_size
5462
self.use_rms_norm = use_rms_norm
5563
self.activation = activation
5664
self.is_lora_enabled = is_lora_enabled
65+
self.conv_kernel_size = conv_kernel_size
66+
self.intermediate_size = intermediate_size
5767

5868
self.conv1d = ColumnParallelLinear(
5969
input_size=conv_kernel_size,
@@ -131,14 +141,62 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
131141
has_weight=rms_norm_has_weight,
132142
) if use_rms_norm else None
133143

134-
def forward_native(self, hidden_states: torch.Tensor,
135-
conv_state: torch.Tensor, ssm_state: torch.Tensor):
144+
if envs.VLLM_USE_V1:
145+
compilation_config = get_current_vllm_config().compilation_config
146+
if prefix in compilation_config.static_forward_context:
147+
raise ValueError(f"Duplicate layer name: {prefix}")
148+
compilation_config.static_forward_context[prefix] = self
149+
# The outer list is for v0 PP virtual engine. Though this code path
150+
# only runs for v1, we have to do this to unify with the interface
151+
# of Attention + v0 PP.
152+
# The inner tuple is (conv_state, ssm_state)
153+
self.kv_cache = [(torch.tensor([]), torch.tensor([]))]
154+
155+
self.prefix = prefix
156+
157+
def forward(self,
158+
hidden_states: torch.Tensor,
159+
mamba_cache_params: Optional[MambaCacheParams] = None):
160+
if not envs.VLLM_USE_V1:
161+
return CustomOp.forward(self, hidden_states, mamba_cache_params)
162+
else:
163+
return self.forward_cuda(hidden_states, mamba_cache_params)
164+
165+
def forward_native(self,
166+
hidden_states: torch.Tensor,
167+
mamba_cache_params: Optional[MambaCacheParams] = None):
136168
pass
137169

138-
def forward_cuda(self, hidden_states: torch.Tensor,
139-
mamba_cache_params: MambaCacheParams):
170+
def forward_cuda(self,
171+
hidden_states: torch.Tensor,
172+
mamba_cache_params: Optional[MambaCacheParams] = None):
173+
174+
forward_context: ForwardContext = get_forward_context()
175+
attn_metadata = forward_context.attn_metadata
176+
177+
if envs.VLLM_USE_V1:
178+
if attn_metadata is not None:
179+
assert isinstance(attn_metadata, dict)
180+
attn_metadata = attn_metadata[self.prefix]
181+
mamba1_metadata = attn_metadata
182+
assert isinstance(mamba1_metadata, Mamba1AttentionMetadata)
183+
query_start_loc = mamba1_metadata.query_start_loc
184+
state_indices_tensor = mamba1_metadata.state_indices_tensor
185+
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
186+
conv_state = self_kv_cache[0].transpose(-1, -2)
187+
ssm_state = self_kv_cache[1]
188+
has_initial_state = mamba1_metadata.has_initial_states
189+
context_lens_tensor = mamba1_metadata.context_lens_tensor
190+
else:
191+
assert mamba_cache_params is not None
192+
conv_state = mamba_cache_params.conv_state
193+
ssm_state = mamba_cache_params.ssm_state
194+
state_indices_tensor = mamba_cache_params.state_indices_tensor
195+
query_start_loc = attn_metadata.query_start_loc
196+
context_lens_tensor = attn_metadata.context_lens_tensor
140197

141-
attn_metadata: AttentionMetadata = get_forward_context().attn_metadata
198+
if context_lens_tensor is not None:
199+
has_initial_state = context_lens_tensor > 0
142200

143201
# 1. Gated MLP's linear projection
144202
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@@ -148,8 +206,12 @@ def forward_cuda(self, hidden_states: torch.Tensor,
148206
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
149207
self.conv1d.weight.size(2))
150208

151-
if attn_metadata.query_start_loc is not None \
152-
and attn_metadata.context_lens_tensor is not None:
209+
if envs.VLLM_USE_V1 and attn_metadata is None:
210+
# V1 profile run
211+
hidden_states = hidden_states.contiguous()
212+
return self.out_proj(hidden_states.transpose(-2, -1))[0]
213+
214+
if query_start_loc is not None and context_lens_tensor is not None:
153215
# |---------- N-1 iteration --------|
154216
# |---------------- N iteration ---------------------|
155217
# |- tokenA -|......................|-- newTokens ---|
@@ -161,18 +223,18 @@ def forward_cuda(self, hidden_states: torch.Tensor,
161223
conv_weights,
162224
bias=self.conv1d.bias,
163225
activation=self.activation,
164-
conv_states=mamba_cache_params.conv_state,
165-
has_initial_state=attn_metadata.context_lens_tensor > 0,
166-
cache_indices=mamba_cache_params.state_indices_tensor,
167-
query_start_loc=attn_metadata.query_start_loc)
226+
conv_states=conv_state,
227+
has_initial_state=has_initial_state,
228+
cache_indices=state_indices_tensor,
229+
query_start_loc=query_start_loc)
168230
else:
169231
hidden_states = causal_conv1d_update(
170232
hidden_states.transpose(0, 1),
171-
mamba_cache_params.conv_state,
233+
conv_state,
172234
conv_weights,
173235
self.conv1d.bias,
174236
self.activation,
175-
conv_state_indices=mamba_cache_params.state_indices_tensor)
237+
conv_state_indices=state_indices_tensor)
176238
hidden_states = hidden_states.transpose(0, 1)
177239

178240
# 3. State Space Model sequence transformation
@@ -203,11 +265,10 @@ def forward_cuda(self, hidden_states: torch.Tensor,
203265
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
204266
self.dt_proj, "bias") else None)
205267

206-
if attn_metadata.query_start_loc is not None \
207-
and attn_metadata.context_lens_tensor is not None:
268+
if query_start_loc is not None and context_lens_tensor is not None:
208269
scan_outputs = selective_scan_fn(
209270
hidden_states,
210-
mamba_cache_params.ssm_state,
271+
ssm_state,
211272
discrete_time_step,
212273
self.A,
213274
B.transpose(-2, -1),
@@ -216,24 +277,23 @@ def forward_cuda(self, hidden_states: torch.Tensor,
216277
gate,
217278
time_proj_bias,
218279
delta_softplus=True,
219-
cache_indices=mamba_cache_params.state_indices_tensor,
220-
has_initial_state=attn_metadata.context_lens_tensor > 0,
221-
query_start_loc=attn_metadata.query_start_loc)
280+
cache_indices=state_indices_tensor,
281+
has_initial_state=has_initial_state,
282+
query_start_loc=query_start_loc)
222283
else:
223284
scan_outputs = torch.empty_like(hidden_states.transpose(0, 1))
224-
selective_state_update(
225-
mamba_cache_params.ssm_state,
226-
hidden_states.transpose(0, 1),
227-
discrete_time_step.transpose(0, 1),
228-
self.A,
229-
B,
230-
C,
231-
self.D,
232-
gate.transpose(0, 1),
233-
time_proj_bias,
234-
dt_softplus=True,
235-
state_batch_indices=mamba_cache_params.state_indices_tensor,
236-
out=scan_outputs)
285+
selective_state_update(ssm_state,
286+
hidden_states.transpose(0, 1),
287+
discrete_time_step.transpose(0, 1),
288+
self.A,
289+
B,
290+
C,
291+
self.D,
292+
gate.transpose(0, 1),
293+
time_proj_bias,
294+
dt_softplus=True,
295+
state_batch_indices=state_indices_tensor,
296+
out=scan_outputs)
237297
scan_outputs = scan_outputs.transpose(0, 1)
238298

239299
# 4. Final linear projection
@@ -245,3 +305,15 @@ def forward_cuda(self, hidden_states: torch.Tensor,
245305
contextualized_states = self.out_proj(
246306
scan_outputs.transpose(-2, -1))[0]
247307
return contextualized_states
308+
309+
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
310+
return MambaStateShapeCalculator.mamba1_state_shape(
311+
tp_world_size=get_tensor_model_parallel_world_size(),
312+
intermediate_size=self.intermediate_size,
313+
state_size=self.ssm_state_size,
314+
conv_kernel=self.conv_kernel_size,
315+
)
316+
317+
@property
318+
def mamba_type(self) -> str:
319+
return "mamba1"

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from vllm.model_executor.layers.mamba.mamba2_metadata import (Mamba2Metadata,
2222
update_metadata)
2323
from vllm.model_executor.layers.mamba.mamba_utils import (
24-
extra_groups_for_head_shards, get_mamba_state_shape)
24+
MambaStateShapeCalculator)
2525
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
2626
causal_conv1d_fn, causal_conv1d_update)
2727
from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
@@ -278,8 +278,9 @@ def __init__(
278278
# - for TP we shard conv_dim by sharding on n_groups,
279279
# - but if n_groups cannot divide tp_size, we need to
280280
# extend some extra groups
281-
self.n_groups = n_groups + extra_groups_for_head_shards(
281+
groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
282282
n_groups, self.tp_size)
283+
self.n_groups = n_groups + groups
283284

284285
self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
285286
self.conv1d = ColumnParallelLinear(
@@ -732,7 +733,7 @@ def forward_cuda(
732733
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
733734

734735
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
735-
return get_mamba_state_shape(
736+
return MambaStateShapeCalculator.mamba2_state_shape(
736737
intermediate_size=self.intermediate_size,
737738
tp_world_size=get_tensor_model_parallel_world_size(),
738739
n_groups=self.n_groups,

0 commit comments

Comments
 (0)