Skip to content

Commit 8dd9fd2

Browse files
rkazantseaidovaecharlaix
authored andcommitted
[OpenVINO] Add support for Mamba and Falcon-mamba (#1360)
* support mamba models * add falcon mamba support * Apply suggestions from code review * Update optimum/exporters/openvino/model_patcher.py * Fix after merge Signed-off-by: Kazantsev, Roman <[email protected]> * Add tests for export_cli Signed-off-by: Kazantsev, Roman <[email protected]> * Add test_export for mamba Signed-off-by: Kazantsev, Roman <[email protected]> * Reformat modeling_decoder.py Signed-off-by: Kazantsev, Roman <[email protected]> * Fix format for modeling_decoder.py Signed-off-by: Kazantsev, Roman <[email protected]> * Add import of PretrainedConfig Signed-off-by: Kazantsev, Roman <[email protected]> * Run mamba test only for transformers greater 4.39 Signed-off-by: Kazantsev, Roman <[email protected]> * Add test_modeling for mamba Signed-off-by: Kazantsev, Roman <[email protected]> * Update documentation Signed-off-by: Kazantsev, Roman <[email protected]> * Apply review-comments: SSM_MODELS from utils, docs Signed-off-by: Kazantsev, Roman <[email protected]> * Removed commented out code Signed-off-by: Kazantsev, Roman <[email protected]> * Have own constructor for Mamba models Signed-off-by: Kazantsev, Roman <[email protected]> * Remove unneeded code Signed-off-by: Kazantsev, Roman <[email protected]> * Update optimum/exporters/openvino/model_patcher.py * Update docs/source/openvino/models.mdx Co-authored-by: Ella Charlaix <[email protected]> * Apply suggestions from code review Co-authored-by: Ella Charlaix <[email protected]> * Fix formatting Signed-off-by: Kazantsev, Roman <[email protected]> * Apply suggestions from code review Co-authored-by: Ella Charlaix <[email protected]> * Update optimum/exporters/openvino/model_configs.py Co-authored-by: Ella Charlaix <[email protected]> * Update optimum/exporters/openvino/model_configs.py * Test falcon-mamba Signed-off-by: Kazantsev, Roman <[email protected]> * Have just one patched forward for both mamba and falcon-mamba Signed-off-by: Kazantsev, Roman <[email protected]> * Use the single config and patcher for both mamba and falcon-mamba Signed-off-by: Kazantsev, Roman <[email protected]> * Update tests/openvino/test_modeling.py * Fix SelectiveScan Signed-off-by: Kazantsev, Roman <[email protected]> * Update optimum/exporters/openvino/model_patcher.py * Simplify MambaPatcher Signed-off-by: Kazantsev, Roman <[email protected]> * Fix patched_forward Signed-off-by: Kazantsev, Roman <[email protected]> * Add comments for patching Signed-off-by: Kazantsev, Roman <[email protected]> * Initialize ssm_rms_normalization in constructor Signed-off-by: Kazantsev, Roman <[email protected]> * Add a comment to explain a reason of having generate_dummy_inputs overriden method Signed-off-by: Kazantsev, Roman <[email protected]> * Move ssm_rms_normalization into ModelPatcher Signed-off-by: Kazantsev, Roman <[email protected]> * Avoid jit scripting that fixes a part of functional tests Signed-off-by: Kazantsev, Roman <[email protected]> * Update tests/openvino/utils_tests.py * Add attention_mask to inputs Signed-off-by: Kazantsev, Roman <[email protected]> * Apply review comments: attention mask for decoding steps and no replace needed for model type Signed-off-by: Kazantsev, Roman <[email protected]> * Revert replacement for model type Signed-off-by: Kazantsev, Roman <[email protected]> * Re-write patching for convsequencetransform Signed-off-by: Kazantsev, Roman <[email protected]> * Use original model type without replacement Signed-off-by: Kazantsev, Roman <[email protected]> * Update tests/openvino/utils_tests.py * Support use_cache=False Signed-off-by: Kazantsev, Roman <[email protected]> * Update tests/openvino/utils_tests.py * Update optimum/exporters/openvino/model_patcher.py Co-authored-by: Ella Charlaix <[email protected]> * Update tests/openvino/test_modeling.py * Apply review-comment - no need to override update_conv_state method Signed-off-by: Kazantsev, Roman <[email protected]> * Add additional checks in tests Signed-off-by: Kazantsev, Roman <[email protected]> --------- Signed-off-by: Kazantsev, Roman <[email protected]> Co-authored-by: eaidova <[email protected]> Co-authored-by: Ella Charlaix <[email protected]>
1 parent 8edea4b commit 8dd9fd2

File tree

10 files changed

+777
-9
lines changed

10 files changed

+777
-9
lines changed

docs/source/openvino/models.mdx

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ Here is the list of the supported architectures :
5252
- ESM
5353
- Exaone
5454
- Falcon
55+
- Falcon-Mamba
5556
- Flaubert
5657
- GLM-4
5758
- GLM-Edge
@@ -82,6 +83,7 @@ Here is the list of the supported architectures :
8283
- Llava-Next-Video
8384
- M2-M100
8485
- MAIRA-2
86+
- Mamba
8587
- MBart
8688
- MPNet
8789
- MPT

optimum/exporters/openvino/model_configs.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@
117117
LlavaNextVideoImageEmbeddingModelPatcher,
118118
LlavaQwen2ImageEmbeddingsModelPatcher,
119119
MairaImageEmbeddingModelPatcher,
120+
MambaPatcher,
120121
MarianModelPatcher,
121122
MarianStatefulSeq2SeqDecoderPatcher,
122123
MiniCPM3Patcher,
@@ -4382,3 +4383,128 @@ def patch_model_for_export(
43824383
if self._behavior != VLMConfigBehavior.VISION_EMBEDDINGS:
43834384
return super().patch_model_for_export(model, model_kwargs)
43844385
return Llama4ImageEmbeddingsModelPatcher(self, model, model_kwargs)
4386+
4387+
4388+
class MambaCacheDummyInputGenerator(DummyInputGenerator):
4389+
"""
4390+
Generates dummy past_ssm_states, past_conv_states and cache_position inputs for Mamba architectures.
4391+
"""
4392+
4393+
SUPPORTED_INPUT_NAMES = ("cache_params", "cache_position")
4394+
4395+
def __init__(
4396+
self,
4397+
task: str,
4398+
normalized_config,
4399+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
4400+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
4401+
**kwargs,
4402+
):
4403+
self.normalized_config = normalized_config
4404+
self.batch_size = batch_size
4405+
self.sequence_length = sequence_length
4406+
self.intermediate_size = self.normalized_config.config.intermediate_size
4407+
self.ssm_state_size = self.normalized_config.config.state_size
4408+
self.conv_kernel_size = self.normalized_config.config.conv_kernel
4409+
4410+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
4411+
if input_name == "cache_params":
4412+
ssm_shape = [self.batch_size, self.intermediate_size, self.ssm_state_size]
4413+
conv_shape = [self.batch_size, self.intermediate_size, self.conv_kernel_size]
4414+
return [
4415+
(
4416+
self.random_float_tensor(ssm_shape, framework=framework, dtype=float_dtype),
4417+
self.random_float_tensor(conv_shape, framework=framework, dtype=float_dtype),
4418+
)
4419+
for _ in range(self.normalized_config.num_layers)
4420+
]
4421+
elif input_name == "cache_position":
4422+
return self.random_int_tensor(
4423+
shape=[self.conv_kernel_size],
4424+
max_value=self.sequence_length,
4425+
framework=framework,
4426+
dtype=int_dtype,
4427+
)
4428+
4429+
raise ValueError(f"Unsupported input name {input_name}")
4430+
4431+
4432+
@register_in_tasks_manager(
4433+
"falcon_mamba", *["text-generation", "text-generation-with-past"], library_name="transformers"
4434+
)
4435+
@register_in_tasks_manager("mamba", *["text-generation", "text-generation-with-past"], library_name="transformers")
4436+
class MambaOpenVINOConfig(TextDecoderOnnxConfig):
4437+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MambaCacheDummyInputGenerator)
4438+
DUMMY_PKV_GENERATOR_CLASS = MambaCacheDummyInputGenerator
4439+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
4440+
MIN_TRANSFORMERS_VERSION = version.parse("4.43.0")
4441+
4442+
@property
4443+
def inputs(self) -> Dict[str, Dict[int, str]]:
4444+
common_inputs = {
4445+
"input_ids": {0: "batch_size", 1: "sequence_length"},
4446+
"attention_mask": {0: "batch_size", 1: "sequence_length"},
4447+
"cache_position": {0: "cache_sequence_length"},
4448+
}
4449+
if self.use_past_in_inputs:
4450+
self.add_past_key_values(common_inputs, direction="inputs")
4451+
return common_inputs
4452+
4453+
def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
4454+
"""
4455+
Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction.
4456+
4457+
Args:
4458+
inputs_or_outputs (`Dict[str, Dict[int, str]]`):
4459+
The mapping to fill.
4460+
direction (`str`):
4461+
either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the
4462+
output mapping, this is important for axes naming.
4463+
"""
4464+
if direction not in ["inputs", "outputs"]:
4465+
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
4466+
4467+
if direction == "inputs":
4468+
ssm_conv_states_name = "cache_params.past"
4469+
else:
4470+
ssm_conv_states_name = "cache_params.present"
4471+
4472+
for i in range(self._normalized_config.num_layers):
4473+
# [batch_size, d_state, d_model]
4474+
inputs_or_outputs[f"{ssm_conv_states_name}.ssm.{i}"] = {0: "batch_size"}
4475+
# [batch_size, conv_kernel_size - 1, d_model]
4476+
inputs_or_outputs[f"{ssm_conv_states_name}.conv.{i}"] = {0: "batch_size"}
4477+
4478+
def patch_model_for_export(
4479+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
4480+
):
4481+
return MambaPatcher(self, model, model_kwargs)
4482+
4483+
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
4484+
# need to override `generate_dummy_inputs` since mamba model has other states: ssm_states and conv_states
4485+
# which we separate and call them as past_ssm_states and past_conv_states
4486+
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)
4487+
4488+
dummy_inputs = {}
4489+
input_names = [key for key in self.inputs.keys() if not key.startswith("cache_params")]
4490+
if self.use_past_in_inputs:
4491+
input_names.extend(["cache_params"])
4492+
4493+
for input_name in input_names:
4494+
input_was_inserted = False
4495+
for dummy_input_gen in dummy_inputs_generators:
4496+
if dummy_input_gen.supports_input(input_name):
4497+
dummy_inputs[input_name] = self.overwrite_shape_and_generate_input(
4498+
dummy_input_gen,
4499+
input_name,
4500+
framework,
4501+
input_shapes=kwargs,
4502+
)
4503+
input_was_inserted = True
4504+
break
4505+
if not input_was_inserted:
4506+
raise RuntimeError(
4507+
f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.'
4508+
)
4509+
4510+
return dummy_inputs

0 commit comments

Comments
 (0)