diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 92d0dce..af4fcce 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -517,6 +517,15 @@ class Gemma2OnnxConfig(TextDecoderOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.53.0") +@register_tasks_manager_onnx("gemma3", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification"]) +class Gemma3OnnxConfig(TextDecoderOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator + # Gemma 3 was added in transformers v4.50.0 + MIN_TRANSFORMERS_VERSION = version.parse("4.50.0") + + @register_tasks_manager_onnx("gpt_oss", *COMMON_TEXT_GENERATION_TASKS) class GPTOssOnnxConfig(GemmaOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.55.0") diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 4d0b841..3b7e010 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -38,6 +38,8 @@ "cohere", "falcon", "gemma", + "gemma2", + "gemma3", "glm", "gpt2", "gpt_bigcode", diff --git a/tests/exporters/onnx/utils_tests.py b/tests/exporters/onnx/utils_tests.py index 705921d..d07a701 100644 --- a/tests/exporters/onnx/utils_tests.py +++ b/tests/exporters/onnx/utils_tests.py @@ -103,6 +103,7 @@ "flaubert": "hf-internal-testing/tiny-random-flaubert", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", "gemma2": "hf-internal-testing/tiny-random-Gemma2ForCausalLM", + "gemma3": "hf-internal-testing/tiny-random-Gemma3ForCausalLM", "glm": "hf-internal-testing/tiny-random-GlmForCausalLM", "glpn": "hf-internal-testing/tiny-random-GLPNModel", "gpt2": "hf-internal-testing/tiny-random-gpt2", @@ -270,6 +271,7 @@ "encoder-decoder": "patrickvonplaten/bert2bert_cnn_daily_mail", "flaubert": "flaubert/flaubert_small_cased", "gemma": "google/gemma-2b", + "gemma3": "google/gemma-2-2b", "gpt2": "gpt2", "gpt_neo": "EleutherAI/gpt-neo-125M", "gpt_neox": "EleutherAI/gpt-neox-20b", diff --git a/tests/onnxruntime/test_decoder.py b/tests/onnxruntime/test_decoder.py index 317085e..59cf216 100644 --- a/tests/onnxruntime/test_decoder.py +++ b/tests/onnxruntime/test_decoder.py @@ -33,6 +33,7 @@ CohereOnnxConfig, DeepSeekV3OnnxConfig, Gemma2OnnxConfig, + Gemma3OnnxConfig, GemmaOnnxConfig, GLMOnnxConfig, GPTOssOnnxConfig, @@ -118,6 +119,8 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): SUPPORTED_ARCHITECTURES.append("gemma") if is_transformers_version(">=", str(Gemma2OnnxConfig.MIN_TRANSFORMERS_VERSION)): SUPPORTED_ARCHITECTURES.append("gemma2") + if is_transformers_version(">=", str(Gemma3OnnxConfig.MIN_TRANSFORMERS_VERSION)): + SUPPORTED_ARCHITECTURES.append("gemma3") if is_transformers_version(">=", str(GLMOnnxConfig.MIN_TRANSFORMERS_VERSION)): SUPPORTED_ARCHITECTURES.append("glm") if is_transformers_version(">=", str(MPTOnnxConfig.MIN_TRANSFORMERS_VERSION)): diff --git a/tests/onnxruntime/testing_utils.py b/tests/onnxruntime/testing_utils.py index e0b2577..1fd3842 100644 --- a/tests/onnxruntime/testing_utils.py +++ b/tests/onnxruntime/testing_utils.py @@ -66,6 +66,7 @@ "flux": "optimum-internal-testing/tiny-random-flux", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", "gemma2": "hf-internal-testing/tiny-random-Gemma2ForCausalLM", + "gemma3": "hf-internal-testing/tiny-random-Gemma3ForCausalLM", "glm": "hf-internal-testing/tiny-random-GlmForCausalLM", "gpt2": "hf-internal-testing/tiny-random-GPT2LMHeadModel", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",