@@ -506,13 +506,28 @@ class GemmaOnnxConfig(LlamaOnnxConfig):
506506 DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
507507 MIN_TRANSFORMERS_VERSION = version .parse ("4.38.0" )
508508
509- @register_tasks_manager_onnx ("gemma3" , * [* COMMON_TEXT_GENERATION_TASKS , "text-classification" ])
509+
510+ @register_tasks_manager_onnx (
511+ "gemma3" , * [* COMMON_TEXT_GENERATION_TASKS , "text-classification" ]
512+ )
510513class Gemma3OnnxConfig (LlamaOnnxConfig ):
511- DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator , GemmaDummyPastKeyValuesGenerator )
514+ DUMMY_INPUT_GENERATOR_CLASSES = (
515+ DummyTextInputGenerator ,
516+ DummyVisionInputGenerator ,
517+ )
512518 DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
513519 NORMALIZED_CONFIG_CLASS = NormalizedConfigManager .get_normalized_config_class ("gemma3" )
514520 MIN_TRANSFORMERS_VERSION = version .parse ("4.52.0.dev0" )
515521
522+ # TODO: check if we need this
523+ # @property
524+ # def inputs(self) -> dict[str, dict[int, str]]:
525+ # return {
526+ # "input_ids": {0: "batch_size", 1: "sequence_length"},
527+ # "attention_mask": {0: "batch_size", 1: "sequence_length"},
528+ # "pixel_values": {0: "batch_size", 1: "num_channels", 2: "image_size", 3: "image_size"},
529+ # }
530+
516531
517532@register_tasks_manager_onnx ("nemotron" , * COMMON_TEXT_GENERATION_TASKS )
518533class NemotronOnnxConfig (GemmaOnnxConfig ):
0 commit comments