Skip to content

Commit 97c09b9

Browse files
Update dummy input generator
1 parent cd831b6 commit 97c09b9

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

optimum/exporters/onnx/model_configs.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,13 +506,28 @@ class GemmaOnnxConfig(LlamaOnnxConfig):
506506
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
507507
MIN_TRANSFORMERS_VERSION = version.parse("4.38.0")
508508

509-
@register_tasks_manager_onnx("gemma3", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification"])
509+
510+
@register_tasks_manager_onnx(
511+
"gemma3", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification"]
512+
)
510513
class Gemma3OnnxConfig(LlamaOnnxConfig):
511-
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
514+
DUMMY_INPUT_GENERATOR_CLASSES = (
515+
DummyTextInputGenerator,
516+
DummyVisionInputGenerator,
517+
)
512518
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
513519
NORMALIZED_CONFIG_CLASS = NormalizedConfigManager.get_normalized_config_class("gemma3")
514520
MIN_TRANSFORMERS_VERSION = version.parse("4.52.0.dev0")
515521

522+
# TODO: check if we need this
523+
# @property
524+
# def inputs(self) -> dict[str, dict[int, str]]:
525+
# return {
526+
# "input_ids": {0: "batch_size", 1: "sequence_length"},
527+
# "attention_mask": {0: "batch_size", 1: "sequence_length"},
528+
# "pixel_values": {0: "batch_size", 1: "num_channels", 2: "image_size", 3: "image_size"},
529+
# }
530+
516531

517532
@register_tasks_manager_onnx("nemotron", *COMMON_TEXT_GENERATION_TASKS)
518533
class NemotronOnnxConfig(GemmaOnnxConfig):

0 commit comments

Comments
 (0)