Skip to content

Commit e26d7c0

Browse files
committed
support mamba models
1 parent 755a833 commit e26d7c0

File tree

4 files changed

+620
-3
lines changed

4 files changed

+620
-3
lines changed

optimum/exporters/openvino/model_configs.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
LlamaModelPatcher,
9696
LlavaImageEmbeddingModelPatcher,
9797
LlavaQwen2ImageEmbeddingsModelPatcher,
98+
MambaPatcher,
9899
MiniCPM3Patcher,
99100
MiniCPMModelPatcher,
100101
MiniCPMVImageEmbeddingsModelPatcher,
@@ -2880,3 +2881,126 @@ def patch_model_for_export(
28802881
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
28812882
) -> "ModelPatcher":
28822883
return DeepseekPatcher(self, model, model_kwargs=model_kwargs)
2884+
2885+
2886+
class MambaCacheDummyInputGenerator(DummyInputGenerator):
2887+
"""
2888+
Generates dummy past_key_values inputs for seq2seq architectures.
2889+
"""
2890+
2891+
SUPPORTED_INPUT_NAMES = ("past_ssm_states", "past_conv_states", "cache_position")
2892+
2893+
def __init__(
2894+
self,
2895+
task: str,
2896+
normalized_config,
2897+
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
2898+
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
2899+
**kwargs,
2900+
):
2901+
self.normalized_config = normalized_config
2902+
self.batch_size = batch_size
2903+
self.sequence_length = sequence_length
2904+
self.intermediate_size = self.normalized_config.config.intermediate_size
2905+
self.ssm_state_size = self.normalized_config.config.state_size
2906+
self.conv_kernel_size = self.normalized_config.config.conv_kernel
2907+
2908+
2909+
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
2910+
if input_name == "past_ssm_states":
2911+
ssm_shape = [self.batch_size, self.intermediate_size, self.ssm_state_size]
2912+
return [self.random_float_tensor(ssm_shape, framework=framework, dtype=float_dtype) for _ in range(self.normalized_config.num_layers)]
2913+
2914+
elif input_name == "past_conv_states":
2915+
conv_shape = [self.batch_size, self.intermediate_size, self.conv_kernel_size]
2916+
return [self.random_float_tensor(conv_shape, framework=framework, dtype=float_dtype) for _ in range(self.normalized_config.num_layers)]
2917+
2918+
elif input_name == "cache_position":
2919+
return self.random_int_tensor(
2920+
shape=[self.conv_kernel_size],
2921+
max_value=self.sequence_length,
2922+
framework=framework,
2923+
dtype=int_dtype,
2924+
)
2925+
2926+
raise ValueError(f"Unsupported input name {input_name}")
2927+
2928+
@register_in_tasks_manager(
2929+
"mamba", *["text-generation", "text-generation-with-past"], library_name="transformers"
2930+
)
2931+
class MambaOpenVINOConfig(TextDecoderOnnxConfig):
2932+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MambaCacheDummyInputGenerator)
2933+
DUMMY_PKV_GENERATOR_CLASS = MambaCacheDummyInputGenerator
2934+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
2935+
2936+
@property
2937+
def inputs(self) -> Dict[str, Dict[int, str]]:
2938+
if self.use_past_in_inputs:
2939+
common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}}
2940+
self.add_past_key_values(common_inputs, direction="inputs")
2941+
#common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
2942+
common_inputs["cache_position"] = {0: "cache_sequence_length"}
2943+
else:
2944+
common_inputs = {
2945+
"input_ids": {0: "batch_size", 1: "sequence_length"},
2946+
#"attention_mask": {0: "batch_size", 1: "sequence_length"},
2947+
"cache_position": {0: "cache_sequence_length"}
2948+
}
2949+
return common_inputs
2950+
2951+
def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
2952+
"""
2953+
Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction.
2954+
2955+
Args:
2956+
inputs_or_outputs (`Dict[str, Dict[int, str]]`):
2957+
The mapping to fill.
2958+
direction (`str`):
2959+
either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the
2960+
output mapping, this is important for axes naming.
2961+
"""
2962+
if direction not in ["inputs", "outputs"]:
2963+
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
2964+
2965+
if direction == "inputs":
2966+
ssm_name = "past_ssm_states"
2967+
conv_name = "past_conv_states"
2968+
else:
2969+
ssm_name = "present_ssm_states"
2970+
conv_name = "present_conv_states"
2971+
2972+
for i in range(self._normalized_config.num_layers):
2973+
inputs_or_outputs[f"{ssm_name}.{i}"] = {0: "batch_size"}
2974+
2975+
for i in range(self._normalized_config.num_layers):
2976+
inputs_or_outputs[f"{conv_name}.{i}"] = {0: "batch_size"}
2977+
2978+
def patch_model_for_export(self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None):
2979+
return MambaPatcher(self, model, model_kwargs)
2980+
2981+
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
2982+
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)
2983+
2984+
dummy_inputs = {}
2985+
input_names = [key for key in self.inputs.keys() if not key.startswith("past_")]
2986+
if self.use_past_in_inputs and self.use_cache_branch is not False:
2987+
input_names.extend(["past_ssm_states", "past_conv_states"])
2988+
2989+
for input_name in input_names:
2990+
input_was_inserted = False
2991+
for dummy_input_gen in dummy_inputs_generators:
2992+
if dummy_input_gen.supports_input(input_name):
2993+
dummy_inputs[input_name] = self.overwrite_shape_and_generate_input(
2994+
dummy_input_gen,
2995+
input_name,
2996+
framework,
2997+
input_shapes=kwargs,
2998+
)
2999+
input_was_inserted = True
3000+
break
3001+
if not input_was_inserted:
3002+
raise RuntimeError(
3003+
f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.'
3004+
)
3005+
3006+
return dummy_inputs

0 commit comments

Comments
 (0)