Skip to content

Commit 7c2537f

Browse files
Update dummy input generator
1 parent 10fb672 commit 7c2537f

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

optimum/exporters/onnx/model_configs.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -498,13 +498,28 @@ class GemmaOnnxConfig(LlamaOnnxConfig):
498498
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
499499
MIN_TRANSFORMERS_VERSION = version.parse("4.38.0")
500500

501-
@register_tasks_manager_onnx("gemma3", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification"])
501+
502+
@register_tasks_manager_onnx(
503+
"gemma3", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification"]
504+
)
502505
class Gemma3OnnxConfig(LlamaOnnxConfig):
503-
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
506+
DUMMY_INPUT_GENERATOR_CLASSES = (
507+
DummyTextInputGenerator,
508+
DummyVisionInputGenerator,
509+
)
504510
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
505511
NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("gemma3")
506512
MIN_TRANSFORMERS_VERSION = version.parse("4.52.0.dev0")
507513

514+
# TODO: check if we need this
515+
# @property
516+
# def inputs(self) -> dict[str, dict[int, str]]:
517+
# return {
518+
# "input_ids": {0: "batch_size", 1: "sequence_length"},
519+
# "attention_mask": {0: "batch_size", 1: "sequence_length"},
520+
# "pixel_values": {0: "batch_size", 1: "num_channels", 2: "image_size", 3: "image_size"},
521+
# }
522+
508523

509524
@register_tasks_manager_onnx("granite", *COMMON_TEXT_GENERATION_TASKS)
510525
class GraniteOnnxConfig(LlamaOnnxConfig):

0 commit comments

Comments
 (0)