Skip to content

Commit e081a82

Browse files
committed
support mamba models
1 parent 755a833 commit e081a82

File tree

4 files changed

+685
-3
lines changed

4 files changed

+685
-3
lines changed

optimum/exporters/openvino/model_configs.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
LlamaModelPatcher,
9696
LlavaImageEmbeddingModelPatcher,
9797
LlavaQwen2ImageEmbeddingsModelPatcher,
98+
MambaPatcher,
9899
MiniCPM3Patcher,
99100
MiniCPMModelPatcher,
100101
MiniCPMVImageEmbeddingsModelPatcher,
@@ -2880,3 +2881,132 @@ def patch_model_for_export(
28802881
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
28812882
) -> "ModelPatcher":
28822883
return DeepseekPatcher(self, model, model_kwargs=model_kwargs)
2884+
2885+
2886+
class MambaCacheDummyInputGenerator(DummyInputGenerator):
2887+
"""
2888+
Generates dummy past_key_values inputs for seq2seq architectures.
2889+
"""
2890+
2891+
SUPPORTED_INPUT_NAMES = ("past_ssm_states", "past_conv_states", "cache_position")
2892+
2893+
def __init__(
2894+
self,
2895+
task: str,
2896+
normalized_config,
2897+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
2898+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
2899+
**kwargs,
2900+
):
2901+
self.normalized_config = normalized_config
2902+
self.batch_size = batch_size
2903+
self.sequence_length = sequence_length
2904+
self.intermediate_size = self.normalized_config.config.intermediate_size
2905+
self.ssm_state_size = self.normalized_config.config.state_size
2906+
self.conv_kernel_size = self.normalized_config.config.conv_kernel
2907+
2908+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
2909+
if input_name == "past_ssm_states":
2910+
ssm_shape = [self.batch_size, self.intermediate_size, self.ssm_state_size]
2911+
return [
2912+
self.random_float_tensor(ssm_shape, framework=framework, dtype=float_dtype)
2913+
for _ in range(self.normalized_config.num_layers)
2914+
]
2915+
2916+
elif input_name == "past_conv_states":
2917+
conv_shape = [self.batch_size, self.intermediate_size, self.conv_kernel_size]
2918+
return [
2919+
self.random_float_tensor(conv_shape, framework=framework, dtype=float_dtype)
2920+
for _ in range(self.normalized_config.num_layers)
2921+
]
2922+
2923+
elif input_name == "cache_position":
2924+
return self.random_int_tensor(
2925+
shape=[self.conv_kernel_size],
2926+
max_value=self.sequence_length,
2927+
framework=framework,
2928+
dtype=int_dtype,
2929+
)
2930+
2931+
raise ValueError(f"Unsupported input name {input_name}")
2932+
2933+
2934+
@register_in_tasks_manager("mamba", *["text-generation", "text-generation-with-past"], library_name="transformers")
2935+
class MambaOpenVINOConfig(TextDecoderOnnxConfig):
2936+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MambaCacheDummyInputGenerator)
2937+
DUMMY_PKV_GENERATOR_CLASS = MambaCacheDummyInputGenerator
2938+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
2939+
2940+
@property
2941+
def inputs(self) -> Dict[str, Dict[int, str]]:
2942+
if self.use_past_in_inputs:
2943+
common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}}
2944+
self.add_past_key_values(common_inputs, direction="inputs")
2945+
# common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
2946+
common_inputs["cache_position"] = {0: "cache_sequence_length"}
2947+
else:
2948+
common_inputs = {
2949+
"input_ids": {0: "batch_size", 1: "sequence_length"},
2950+
# "attention_mask": {0: "batch_size", 1: "sequence_length"},
2951+
"cache_position": {0: "cache_sequence_length"},
2952+
}
2953+
return common_inputs
2954+
2955+
def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
2956+
"""
2957+
Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction.
2958+
2959+
Args:
2960+
inputs_or_outputs (`Dict[str, Dict[int, str]]`):
2961+
The mapping to fill.
2962+
direction (`str`):
2963+
either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the
2964+
output mapping, this is important for axes naming.
2965+
"""
2966+
if direction not in ["inputs", "outputs"]:
2967+
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
2968+
2969+
if direction == "inputs":
2970+
ssm_name = "past_ssm_states"
2971+
conv_name = "past_conv_states"
2972+
else:
2973+
ssm_name = "present_ssm_states"
2974+
conv_name = "present_conv_states"
2975+
2976+
for i in range(self._normalized_config.num_layers):
2977+
inputs_or_outputs[f"{ssm_name}.{i}"] = {0: "batch_size"}
2978+
2979+
for i in range(self._normalized_config.num_layers):
2980+
inputs_or_outputs[f"{conv_name}.{i}"] = {0: "batch_size"}
2981+
2982+
def patch_model_for_export(
2983+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
2984+
):
2985+
return MambaPatcher(self, model, model_kwargs)
2986+
2987+
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
2988+
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)
2989+
2990+
dummy_inputs = {}
2991+
input_names = [key for key in self.inputs.keys() if not key.startswith("past_")]
2992+
if self.use_past_in_inputs and self.use_cache_branch is not False:
2993+
input_names.extend(["past_ssm_states", "past_conv_states"])
2994+
2995+
for input_name in input_names:
2996+
input_was_inserted = False
2997+
for dummy_input_gen in dummy_inputs_generators:
2998+
if dummy_input_gen.supports_input(input_name):
2999+
dummy_inputs[input_name] = self.overwrite_shape_and_generate_input(
3000+
dummy_input_gen,
3001+
input_name,
3002+
framework,
3003+
input_shapes=kwargs,
3004+
)
3005+
input_was_inserted = True
3006+
break
3007+
if not input_was_inserted:
3008+
raise RuntimeError(
3009+
f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.'
3010+
)
3011+
3012+
return dummy_inputs

0 commit comments

Comments
 (0)