diff --git a/README.md b/README.md index 31bde36..5c39a6d 100644 --- a/README.md +++ b/README.md @@ -56,3 +56,7 @@ Once the model is exported to the ONNX format, we provide Python classes enablin ``` More details on how to run ONNX models with `ORTModelForXXX` classes [here](https://huggingface.co/docs/optimum/main/en/onnxruntime/usage_guides/models). + +### Examples + +Check out the [examples folder](./examples) for more usage examples including optimization, quantization, and model-specific demonstrations. diff --git a/examples/gemma3.py b/examples/gemma3.py new file mode 100644 index 0000000..80b81af --- /dev/null +++ b/examples/gemma3.py @@ -0,0 +1,25 @@ +"""Simple example: Export Gemma3 270M to ONNX and generate text. + +Usage: + uv pip install onnxruntime + uv run examples/gemma3.py +""" + +from transformers import AutoTokenizer + +from optimum.onnxruntime import ORTModelForCausalLM + + +model_id = "google/gemma-3-270m-it" +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = ORTModelForCausalLM.from_pretrained(model_id, export=True) + +# Chat with instruction-tuned model +conversation = [{"role": "user", "content": "Hello! How are you?"}] +prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) +inputs = tokenizer(prompt, return_tensors="pt") + +outputs = model.generate(**inputs, max_new_tokens=50, pad_token_id=tokenizer.eos_token_id) +response = tokenizer.decode(outputs[0], skip_special_tokens=True) + +print(response) diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 92d0dce..b5be222 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -501,19 +501,25 @@ class Qwen3MoeOnnxConfig(LlamaOnnxConfig): @register_tasks_manager_onnx("gemma", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification"]) -class GemmaOnnxConfig(LlamaOnnxConfig): +class GemmaOnnxConfig(TextDecoderOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator) DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator MIN_TRANSFORMERS_VERSION = version.parse("4.38.0") @register_tasks_manager_onnx("gemma2", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification"]) -class Gemma2OnnxConfig(TextDecoderOnnxConfig): - NORMALIZED_CONFIG_CLASS = NormalizedTextConfig - DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator) - DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator +class Gemma2OnnxConfig(GemmaOnnxConfig): # Gemma 2 was added in transformers v4.42 using HybridCache - # (tuple of past_key_values never supported), DynamicCache since v4.53 + # DynamicCache support was added since v4.53 + MIN_TRANSFORMERS_VERSION = version.parse("4.53.0") + + +@register_tasks_manager_onnx("gemma3", *COMMON_TEXT_GENERATION_TASKS, "text-classification") +@register_tasks_manager_onnx("gemma3_text", *COMMON_TEXT_GENERATION_TASKS, "text-classification") +class Gemma3OnnxConfig(GemmaOnnxConfig): + # Gemma 3 was added in transformers v4.50 using HybridCache + # DynamicCache support was added since v4.53 MIN_TRANSFORMERS_VERSION = version.parse("4.53.0") diff --git a/optimum/exporters/onnx/utils.py b/optimum/exporters/onnx/utils.py index 4d0b841..65d000a 100644 --- a/optimum/exporters/onnx/utils.py +++ b/optimum/exporters/onnx/utils.py @@ -37,13 +37,11 @@ "deepseek_v3", "cohere", "falcon", - "gemma", "glm", "gpt2", "gpt_bigcode", "gpt_neo", "gpt_neox", - "gpt_oss", "gptj", "granite", "helium", @@ -51,7 +49,6 @@ "internlm2", "llama", "mistral", - "nemotron", "phi", "phi3", "qwen2", diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index 6410885..60392d9 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -185,7 +185,7 @@ def __init__( "To re-export your model, simply set `export=True` as in `from_pretrained(..., export=True, use_cache=True)`." ) - if self.config.model_type in {"gemma", "gpt_oss", "nemotron"}: + if self.config.model_type in {"gemma", "gemma3", "gemma3_text", "gpt_oss", "nemotron"}: self.embed_size_per_head = self.config.head_dim elif self.old_gpt_bigcode_modeling: # (before v4.54) GPT BigCode fuses keys and values in one tensor, doubling the head dimension @@ -202,6 +202,8 @@ def __init__( "deepseek_v3", "cohere", "gemma", + "gemma3", + "gemma3_text", "glm", "granite", "gpt_oss", diff --git a/tests/exporters/onnx/utils_tests.py b/tests/exporters/onnx/utils_tests.py index 705921d..be053a0 100644 --- a/tests/exporters/onnx/utils_tests.py +++ b/tests/exporters/onnx/utils_tests.py @@ -103,6 +103,8 @@ "flaubert": "hf-internal-testing/tiny-random-flaubert", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", "gemma2": "hf-internal-testing/tiny-random-Gemma2ForCausalLM", + "gemma3": "hf-internal-testing/tiny-random-Gemma3ForConditionalGeneration", + "gemma3_text": "hf-internal-testing/tiny-random-Gemma3ForCausalLM", "glm": "hf-internal-testing/tiny-random-GlmForCausalLM", "glpn": "hf-internal-testing/tiny-random-GLPNModel", "gpt2": "hf-internal-testing/tiny-random-gpt2", diff --git a/tests/onnxruntime/test_decoder.py b/tests/onnxruntime/test_decoder.py index 317085e..38274d3 100644 --- a/tests/onnxruntime/test_decoder.py +++ b/tests/onnxruntime/test_decoder.py @@ -33,6 +33,7 @@ CohereOnnxConfig, DeepSeekV3OnnxConfig, Gemma2OnnxConfig, + Gemma3OnnxConfig, GemmaOnnxConfig, GLMOnnxConfig, GPTOssOnnxConfig, @@ -118,6 +119,8 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin): SUPPORTED_ARCHITECTURES.append("gemma") if is_transformers_version(">=", str(Gemma2OnnxConfig.MIN_TRANSFORMERS_VERSION)): SUPPORTED_ARCHITECTURES.append("gemma2") + if is_transformers_version(">=", str(Gemma3OnnxConfig.MIN_TRANSFORMERS_VERSION)): + SUPPORTED_ARCHITECTURES.extend(["gemma3", "gemma3_text"]) if is_transformers_version(">=", str(GLMOnnxConfig.MIN_TRANSFORMERS_VERSION)): SUPPORTED_ARCHITECTURES.append("glm") if is_transformers_version(">=", str(MPTOnnxConfig.MIN_TRANSFORMERS_VERSION)): @@ -306,9 +309,17 @@ def test_find_untested_architectures(self): if "gemma2" in supported_architectures and is_transformers_version( "<", str(Gemma2OnnxConfig.MIN_TRANSFORMERS_VERSION) ): - # Gemma 2 was added in transformers v4.42 using HybridCache (tuple of past_key_values never supported), DynamicCache since v4.53 + # Gemma 2 was added in transformers v4.42 supporting HybridCache only, + # DynamicCache support was added since v4.53 supported_architectures.remove("gemma2") + if "gemma3" in supported_architectures and is_transformers_version( + "<", str(Gemma3OnnxConfig.MIN_TRANSFORMERS_VERSION) + ): + # Gemma 3 was added in transformers v4.50 supporting HybridCache only, + # DynamicCache support was added since v4.53 + supported_architectures.remove("gemma3") + untested_architectures = supported_architectures - tested_architectures if len(untested_architectures) > 0: diff --git a/tests/onnxruntime/testing_utils.py b/tests/onnxruntime/testing_utils.py index e0b2577..03b0b7e 100644 --- a/tests/onnxruntime/testing_utils.py +++ b/tests/onnxruntime/testing_utils.py @@ -66,6 +66,8 @@ "flux": "optimum-internal-testing/tiny-random-flux", "gemma": "fxmarty/tiny-random-GemmaForCausalLM", "gemma2": "hf-internal-testing/tiny-random-Gemma2ForCausalLM", + "gemma3": "hf-internal-testing/tiny-random-Gemma3ForCausalLM", + "gemma3_text": "hf-internal-testing/tiny-random-Gemma3ForCausalLM", "glm": "hf-internal-testing/tiny-random-GlmForCausalLM", "gpt2": "hf-internal-testing/tiny-random-GPT2LMHeadModel", "gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",