Skip to content

Commit 1b2b821

Browse files
committed
support mamba models
1 parent 755a833 commit 1b2b821

File tree

4 files changed

+622
-3
lines changed

4 files changed

+622
-3
lines changed

optimum/exporters/openvino/model_configs.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import enum
1616
from copy import deepcopy
1717
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
18+
import inspect
19+
import re
1820

1921
from packaging import version
2022
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel
@@ -95,6 +97,7 @@
9597
LlamaModelPatcher,
9698
LlavaImageEmbeddingModelPatcher,
9799
LlavaQwen2ImageEmbeddingsModelPatcher,
100+
MambaPatcher,
98101
MiniCPM3Patcher,
99102
MiniCPMModelPatcher,
100103
MiniCPMVImageEmbeddingsModelPatcher,
@@ -2880,3 +2883,126 @@ def patch_model_for_export(
28802883
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
28812884
) -> "ModelPatcher":
28822885
return DeepseekPatcher(self, model, model_kwargs=model_kwargs)
2886+
2887+
2888+
class MambaCacheDummyInputGenerator(DummyInputGenerator):
2889+
"""
2890+
Generates dummy past_key_values inputs for seq2seq architectures.
2891+
"""
2892+
2893+
SUPPORTED_INPUT_NAMES = ("past_ssm_states", "past_conv_states", "cache_position")
2894+
2895+
def __init__(
2896+
self,
2897+
task: str,
2898+
normalized_config,
2899+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
2900+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
2901+
**kwargs,
2902+
):
2903+
self.normalized_config = normalized_config
2904+
self.batch_size = batch_size
2905+
self.sequence_length = sequence_length
2906+
self.intermediate_size = self.normalized_config.config.intermediate_size
2907+
self.ssm_state_size = self.normalized_config.config.state_size
2908+
self.conv_kernel_size = self.normalized_config.config.conv_kernel
2909+
2910+
2911+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
2912+
if input_name == "past_ssm_states":
2913+
ssm_shape = [self.batch_size, self.intermediate_size, self.ssm_state_size]
2914+
return [self.random_float_tensor(ssm_shape, framework=framework, dtype=float_dtype) for _ in range(self.normalized_config.num_layers)]
2915+
2916+
elif input_name == "past_conv_states":
2917+
conv_shape = [self.batch_size, self.intermediate_size, self.conv_kernel_size]
2918+
return [self.random_float_tensor(conv_shape, framework=framework, dtype=float_dtype) for _ in range(self.normalized_config.num_layers)]
2919+
2920+
elif input_name == "cache_position":
2921+
return self.random_int_tensor(
2922+
shape=[self.conv_kernel_size],
2923+
max_value=self.sequence_length,
2924+
framework=framework,
2925+
dtype=int_dtype,
2926+
)
2927+
2928+
raise ValueError(f"Unsupported input name {input_name}")
2929+
2930+
@register_in_tasks_manager(
2931+
"mamba", *["text-generation", "text-generation-with-past"], library_name="transformers"
2932+
)
2933+
class MambaOpenVINOConfig(TextDecoderOnnxConfig):
2934+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MambaCacheDummyInputGenerator)
2935+
DUMMY_PKV_GENERATOR_CLASS = MambaCacheDummyInputGenerator
2936+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
2937+
2938+
@property
2939+
def inputs(self) -> Dict[str, Dict[int, str]]:
2940+
if self.use_past_in_inputs:
2941+
common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}}
2942+
self.add_past_key_values(common_inputs, direction="inputs")
2943+
#common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
2944+
common_inputs["cache_position"] = {0: "cache_sequence_length"}
2945+
else:
2946+
common_inputs = {
2947+
"input_ids": {0: "batch_size", 1: "sequence_length"},
2948+
#"attention_mask": {0: "batch_size", 1: "sequence_length"},
2949+
"cache_position": {0: "cache_sequence_length"}
2950+
}
2951+
return common_inputs
2952+
2953+
def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
2954+
"""
2955+
Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction.
2956+
2957+
Args:
2958+
inputs_or_outputs (`Dict[str, Dict[int, str]]`):
2959+
The mapping to fill.
2960+
direction (`str`):
2961+
either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the
2962+
output mapping, this is important for axes naming.
2963+
"""
2964+
if direction not in ["inputs", "outputs"]:
2965+
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
2966+
2967+
if direction == "inputs":
2968+
ssm_name = "past_ssm_states"
2969+
conv_name = "past_conv_states"
2970+
else:
2971+
ssm_name = "present_ssm_states"
2972+
conv_name = "present_conv_states"
2973+
2974+
for i in range(self._normalized_config.num_layers):
2975+
inputs_or_outputs[f"{ssm_name}.{i}"] = {0: "batch_size"}
2976+
2977+
for i in range(self._normalized_config.num_layers):
2978+
inputs_or_outputs[f"{conv_name}.{i}"] = {0: "batch_size"}
2979+
2980+
def patch_model_for_export(self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None):
2981+
return MambaPatcher(self, model, model_kwargs)
2982+
2983+
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
2984+
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)
2985+
2986+
dummy_inputs = {}
2987+
input_names = [key for key in self.inputs.keys() if not key.startswith("past_")]
2988+
if self.use_past_in_inputs and self.use_cache_branch is not False:
2989+
input_names.extend(["past_ssm_states", "past_conv_states"])
2990+
2991+
for input_name in input_names:
2992+
input_was_inserted = False
2993+
for dummy_input_gen in dummy_inputs_generators:
2994+
if dummy_input_gen.supports_input(input_name):
2995+
dummy_inputs[input_name] = self.overwrite_shape_and_generate_input(
2996+
dummy_input_gen,
2997+
input_name,
2998+
framework,
2999+
input_shapes=kwargs,
3000+
)
3001+
input_was_inserted = True
3002+
break
3003+
if not input_was_inserted:
3004+
raise RuntimeError(
3005+
f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.'
3006+
)
3007+
3008+
return dummy_inputs

0 commit comments

Comments
 (0)