@@ -498,13 +498,28 @@ class GemmaOnnxConfig(LlamaOnnxConfig):
498498 DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
499499 MIN_TRANSFORMERS_VERSION = version .parse ("4.38.0" )
500500
501- @register_tasks_manager_onnx ("gemma3" , * [* COMMON_TEXT_GENERATION_TASKS , "text-classification" ])
501+
502+ @register_tasks_manager_onnx (
503+ "gemma3" , * [* COMMON_TEXT_GENERATION_TASKS , "text-classification" ]
504+ )
502505class Gemma3OnnxConfig (LlamaOnnxConfig ):
503- DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator , GemmaDummyPastKeyValuesGenerator )
506+ DUMMY_INPUT_GENERATOR_CLASSES = (
507+ DummyTextInputGenerator ,
508+ DummyVisionInputGenerator ,
509+ )
504510 DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
505511 NORMALIZED_CONFIG_CLASS = NormalizedConfigManager .get_normalized_config_class ("gemma3" )
506512 MIN_TRANSFORMERS_VERSION = version .parse ("4.52.0.dev0" )
507513
514+ # TODO: check if we need this
515+ # @property
516+ # def inputs(self) -> dict[str, dict[int, str]]:
517+ # return {
518+ # "input_ids": {0: "batch_size", 1: "sequence_length"},
519+ # "attention_mask": {0: "batch_size", 1: "sequence_length"},
520+ # "pixel_values": {0: "batch_size", 1: "num_channels", 2: "image_size", 3: "image_size"},
521+ # }
522+
508523
509524@register_tasks_manager_onnx ("granite" , * COMMON_TEXT_GENERATION_TASKS )
510525class GraniteOnnxConfig (LlamaOnnxConfig ):
0 commit comments