Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 141 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
DeciLMModelPatcher,
DeepseekPatcher,
FalconModelPatcher,
FalconMambaPatcher,
FluxTransfromerModelPatcher,
Gemma2ModelPatcher,
GptBigCodeModelPatcher,
Expand All @@ -95,6 +96,7 @@
LlamaModelPatcher,
LlavaImageEmbeddingModelPatcher,
LlavaQwen2ImageEmbeddingsModelPatcher,
MambaPatcher,
MiniCPM3Patcher,
MiniCPMModelPatcher,
MiniCPMVImageEmbeddingsModelPatcher,
Expand Down Expand Up @@ -2880,3 +2882,142 @@ def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return DeepseekPatcher(self, model, model_kwargs=model_kwargs)


class MambaCacheDummyInputGenerator(DummyInputGenerator):
"""
Generates dummy past_key_values inputs for seq2seq architectures.
"""

SUPPORTED_INPUT_NAMES = ("past_ssm_states", "past_conv_states", "cache_position")

def __init__(
self,
task: str,
normalized_config,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"],
**kwargs,
):
self.normalized_config = normalized_config
self.batch_size = batch_size
self.sequence_length = sequence_length
self.intermediate_size = self.normalized_config.config.intermediate_size
self.ssm_state_size = self.normalized_config.config.state_size
self.conv_kernel_size = self.normalized_config.config.conv_kernel

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if input_name == "past_ssm_states":
ssm_shape = [self.batch_size, self.intermediate_size, self.ssm_state_size]
return [
self.random_float_tensor(ssm_shape, framework=framework, dtype=float_dtype)
for _ in range(self.normalized_config.num_layers)
]

elif input_name == "past_conv_states":
conv_shape = [self.batch_size, self.intermediate_size, self.conv_kernel_size]
return [
self.random_float_tensor(conv_shape, framework=framework, dtype=float_dtype)
for _ in range(self.normalized_config.num_layers)
]

elif input_name == "cache_position":
return self.random_int_tensor(
shape=[self.conv_kernel_size],
max_value=self.sequence_length,
framework=framework,
dtype=int_dtype,
)

raise ValueError(f"Unsupported input name {input_name}")


@register_in_tasks_manager("mamba", *["text-generation", "text-generation-with-past"], library_name="transformers")
class MambaOpenVINOConfig(TextDecoderOnnxConfig):
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, MambaCacheDummyInputGenerator)
DUMMY_PKV_GENERATOR_CLASS = MambaCacheDummyInputGenerator
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.use_past_in_inputs:
common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}}
self.add_past_key_values(common_inputs, direction="inputs")
# common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"}
common_inputs["cache_position"] = {0: "cache_sequence_length"}
else:
common_inputs = {
"input_ids": {0: "batch_size", 1: "sequence_length"},
# "attention_mask": {0: "batch_size", 1: "sequence_length"},
"cache_position": {0: "cache_sequence_length"},
}
return common_inputs

def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str):
"""
Fills `input_or_outputs` mapping with past_key_values dynamic axes considering the direction.

Args:
inputs_or_outputs (`Dict[str, Dict[int, str]]`):
The mapping to fill.
direction (`str`):
either "inputs" or "outputs", it specifies whether `input_or_outputs` is the input mapping or the
output mapping, this is important for axes naming.
"""
if direction not in ["inputs", "outputs"]:
raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')

if direction == "inputs":
ssm_name = "past_ssm_states"
conv_name = "past_conv_states"
else:
ssm_name = "present_ssm_states"
conv_name = "present_conv_states"

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{ssm_name}.{i}"] = {0: "batch_size"}

for i in range(self._normalized_config.num_layers):
inputs_or_outputs[f"{conv_name}.{i}"] = {0: "batch_size"}

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
):
return MambaPatcher(self, model, model_kwargs)

def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs)

dummy_inputs = {}
input_names = [key for key in self.inputs.keys() if not key.startswith("past_")]
if self.use_past_in_inputs and self.use_cache_branch is not False:
input_names.extend(["past_ssm_states", "past_conv_states"])

for input_name in input_names:
input_was_inserted = False
for dummy_input_gen in dummy_inputs_generators:
if dummy_input_gen.supports_input(input_name):
dummy_inputs[input_name] = self.overwrite_shape_and_generate_input(
dummy_input_gen,
input_name,
framework,
input_shapes=kwargs,
)
input_was_inserted = True
break
if not input_was_inserted:
raise RuntimeError(
f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.'
)

return dummy_inputs


@register_in_tasks_manager(
"falcon-mamba", *["text-generation", "text-generation-with-past"], library_name="transformers"
)
class FalconMambaOpenVINOConfig(MambaOpenVINOConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
):
return FalconMambaPatcher(self, model, model_kwargs)
Loading
Loading