You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/contributing/model/basic.md
+28Lines changed: 28 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -121,3 +121,31 @@ To support a model with interleaving sliding windows, we need to take care of th
121
121
- In the modeling code, parse the correct sliding window value for every layer, and pass it to the attention layer's `per_layer_sliding_window` argument. For reference, check [this line](https://github.com/vllm-project/vllm/blob/996357e4808ca5eab97d4c97c7d25b3073f46aab/vllm/model_executor/models/llama.py#L171).
122
122
123
123
With these two steps, interleave sliding windows should work with the model.
124
+
125
+
### How to support models that use Mamba?
126
+
127
+
We consider 3 different scenarios:
128
+
129
+
1. Models that use Mamba layers (either Mamba-1 or Mamba-2) but do not use attention layers.
130
+
2. Models that combine Mamba layers (either Mamba-1 or Mamba-2) together with attention layers.
131
+
3. Models that combine Mamba-like mechanisms (e.g., Linear Attention, ShortConv) together with attention layers.
132
+
133
+
For case (1), we recommend looking at the implementation of [`MambaForCausalLM`](gh-file:vllm/model_executor/models/mamba.py) (for Mamba-1) or [`Mamba2ForCausalLM`](gh-file:vllm/model_executor/models/mamba2.py) (for Mamba-2) as a reference.
134
+
The model should inherit protocol `IsAttentionFree` and also implement class methods `get_mamba_state_dtype_from_config` and `get_mamba_state_shape_from_config` to calculate the state shapes and data types from the config.
135
+
For the mamba layers themselves, please use the [`MambaMixer`](gh-file:vllm/model_executor/layers/mamba/mamba_mixer.py) (for Mamba-1) or [`MambaMixer2`](gh-file:vllm/model_executor/layers/mamba/mamba_mixer2.py) (for Mamba-2) classes.
136
+
Please *do not* use the `MambaCacheManager` (deprecated in V1) or replicate any of the V0-specific code paths in the existing model implementations.
137
+
V0-only classes and code will be removed in the very near future.
138
+
The model should also be added to the `MODELS_CONFIG_MAP` dictionary in <gh-file:vllm/model_executor/models/config.py> to ensure that the runtime defaults are optimized.
139
+
140
+
For case (2), we recommend using as a reference the implementation of [`JambaForCausalLM`](gh-file:vllm/model_executor/models/jamba.py) (for an example of a model that uses Mamba-1 and attention together) or [`BambaForCausalLM`](gh-file:vllm/model_executor/models/bamba.py) (for an example of a model that uses Mamba-2 and attention together).
141
+
These models should follow the same instructions as case (1), but they should inherit protocol `IsHybrid` (instead of `IsAttentionFree`) and it is *not* necessary to add them to the `MODELS_CONFIG_MAP` (their runtime defaults will be inferred from the protocol).
142
+
143
+
For case (3), we recommend looking at the implementation of [`MiniMaxText01ForCausalLM`](gh-file:vllm/model_executor/models/minimax_text_01.py) or [`Lfm2ForCausalLM`](gh-file:vllm/model_executor/models/lfm2.py) as a reference, which use custom "mamba-like" layers `MiniMaxText01LinearAttention` and `ShortConv` respectively.
144
+
Please follow the same guidelines as case (2) for implementing these models.
145
+
We use "mamba-like" to refer to layers that posses a state that is updated in-place, rather than being appended-to (like KV cache for attention).
146
+
For implementing new custom mamba-like layers, one should inherit from `MambaBase` and implement the methods `get_state_dtype`, `get_state_shape` to calculate the data types and state shapes at runtime, as well as `mamba_type` and `get_attn_backend`.
147
+
It is also necessary to implement the "attention meta-data" class which handles the meta-data that is common across all layers.
148
+
Please see [`LinearAttentionMetadata`](gh-file:vllm/v1/attention/backends/linear_attn.py) or [`ShortConvAttentionMetadata`](gh-file:v1/attention/backends/short_conv_attn.py) for examples of this.
149
+
Finally, if one wants to support torch compile and CUDA graphs, it necessary to wrap the call to the mamba-like layer inside a custom op and register it.
150
+
Please see the calls to `direct_register_custom_op` in <gh-file:vllm/model_executor/models/minimax_text_01.py> or <gh-file:vllm/model_executor/layers/mamba/short_conv.py> for examples of this.
151
+
The new custom op should then be added to the list `_attention_ops` in <gh-file:vllm/config/compilation.py> to ensure that piecewise CUDA graphs works as intended.
0 commit comments