Skip to content

Commit f5df6b5

Browse files
Add support for gemma3-text (#70)
Added support for gemma3-text following the code in: - #50 also added a working example with `gemma3-270m-instruct` will update and improve as needed. Related - #69 - #49 - #45 - huggingface/optimum#1724 - #56 --------- Co-authored-by: Ilyas Moutawwakil <[email protected]> Co-authored-by: IlyasMoutawwakil <[email protected]>
1 parent cd83dcc commit f5df6b5

File tree

8 files changed

+60
-11
lines changed

8 files changed

+60
-11
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,7 @@ Once the model is exported to the ONNX format, we provide Python classes enablin
5656
```
5757

5858
More details on how to run ONNX models with `ORTModelForXXX` classes [here](https://huggingface.co/docs/optimum/main/en/onnxruntime/usage_guides/models).
59+
60+
### Examples
61+
62+
Check out the [examples folder](./examples) for more usage examples including optimization, quantization, and model-specific demonstrations.

examples/gemma3.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Simple example: Export Gemma3 270M to ONNX and generate text.
2+
3+
Usage:
4+
uv pip install onnxruntime
5+
uv run examples/gemma3.py
6+
"""
7+
8+
from transformers import AutoTokenizer
9+
10+
from optimum.onnxruntime import ORTModelForCausalLM
11+
12+
13+
model_id = "google/gemma-3-270m-it"
14+
tokenizer = AutoTokenizer.from_pretrained(model_id)
15+
model = ORTModelForCausalLM.from_pretrained(model_id, export=True)
16+
17+
# Chat with instruction-tuned model
18+
conversation = [{"role": "user", "content": "Hello! How are you?"}]
19+
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
20+
inputs = tokenizer(prompt, return_tensors="pt")
21+
22+
outputs = model.generate(**inputs, max_new_tokens=50, pad_token_id=tokenizer.eos_token_id)
23+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
24+
25+
print(response)

optimum/exporters/onnx/model_configs.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -501,19 +501,25 @@ class Qwen3MoeOnnxConfig(LlamaOnnxConfig):
501501

502502

503503
@register_tasks_manager_onnx("gemma", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification"])
504-
class GemmaOnnxConfig(LlamaOnnxConfig):
504+
class GemmaOnnxConfig(TextDecoderOnnxConfig):
505+
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
505506
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
506507
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
507508
MIN_TRANSFORMERS_VERSION = version.parse("4.38.0")
508509

509510

510511
@register_tasks_manager_onnx("gemma2", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification"])
511-
class Gemma2OnnxConfig(TextDecoderOnnxConfig):
512-
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
513-
DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, GemmaDummyPastKeyValuesGenerator)
514-
DUMMY_PKV_GENERATOR_CLASS = GemmaDummyPastKeyValuesGenerator
512+
class Gemma2OnnxConfig(GemmaOnnxConfig):
515513
# Gemma 2 was added in transformers v4.42 using HybridCache
516-
# (tuple of past_key_values never supported), DynamicCache since v4.53
514+
# DynamicCache support was added since v4.53
515+
MIN_TRANSFORMERS_VERSION = version.parse("4.53.0")
516+
517+
518+
@register_tasks_manager_onnx("gemma3", *COMMON_TEXT_GENERATION_TASKS, "text-classification")
519+
@register_tasks_manager_onnx("gemma3_text", *COMMON_TEXT_GENERATION_TASKS, "text-classification")
520+
class Gemma3OnnxConfig(GemmaOnnxConfig):
521+
# Gemma 3 was added in transformers v4.50 using HybridCache
522+
# DynamicCache support was added since v4.53
517523
MIN_TRANSFORMERS_VERSION = version.parse("4.53.0")
518524

519525

optimum/exporters/onnx/utils.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,18 @@
3737
"deepseek_v3",
3838
"cohere",
3939
"falcon",
40-
"gemma",
4140
"glm",
4241
"gpt2",
4342
"gpt_bigcode",
4443
"gpt_neo",
4544
"gpt_neox",
46-
"gpt_oss",
4745
"gptj",
4846
"granite",
4947
"helium",
5048
"imagegpt",
5149
"internlm2",
5250
"llama",
5351
"mistral",
54-
"nemotron",
5552
"phi",
5653
"phi3",
5754
"qwen2",

optimum/onnxruntime/modeling_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def __init__(
185185
"To re-export your model, simply set `export=True` as in `from_pretrained(..., export=True, use_cache=True)`."
186186
)
187187

188-
if self.config.model_type in {"gemma", "gpt_oss", "nemotron"}:
188+
if self.config.model_type in {"gemma", "gemma3", "gemma3_text", "gpt_oss", "nemotron"}:
189189
self.embed_size_per_head = self.config.head_dim
190190
elif self.old_gpt_bigcode_modeling:
191191
# (before v4.54) GPT BigCode fuses keys and values in one tensor, doubling the head dimension
@@ -202,6 +202,8 @@ def __init__(
202202
"deepseek_v3",
203203
"cohere",
204204
"gemma",
205+
"gemma3",
206+
"gemma3_text",
205207
"glm",
206208
"granite",
207209
"gpt_oss",

tests/exporters/onnx/utils_tests.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@
103103
"flaubert": "hf-internal-testing/tiny-random-flaubert",
104104
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
105105
"gemma2": "hf-internal-testing/tiny-random-Gemma2ForCausalLM",
106+
"gemma3": "hf-internal-testing/tiny-random-Gemma3ForConditionalGeneration",
107+
"gemma3_text": "hf-internal-testing/tiny-random-Gemma3ForCausalLM",
106108
"glm": "hf-internal-testing/tiny-random-GlmForCausalLM",
107109
"glpn": "hf-internal-testing/tiny-random-GLPNModel",
108110
"gpt2": "hf-internal-testing/tiny-random-gpt2",

tests/onnxruntime/test_decoder.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
CohereOnnxConfig,
3434
DeepSeekV3OnnxConfig,
3535
Gemma2OnnxConfig,
36+
Gemma3OnnxConfig,
3637
GemmaOnnxConfig,
3738
GLMOnnxConfig,
3839
GPTOssOnnxConfig,
@@ -118,6 +119,8 @@ class ORTModelForCausalLMIntegrationTest(ORTModelTestMixin):
118119
SUPPORTED_ARCHITECTURES.append("gemma")
119120
if is_transformers_version(">=", str(Gemma2OnnxConfig.MIN_TRANSFORMERS_VERSION)):
120121
SUPPORTED_ARCHITECTURES.append("gemma2")
122+
if is_transformers_version(">=", str(Gemma3OnnxConfig.MIN_TRANSFORMERS_VERSION)):
123+
SUPPORTED_ARCHITECTURES.extend(["gemma3", "gemma3_text"])
121124
if is_transformers_version(">=", str(GLMOnnxConfig.MIN_TRANSFORMERS_VERSION)):
122125
SUPPORTED_ARCHITECTURES.append("glm")
123126
if is_transformers_version(">=", str(MPTOnnxConfig.MIN_TRANSFORMERS_VERSION)):
@@ -306,9 +309,17 @@ def test_find_untested_architectures(self):
306309
if "gemma2" in supported_architectures and is_transformers_version(
307310
"<", str(Gemma2OnnxConfig.MIN_TRANSFORMERS_VERSION)
308311
):
309-
# Gemma 2 was added in transformers v4.42 using HybridCache (tuple of past_key_values never supported), DynamicCache since v4.53
312+
# Gemma 2 was added in transformers v4.42 supporting HybridCache only,
313+
# DynamicCache support was added since v4.53
310314
supported_architectures.remove("gemma2")
311315

316+
if "gemma3" in supported_architectures and is_transformers_version(
317+
"<", str(Gemma3OnnxConfig.MIN_TRANSFORMERS_VERSION)
318+
):
319+
# Gemma 3 was added in transformers v4.50 supporting HybridCache only,
320+
# DynamicCache support was added since v4.53
321+
supported_architectures.remove("gemma3")
322+
312323
untested_architectures = supported_architectures - tested_architectures
313324

314325
if len(untested_architectures) > 0:

tests/onnxruntime/testing_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@
6666
"flux": "optimum-internal-testing/tiny-random-flux",
6767
"gemma": "fxmarty/tiny-random-GemmaForCausalLM",
6868
"gemma2": "hf-internal-testing/tiny-random-Gemma2ForCausalLM",
69+
"gemma3": "hf-internal-testing/tiny-random-Gemma3ForCausalLM",
70+
"gemma3_text": "hf-internal-testing/tiny-random-Gemma3ForCausalLM",
6971
"glm": "hf-internal-testing/tiny-random-GlmForCausalLM",
7072
"gpt2": "hf-internal-testing/tiny-random-GPT2LMHeadModel",
7173
"gpt_bigcode": "hf-internal-testing/tiny-random-GPTBigCodeModel",

0 commit comments

Comments
 (0)