Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,7 @@ Once the model is exported to the ONNX format, we provide Python classes enablin
```

More details on how to run ONNX models with `ORTModelForXXX` classes [here](https://huggingface.co/docs/optimum/main/en/onnxruntime/usage_guides/models).

### Examples

Check out the [examples folder](./examples) for more usage examples including optimization, quantization, and model-specific demonstrations.
25 changes: 25 additions & 0 deletions examples/gemma3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Simple example: Export Gemma3 270M to ONNX and generate text.

Usage:
uv pip install onnxruntime
uv run examples/gemma3.py
"""

from transformers import AutoTokenizer

from optimum.onnxruntime import ORTModelForCausalLM


model_id = "google/gemma-3-270m-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = ORTModelForCausalLM.from_pretrained(model_id, export=True)

# Chat with instruction-tuned model
conversation = [{"role": "user", "content": "Hello! How are you?"}]
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt")

outputs = model.generate(**inputs, max_new_tokens=50, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(response)
18 changes: 12 additions & 6 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,19 +501,25 @@ class Qwen3MoeOnnxConfig(LlamaOnnxConfig):


@register_tasks_manager_onnx("gemma", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification"])
class GemmaOnnxConfig(LlamaOnnxConfig):
class GemmaOnnxConfig(TextDecoderOnnxConfig):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I discovered that gemma models in general don't need the position ids argument

Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil Oct 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@echarlaix wdyt ? this also removes the need for position ids from gpt_oss and nemotron

NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
MIN_TRANSFORMERS_VERSION = version.parse("4.38.0")


@register_tasks_manager_onnx("gemma2", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification"])
class Gemma2OnnxConfig(TextDecoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
class Gemma2OnnxConfig(GemmaOnnxConfig):
# Gemma 2 was added in transformers v4.42 using HybridCache
# (tuple of past_key_values never supported), DynamicCache since v4.53
# DynamicCache support was added since v4.53
MIN_TRANSFORMERS_VERSION = version.parse("4.53.0")


@register_tasks_manager_onnx("gemma3", *COMMON_TEXT_GENERATION_TASKS, "text-classification")
@register_tasks_manager_onnx("gemma3_text", *COMMON_TEXT_GENERATION_TASKS, "text-classification")
class Gemma3OnnxConfig(GemmaOnnxConfig):
# Gemma 3 was added in transformers v4.50 using HybridCache
# DynamicCache support was added since v4.53
MIN_TRANSFORMERS_VERSION = version.parse("4.53.0")


Expand Down
3 changes: 0 additions & 3 deletions optimum/exporters/onnx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,21 +37,18 @@
"deepseek_v3",
"cohere",
"falcon",
"gemma",
"glm",
"gpt2",
"gpt_bigcode",
"gpt_neo",
"gpt_neox",
"gpt_oss",
"gptj",
"granite",
"helium",
"imagegpt",
"internlm2",
"llama",
"mistral",
"nemotron",
"phi",
"phi3",
"qwen2",
Expand Down
4 changes: 3 additions & 1 deletion optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __init__(
"To re-export your model, simply set `export=True` as in `from_pretrained(..., export=True, use_cache=True)`."
)

if self.config.model_type in {"gemma", "gpt_oss", "nemotron"}:
if self.config.model_type in {"gemma", "gemma3", "gemma3_text", "gpt_oss", "nemotron"}:
self.embed_size_per_head = self.config.head_dim
elif self.old_gpt_bigcode_modeling:
# (before v4.54) GPT BigCode fuses keys and values in one tensor, doubling the head dimension
Expand All @@ -202,6 +202,8 @@ def __init__(
"deepseek_v3",
"cohere",
"gemma",
"gemma3",
"gemma3_text",
"glm",
"granite",
"gpt_oss",
Expand Down
2 changes: 2 additions & 0 deletions tests/exporters/onnx/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@
"flaubert": "hf-internal-testing/tiny-random-flaubert",
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
"gemma2": "hf-internal-testing/tiny-random-Gemma2ForCausalLM",
"gemma3": "hf-internal-testing/tiny-random-Gemma3ForConditionalGeneration",
"gemma3_text": "hf-internal-testing/tiny-random-Gemma3ForCausalLM",
"glm": "hf-internal-testing/tiny-random-GlmForCausalLM",
"glpn": "hf-internal-testing/tiny-random-GLPNModel",
"gpt2": "hf-internal-testing/tiny-random-gpt2",
Expand Down
13 changes: 12 additions & 1 deletion tests/onnxruntime/test_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CohereOnnxConfig,
DeepSeekV3OnnxConfig,
Gemma2OnnxConfig,
Gemma3OnnxConfig,
GemmaOnnxConfig,
GLMOnnxConfig,
GPTOssOnnxConfig,
Expand Down Expand Up @@ -118,6 +119,8 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES.append("gemma")
if is_transformers_version(">=", str(Gemma2OnnxConfig.MIN_TRANSFORMERS_VERSION)):
SUPPORTED_ARCHITECTURES.append("gemma2")
if is_transformers_version(">=", str(Gemma3OnnxConfig.MIN_TRANSFORMERS_VERSION)):
SUPPORTED_ARCHITECTURES.extend(["gemma3", "gemma3_text"])
if is_transformers_version(">=", str(GLMOnnxConfig.MIN_TRANSFORMERS_VERSION)):
SUPPORTED_ARCHITECTURES.append("glm")
if is_transformers_version(">=", str(MPTOnnxConfig.MIN_TRANSFORMERS_VERSION)):
Expand Down Expand Up @@ -306,9 +309,17 @@ def test_find_untested_architectures(self):
if "gemma2" in supported_architectures and is_transformers_version(
"<", str(Gemma2OnnxConfig.MIN_TRANSFORMERS_VERSION)
):
# Gemma 2 was added in transformers v4.42 using HybridCache (tuple of past_key_values never supported), DynamicCache since v4.53
# Gemma 2 was added in transformers v4.42 supporting HybridCache only,
# DynamicCache support was added since v4.53
supported_architectures.remove("gemma2")

if "gemma3" in supported_architectures and is_transformers_version(
"<", str(Gemma3OnnxConfig.MIN_TRANSFORMERS_VERSION)
):
# Gemma 3 was added in transformers v4.50 supporting HybridCache only,
# DynamicCache support was added since v4.53
supported_architectures.remove("gemma3")

untested_architectures = supported_architectures - tested_architectures

if len(untested_architectures) > 0:
Expand Down
2 changes: 2 additions & 0 deletions tests/onnxruntime/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
"flux": "optimum-internal-testing/tiny-random-flux",
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
"gemma2": "hf-internal-testing/tiny-random-Gemma2ForCausalLM",
"gemma3": "hf-internal-testing/tiny-random-Gemma3ForCausalLM",
"gemma3_text": "hf-internal-testing/tiny-random-Gemma3ForCausalLM",
"glm": "hf-internal-testing/tiny-random-GlmForCausalLM",
"gpt2": "hf-internal-testing/tiny-random-GPT2LMHeadModel",
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",
Expand Down
Loading