Skip to content

Commit 4178fa6

Browse files
committed
add falcon mamba support
1 parent e081a82 commit 4178fa6

File tree

5 files changed

+109
-9
lines changed

5 files changed

+109
-9
lines changed

optimum/exporters/openvino/model_configs.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
DeciLMModelPatcher,
7878
DeepseekPatcher,
7979
FalconModelPatcher,
80+
FalconMambaPatcher,
8081
FluxTransfromerModelPatcher,
8182
Gemma2ModelPatcher,
8283
GptBigCodeModelPatcher,
@@ -3010,3 +3011,13 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
30103011
)
30113012

30123013
return dummy_inputs
3014+
3015+
3016+
@register_in_tasks_manager(
3017+
"falcon-mamba", *["text-generation", "text-generation-with-past"], library_name="transformers"
3018+
)
3019+
class FalconMambaOpenVINOConfig(MambaOpenVINOConfig):
3020+
def patch_model_for_export(
3021+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
3022+
):
3023+
return FalconMambaPatcher(self, model, model_kwargs)

optimum/exporters/openvino/model_patcher.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4510,6 +4510,75 @@ def mamba_mixer_forward(
45104510
return contextualized_states
45114511

45124512

4513+
def falcon_mamba_mixer_forward(
4514+
self,
4515+
input_states,
4516+
cache_params=None,
4517+
cache_position: Optional[torch.LongTensor] = None,
4518+
attention_mask: Optional[torch.LongTensor] = None,
4519+
):
4520+
from transformers.models.falcon_mamba.modeling_falcon_mamba import rms_forward
4521+
4522+
batch_size, seq_len, _ = input_states.shape
4523+
dtype = input_states.dtype
4524+
# 1. Gated MLP's linear projection
4525+
projected_states = self.in_proj(input_states).transpose(1, 2) # [batch, 2 * intermediate_size, seq_len]
4526+
hidden_states, gate = projected_states.chunk(2, dim=1)
4527+
4528+
if attention_mask is not None:
4529+
hidden_states = hidden_states * attention_mask.unsqueeze(1)
4530+
4531+
# 2. Convolution sequence transformation
4532+
if cache_params is not None:
4533+
ssm_state = cache_params.ssm_states[self.layer_idx].clone()
4534+
ssm_state = ssm_state.to(hidden_states.device)
4535+
# use `cache_position.shape[0]` to check whether we are in prefill
4536+
# stage, it's equivalent to check `cache_position[0] == 0`, which
4537+
# breaks dynamo fullgraph constraints
4538+
hidden_states, conv_state = self.conv_sequence_transform(
4539+
hidden_states, cache_position, cache_params.conv_states[self.layer_idx]
4540+
)
4541+
cache_params.conv_states[self.layer_idx] = conv_state
4542+
else:
4543+
ssm_state = torch.zeros(
4544+
(batch_size, self.intermediate_size, self.ssm_state_size), device=hidden_states.device, dtype=dtype
4545+
)
4546+
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # [batch, intermediate_size, seq_len]
4547+
4548+
if attention_mask is not None:
4549+
hidden_states = hidden_states * attention_mask.unsqueeze(1)
4550+
4551+
# 3. State Space Model sequence transformation
4552+
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
4553+
ssm_parameters = self.x_proj(hidden_states.transpose(1, 2))
4554+
time_step, B, C = torch.split(
4555+
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
4556+
)
4557+
4558+
B = rms_forward(B, variance_epsilon=self.rms_eps)
4559+
C = rms_forward(C, variance_epsilon=self.rms_eps)
4560+
time_step = rms_forward(time_step, variance_epsilon=self.rms_eps)
4561+
discrete_time_step = self.dt_proj(time_step) # [batch, seq_len, intermediate_size]
4562+
4563+
discrete_time_step = torch.nn.functional.softplus(discrete_time_step) # [batch, intermediate_size, seq_len]
4564+
A = -torch.exp(self.A_log.float())
4565+
B = B.float()
4566+
D = self.D.float()
4567+
4568+
scan_output, ssm_state = self.selective_scan(
4569+
ssm_state, hidden_states.float().transpose(1, 2), discrete_time_step, A, B, C, D
4570+
)
4571+
scan_output = scan_output.transpose(1, 2)
4572+
scan_output = scan_output * self.act(gate)
4573+
4574+
if cache_params is not None:
4575+
cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
4576+
4577+
# 4. Final linear projection
4578+
contextualized_states = self.out_proj(scan_output.transpose(1, 2)) # [batch, seq_len, hidden_size]
4579+
return contextualized_states
4580+
4581+
45134582
class MambaPatcher(ModelPatcher):
45144583
def __init__(
45154584
self,
@@ -4684,3 +4753,22 @@ def __exit__(self, exc_type, exc_value, traceback):
46844753
self._model.forward = self._model.__orig_forward
46854754
for layer in self._model.backbone.layers:
46864755
layer.mixer.forward = layer.mixer._orig_forward
4756+
4757+
4758+
class FalconMambaPatcher(MambaPatcher):
4759+
def __enter__(self):
4760+
super().__enter__()
4761+
selective_scan = SelectiveScan()
4762+
4763+
for layer in self._model.backbone.layers:
4764+
layer.mixer.selective_scan = selective_scan
4765+
layer.mixer._orig_forward = layer.mixer.forward
4766+
layer.mixer.forward = types.MethodType(falcon_mamba_mixer_forward, layer.mixer)
4767+
conv_transform = ConvSequenceTransform(
4768+
layer.mixer.conv_kernel_size,
4769+
layer.mixer.use_conv_bias,
4770+
layer.mixer.conv1d,
4771+
layer.mixer.act,
4772+
layer.mixer.conv1d.bias,
4773+
)
4774+
layer.mixer.conv_sequence_transform = torch.jit.script(conv_transform)

optimum/exporters/openvino/stateful.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,13 @@ def patch_stateful_ssm(config, ov_model):
297297
build_state_initializer(ov_model, batch_dim)
298298

299299

300+
SSM_MODELS = ["mamba", "falcon-mamba"]
301+
302+
300303
def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
301304
if config.is_encoder_decoder and model_has_input_output_name(ov_model, "encoder_hidden_states"):
302305
return patch_stateful_encoder_decoder(config, ov_model)
303-
if config.model_type == "mamba":
306+
if config.model_type.replace("_", "-") in SSM_MODELS:
304307
return patch_stateful_ssm(config, ov_model)
305308
return patch_stateful_decoder(config, ov_model)
306309

optimum/exporters/openvino/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,8 @@ def get_submodels(model):
229229
"qwen2-5-vl",
230230
]
231231

232+
SSM_MODELS = ["mamba", "falcon-mamba"]
233+
232234

233235
def save_config(config, save_dir):
234236
try:

optimum/intel/openvino/modeling_decoder.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
import copy
1515
import logging
1616
import os
17-
from dataclasses import dataclass
1817
from pathlib import Path
1918
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
19+
from dataclasses import dataclass
2020

2121
import numpy as np
2222
import openvino
@@ -37,6 +37,7 @@
3737

3838
from ...exporters.openvino import ensure_stateful_is_available, main_export, patch_stateful
3939
from ...exporters.openvino.stateful import model_has_state
40+
from ...exporters.openvino.utils import SSM_MODELS
4041
from ..utils.import_utils import compare_versions, is_nncf_available, is_transformers_version
4142
from ..utils.modeling_utils import MULTI_QUERY_ATTN_MODELS
4243
from .configuration import (
@@ -59,7 +60,7 @@
5960
if is_transformers_version(">=", "4.43"):
6061
from transformers.cache_utils import MambaCache
6162
else:
62-
MambaCache = object()
63+
MambaCache = object
6364

6465
if TYPE_CHECKING:
6566
try:
@@ -858,7 +859,7 @@ def _from_pretrained(
858859
init_cls = OVBloomForCausalLM
859860
elif model_type == "gpt-bigcode":
860861
init_cls = OVGPTBigCodeForCausalLM
861-
elif model_type == "mamba":
862+
elif model_type in SSM_MODELS:
862863
init_cls = OVMambaForCausalLM
863864
else:
864865
init_cls = cls
@@ -1138,17 +1139,13 @@ def forward(
11381139
self._past_length = 0
11391140

11401141
ssm_states, conv_states = [], []
1141-
print(inputs.keys())
1142-
11431142
self.request.start_async(inputs, share_inputs=True)
11441143
self.request.wait()
11451144
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)
11461145

11471146
if self.stateful:
11481147
self._past_length += input_ids.shape[1]
11491148
else:
1150-
print(self.ssm_cache_output_names)
1151-
print(self.conv_cache_output_names)
11521149
ssm_states = [self.request.get_tensor(key).data for key in self.ssm_cache_output_names]
11531150
conv_states = [self.request.get_tensor(key).data for key in self.conv_cache_output_names]
11541151
cache_params = OVMambaCache(self.config, input_ids.shape[0], conv_states=conv_states, ssm_states=ssm_states)
@@ -1159,7 +1156,6 @@ def _update_model_kwargs_for_generation(
11591156
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], num_new_tokens: int = 1, **kwargs
11601157
) -> Dict[str, Any]:
11611158
model_kwargs["cache_params"] = outputs.get("cache_params", None)
1162-
print(model_kwargs["cache_params"])
11631159
if (
11641160
model_kwargs.get("use_cache", True)
11651161
and "cache_position" in model_kwargs

0 commit comments

Comments
 (0)