From 7c662c8a0f965e14a7f5ecc545b57ad4190d4cd8 Mon Sep 17 00:00:00 2001 From: Fghjjjjjk Date: Sat, 18 Oct 2025 19:41:25 +0530 Subject: [PATCH] feat: Add native Gemma3 support for ONNX export - Add Gemma3OnnxConfig class with proper configuration - Register gemma3 model type for text generation and classification tasks - Add Gemma3 to supported architectures and test mappings - Set minimum transformers version requirement to 4.50.0 - Follow same pattern as existing Gemma/Gemma2 implementations Fixes: ValueError when exporting Gemma3 models to ONNX format Resolves: 'gemma3 model, that is a custom or unsupported architecture' error --- optimum/exporters/onnx/model_configs.py | 9 +++++++++ optimum/exporters/onnx/utils.py | 2 ++ tests/exporters/onnx/utils_tests.py | 2 ++ tests/onnxruntime/test_decoder.py | 3 +++ tests/onnxruntime/testing_utils.py | 1 + 5 files changed, 17 insertions(+) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 92d0dce2..af4fcce3 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -517,6 +517,15 @@ class Gemma2OnnxConfig(TextDecoderOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.53.0") +@register_tasks_manager_onnx("gemma3", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification"]) +class Gemma3OnnxConfig(TextDecoderOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator + # Gemma 3 was added in transformers v4.50.0 + MIN_TRANSFORMERS_VERSION = version.parse("4.50.0") + + @register_tasks_manager_onnx("gpt_oss", *COMMON_TEXT_GENERATION_TASKS) class GPTOssOnnxConfig(GemmaOnnxConfig): MIN_TRANSFORMERS_VERSION = version.parse("4.55.0") diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 4d0b8413..3b7e0105 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -38,6 +38,8 @@ "cohere", "falcon", "gemma", + "gemma2", + "gemma3", "glm", "gpt2", "gpt_bigcode", diff --git a/tests/exporters/onnx/utils_tests.py b/tests/exporters/onnx/utils_tests.py index 705921df..d07a7019 100644 --- a/tests/exporters/onnx/utils_tests.py +++ b/tests/exporters/onnx/utils_tests.py @@ -103,6 +103,7 @@ "flaubert": "hf-internal-testing/tiny-random-flaubert", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", "gemma2": "hf-internal-testing/tiny-random-Gemma2ForCausalLM", + "gemma3": "hf-internal-testing/tiny-random-Gemma3ForCausalLM", "glm": "hf-internal-testing/tiny-random-GlmForCausalLM", "glpn": "hf-internal-testing/tiny-random-GLPNModel", "gpt2": "hf-internal-testing/tiny-random-gpt2", @@ -270,6 +271,7 @@ "encoder-decoder": "patrickvonplaten/bert2bert_cnn_daily_mail", "flaubert": "flaubert/flaubert_small_cased", "gemma": "google/gemma-2b", + "gemma3": "google/gemma-2-2b", "gpt2": "gpt2", "gpt_neo": "EleutherAI/gpt-neo-125M", "gpt_neox": "EleutherAI/gpt-neox-20b", diff --git a/tests/onnxruntime/test_decoder.py b/tests/onnxruntime/test_decoder.py index 317085ed..59cf216d 100644 --- a/tests/onnxruntime/test_decoder.py +++ b/tests/onnxruntime/test_decoder.py @@ -33,6 +33,7 @@ CohereOnnxConfig, DeepSeekV3OnnxConfig, Gemma2OnnxConfig, + Gemma3OnnxConfig, GemmaOnnxConfig, GLMOnnxConfig, GPTOssOnnxConfig, @@ -118,6 +119,8 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): SUPPORTED_ARCHITECTURES.append("gemma") if is_transformers_version(">=", str(Gemma2OnnxConfig.MIN_TRANSFORMERS_VERSION)): SUPPORTED_ARCHITECTURES.append("gemma2") + if is_transformers_version(">=", str(Gemma3OnnxConfig.MIN_TRANSFORMERS_VERSION)): + SUPPORTED_ARCHITECTURES.append("gemma3") if is_transformers_version(">=", str(GLMOnnxConfig.MIN_TRANSFORMERS_VERSION)): SUPPORTED_ARCHITECTURES.append("glm") if is_transformers_version(">=", str(MPTOnnxConfig.MIN_TRANSFORMERS_VERSION)): diff --git a/tests/onnxruntime/testing_utils.py b/tests/onnxruntime/testing_utils.py index e0b2577e..1fd3842f 100644 --- a/tests/onnxruntime/testing_utils.py +++ b/tests/onnxruntime/testing_utils.py @@ -66,6 +66,7 @@ "flux": "optimum-internal-testing/tiny-random-flux", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", "gemma2": "hf-internal-testing/tiny-random-Gemma2ForCausalLM", + "gemma3": "hf-internal-testing/tiny-random-Gemma3ForCausalLM", "glm": "hf-internal-testing/tiny-random-GlmForCausalLM", "gpt2": "hf-internal-testing/tiny-random-GPT2LMHeadModel", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",