1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
4
+ from typing import Optional
5
+
4
6
import torch
5
7
from torch import nn
6
8
from torch .nn .parameter import Parameter
7
9
8
- from vllm .attention .backends .abstract import AttentionMetadata
10
+ from vllm import envs
11
+ from vllm .config import get_current_vllm_config
9
12
from vllm .distributed .parallel_state import (
10
13
get_tensor_model_parallel_rank , get_tensor_model_parallel_world_size )
11
- from vllm .forward_context import get_forward_context
14
+ from vllm .forward_context import ForwardContext , get_forward_context
12
15
from vllm .model_executor .custom_op import CustomOp
13
16
from vllm .model_executor .layers .layernorm import RMSNorm
14
17
from vllm .model_executor .layers .linear import (ColumnParallelLinear ,
15
18
MergedColumnParallelLinear ,
16
19
RowParallelLinear )
20
+ from vllm .model_executor .layers .mamba .abstract import MambaBase
21
+ from vllm .model_executor .layers .mamba .mamba_utils import (
22
+ MambaStateShapeCalculator )
17
23
from vllm .model_executor .layers .mamba .ops .causal_conv1d import (
18
24
causal_conv1d_fn , causal_conv1d_update )
19
25
from vllm .model_executor .layers .mamba .ops .mamba_ssm import (
20
26
selective_scan_fn , selective_state_update )
21
27
from vllm .model_executor .models .mamba_cache import MambaCacheParams
22
28
from vllm .model_executor .utils import set_weight_attrs
29
+ from vllm .v1 .attention .backends .mamba1_attn import Mamba1AttentionMetadata
23
30
24
31
25
32
# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
26
33
@CustomOp .register ("mamba_mixer" )
27
- class MambaMixer (CustomOp ):
34
+ class MambaMixer (MambaBase , CustomOp ):
28
35
"""
29
36
Compute ∆, A, B, C, and D the state space parameters and compute
30
37
the `contextualized_states`. A, D are input independent
@@ -47,13 +54,16 @@ def __init__(self,
47
54
rms_norm_has_weight : bool = True ,
48
55
rms_norm_eps : float = 1e-5 ,
49
56
activation = "silu" ,
50
- is_lora_enabled : bool = False ):
57
+ is_lora_enabled : bool = False ,
58
+ prefix : str = "" ):
51
59
super ().__init__ ()
52
60
self .time_step_rank = time_step_rank
53
61
self .ssm_state_size = ssm_state_size
54
62
self .use_rms_norm = use_rms_norm
55
63
self .activation = activation
56
64
self .is_lora_enabled = is_lora_enabled
65
+ self .conv_kernel_size = conv_kernel_size
66
+ self .intermediate_size = intermediate_size
57
67
58
68
self .conv1d = ColumnParallelLinear (
59
69
input_size = conv_kernel_size ,
@@ -131,14 +141,62 @@ def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
131
141
has_weight = rms_norm_has_weight ,
132
142
) if use_rms_norm else None
133
143
134
- def forward_native (self , hidden_states : torch .Tensor ,
135
- conv_state : torch .Tensor , ssm_state : torch .Tensor ):
144
+ if envs .VLLM_USE_V1 :
145
+ compilation_config = get_current_vllm_config ().compilation_config
146
+ if prefix in compilation_config .static_forward_context :
147
+ raise ValueError (f"Duplicate layer name: { prefix } " )
148
+ compilation_config .static_forward_context [prefix ] = self
149
+ # The outer list is for v0 PP virtual engine. Though this code path
150
+ # only runs for v1, we have to do this to unify with the interface
151
+ # of Attention + v0 PP.
152
+ # The inner tuple is (conv_state, ssm_state)
153
+ self .kv_cache = [(torch .tensor ([]), torch .tensor ([]))]
154
+
155
+ self .prefix = prefix
156
+
157
+ def forward (self ,
158
+ hidden_states : torch .Tensor ,
159
+ mamba_cache_params : Optional [MambaCacheParams ] = None ):
160
+ if not envs .VLLM_USE_V1 :
161
+ return CustomOp .forward (self , hidden_states , mamba_cache_params )
162
+ else :
163
+ return self .forward_cuda (hidden_states , mamba_cache_params )
164
+
165
+ def forward_native (self ,
166
+ hidden_states : torch .Tensor ,
167
+ mamba_cache_params : Optional [MambaCacheParams ] = None ):
136
168
pass
137
169
138
- def forward_cuda (self , hidden_states : torch .Tensor ,
139
- mamba_cache_params : MambaCacheParams ):
170
+ def forward_cuda (self ,
171
+ hidden_states : torch .Tensor ,
172
+ mamba_cache_params : Optional [MambaCacheParams ] = None ):
173
+
174
+ forward_context : ForwardContext = get_forward_context ()
175
+ attn_metadata = forward_context .attn_metadata
176
+
177
+ if envs .VLLM_USE_V1 :
178
+ if attn_metadata is not None :
179
+ assert isinstance (attn_metadata , dict )
180
+ attn_metadata = attn_metadata [self .prefix ]
181
+ mamba1_metadata = attn_metadata
182
+ assert isinstance (mamba1_metadata , Mamba1AttentionMetadata )
183
+ query_start_loc = mamba1_metadata .query_start_loc
184
+ state_indices_tensor = mamba1_metadata .state_indices_tensor
185
+ self_kv_cache = self .kv_cache [forward_context .virtual_engine ]
186
+ conv_state = self_kv_cache [0 ].transpose (- 1 , - 2 )
187
+ ssm_state = self_kv_cache [1 ]
188
+ has_initial_state = mamba1_metadata .has_initial_states
189
+ context_lens_tensor = mamba1_metadata .context_lens_tensor
190
+ else :
191
+ assert mamba_cache_params is not None
192
+ conv_state = mamba_cache_params .conv_state
193
+ ssm_state = mamba_cache_params .ssm_state
194
+ state_indices_tensor = mamba_cache_params .state_indices_tensor
195
+ query_start_loc = attn_metadata .query_start_loc
196
+ context_lens_tensor = attn_metadata .context_lens_tensor
140
197
141
- attn_metadata : AttentionMetadata = get_forward_context ().attn_metadata
198
+ if context_lens_tensor is not None :
199
+ has_initial_state = context_lens_tensor > 0
142
200
143
201
# 1. Gated MLP's linear projection
144
202
projected_states = self .in_proj (hidden_states )[0 ].transpose (- 2 , - 1 )
@@ -148,8 +206,12 @@ def forward_cuda(self, hidden_states: torch.Tensor,
148
206
conv_weights = self .conv1d .weight .view (self .conv1d .weight .size (0 ),
149
207
self .conv1d .weight .size (2 ))
150
208
151
- if attn_metadata .query_start_loc is not None \
152
- and attn_metadata .context_lens_tensor is not None :
209
+ if envs .VLLM_USE_V1 and attn_metadata is None :
210
+ # V1 profile run
211
+ hidden_states = hidden_states .contiguous ()
212
+ return self .out_proj (hidden_states .transpose (- 2 , - 1 ))[0 ]
213
+
214
+ if query_start_loc is not None and context_lens_tensor is not None :
153
215
# |---------- N-1 iteration --------|
154
216
# |---------------- N iteration ---------------------|
155
217
# |- tokenA -|......................|-- newTokens ---|
@@ -161,18 +223,18 @@ def forward_cuda(self, hidden_states: torch.Tensor,
161
223
conv_weights ,
162
224
bias = self .conv1d .bias ,
163
225
activation = self .activation ,
164
- conv_states = mamba_cache_params . conv_state ,
165
- has_initial_state = attn_metadata . context_lens_tensor > 0 ,
166
- cache_indices = mamba_cache_params . state_indices_tensor ,
167
- query_start_loc = attn_metadata . query_start_loc )
226
+ conv_states = conv_state ,
227
+ has_initial_state = has_initial_state ,
228
+ cache_indices = state_indices_tensor ,
229
+ query_start_loc = query_start_loc )
168
230
else :
169
231
hidden_states = causal_conv1d_update (
170
232
hidden_states .transpose (0 , 1 ),
171
- mamba_cache_params . conv_state ,
233
+ conv_state ,
172
234
conv_weights ,
173
235
self .conv1d .bias ,
174
236
self .activation ,
175
- conv_state_indices = mamba_cache_params . state_indices_tensor )
237
+ conv_state_indices = state_indices_tensor )
176
238
hidden_states = hidden_states .transpose (0 , 1 )
177
239
178
240
# 3. State Space Model sequence transformation
@@ -203,11 +265,10 @@ def forward_cuda(self, hidden_states: torch.Tensor,
203
265
time_proj_bias = (self .dt_proj .bias .float () if hasattr (
204
266
self .dt_proj , "bias" ) else None )
205
267
206
- if attn_metadata .query_start_loc is not None \
207
- and attn_metadata .context_lens_tensor is not None :
268
+ if query_start_loc is not None and context_lens_tensor is not None :
208
269
scan_outputs = selective_scan_fn (
209
270
hidden_states ,
210
- mamba_cache_params . ssm_state ,
271
+ ssm_state ,
211
272
discrete_time_step ,
212
273
self .A ,
213
274
B .transpose (- 2 , - 1 ),
@@ -216,24 +277,23 @@ def forward_cuda(self, hidden_states: torch.Tensor,
216
277
gate ,
217
278
time_proj_bias ,
218
279
delta_softplus = True ,
219
- cache_indices = mamba_cache_params . state_indices_tensor ,
220
- has_initial_state = attn_metadata . context_lens_tensor > 0 ,
221
- query_start_loc = attn_metadata . query_start_loc )
280
+ cache_indices = state_indices_tensor ,
281
+ has_initial_state = has_initial_state ,
282
+ query_start_loc = query_start_loc )
222
283
else :
223
284
scan_outputs = torch .empty_like (hidden_states .transpose (0 , 1 ))
224
- selective_state_update (
225
- mamba_cache_params .ssm_state ,
226
- hidden_states .transpose (0 , 1 ),
227
- discrete_time_step .transpose (0 , 1 ),
228
- self .A ,
229
- B ,
230
- C ,
231
- self .D ,
232
- gate .transpose (0 , 1 ),
233
- time_proj_bias ,
234
- dt_softplus = True ,
235
- state_batch_indices = mamba_cache_params .state_indices_tensor ,
236
- out = scan_outputs )
285
+ selective_state_update (ssm_state ,
286
+ hidden_states .transpose (0 , 1 ),
287
+ discrete_time_step .transpose (0 , 1 ),
288
+ self .A ,
289
+ B ,
290
+ C ,
291
+ self .D ,
292
+ gate .transpose (0 , 1 ),
293
+ time_proj_bias ,
294
+ dt_softplus = True ,
295
+ state_batch_indices = state_indices_tensor ,
296
+ out = scan_outputs )
237
297
scan_outputs = scan_outputs .transpose (0 , 1 )
238
298
239
299
# 4. Final linear projection
@@ -245,3 +305,15 @@ def forward_cuda(self, hidden_states: torch.Tensor,
245
305
contextualized_states = self .out_proj (
246
306
scan_outputs .transpose (- 2 , - 1 ))[0 ]
247
307
return contextualized_states
308
+
309
+ def get_state_shape (self ) -> tuple [tuple [int , ...], tuple [int , ...]]:
310
+ return MambaStateShapeCalculator .mamba1_state_shape (
311
+ tp_world_size = get_tensor_model_parallel_world_size (),
312
+ intermediate_size = self .intermediate_size ,
313
+ state_size = self .ssm_state_size ,
314
+ conv_kernel = self .conv_kernel_size ,
315
+ )
316
+
317
+ @property
318
+ def mamba_type (self ) -> str :
319
+ return "mamba1"
0 commit comments