Skip to content

Commit 3f08a26

Browse files
committed
feat: add support for gemma3-text
simplify the example
1 parent 671b84f commit 3f08a26

File tree

5 files changed

+3221
-2
lines changed

5 files changed

+3221
-2
lines changed

examples/gemma3.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Simple example: Export Gemma3 270M to ONNX and generate text.
2+
3+
Usage:
4+
uv pip install onnxruntime
5+
uv run examples/gemma3.py
6+
"""
7+
8+
from transformers import AutoTokenizer
9+
10+
from optimum.onnxruntime import ORTModelForCausalLM
11+
12+
13+
model_id = "google/gemma-3-270m-it"
14+
tokenizer = AutoTokenizer.from_pretrained(model_id)
15+
model = ORTModelForCausalLM.from_pretrained(model_id, export=True)
16+
17+
# Chat with instruction-tuned model
18+
conversation = [{"role": "user", "content": "Hello! How are you?"}]
19+
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
20+
inputs = tokenizer(prompt, return_tensors="pt")
21+
22+
outputs = model.generate(**inputs, max_new_tokens=50, pad_token_id=tokenizer.eos_token_id)
23+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
24+
25+
print(response)

optimum/exporters/onnx/model_configs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
CLIPModelPatcher,
4343
CohereModelPatcher,
4444
FluxTransformerModelPatcher,
45+
Gemma3LMModelPatcher,
4546
MgpstrModelPatcher,
4647
MoonshineModelPatcher,
4748
MusicgenModelPatcher,
@@ -516,6 +517,14 @@ class Gemma2OnnxConfig(TextDecoderOnnxConfig):
516517
MIN_TRANSFORMERS_VERSION = version.parse("4.53.0")
517518

518519

520+
@register_tasks_manager_onnx("gemma3", *COMMON_TEXT_GENERATION_TASKS)
521+
@register_tasks_manager_onnx("gemma3_text", *COMMON_TEXT_GENERATION_TASKS)
522+
class Gemma3OnnxConfig(GemmaOnnxConfig):
523+
"""ONNX config for Gemma3 text-only models."""
524+
MIN_TRANSFORMERS_VERSION = version.parse("4.52.0")
525+
_MODEL_PATCHER = Gemma3LMModelPatcher
526+
527+
519528
@register_tasks_manager_onnx("gpt_oss", *COMMON_TEXT_GENERATION_TASKS)
520529
class GPTOssOnnxConfig(GemmaOnnxConfig):
521530
MIN_TRANSFORMERS_VERSION = version.parse("4.55.0")

optimum/exporters/onnx/model_patcher.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import inspect
1919
import sys
2020
import types
21-
from typing import TYPE_CHECKING, Any, Callable
21+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
2222

2323
import torch
2424
import transformers
@@ -30,6 +30,7 @@
3030
jit_utils,
3131
symbolic_helper,
3232
)
33+
from transformers import PreTrainedModel, TFPreTrainedModel
3334
from transformers.modeling_outputs import BaseModelOutput
3435
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
3536

@@ -1418,3 +1419,57 @@ def __exit__(self, exc_type, exc_value, traceback):
14181419
from transformers.models.cohere.modeling_cohere import CohereRotaryEmbedding
14191420

14201421
CohereRotaryEmbedding.forward = self.original_forward
1422+
1423+
1424+
class Gemma3LMModelPatcher(DecoderModelPatcher):
1425+
"""Patcher for Gemma3 language model to handle cache conversion for ONNX export."""
1426+
1427+
def __init__(
1428+
self,
1429+
config,
1430+
model: Union[PreTrainedModel, TFPreTrainedModel],
1431+
model_kwargs: Optional[Dict[str, Any]] = None,
1432+
):
1433+
def forward(
1434+
self,
1435+
attention_mask,
1436+
position_ids,
1437+
past_key_values,
1438+
inputs_embeds,
1439+
use_cache=True,
1440+
):
1441+
from transformers.cache_utils import DynamicCache
1442+
1443+
pkv = DynamicCache.from_legacy_cache(past_key_values)
1444+
1445+
past_seen_tokens = past_key_values[0][0].shape[-2] if past_key_values is not None else 0
1446+
cache_position = torch.arange(
1447+
past_seen_tokens,
1448+
past_seen_tokens + inputs_embeds.shape[1],
1449+
device=inputs_embeds.device,
1450+
)
1451+
1452+
result = self.__orig_forward(
1453+
input_ids=None,
1454+
attention_mask=attention_mask,
1455+
position_ids=position_ids,
1456+
cache_position=cache_position,
1457+
past_key_values=pkv,
1458+
inputs_embeds=inputs_embeds,
1459+
use_cache=use_cache,
1460+
)
1461+
upd_pkv = result["past_key_values"]
1462+
result["past_key_values"] = upd_pkv.to_legacy_cache()
1463+
return result
1464+
1465+
if is_transformers_version("<", "4.53.0"):
1466+
model.__orig_forward = model.forward
1467+
model.forward = types.MethodType(forward, model)
1468+
1469+
super().__init__(config, model, model_kwargs)
1470+
1471+
def __exit__(self, exc_type, exc_value, traceback):
1472+
super().__exit__(exc_type, exc_value, traceback)
1473+
1474+
if is_transformers_version("<", "4.53.0"):
1475+
self._model.forward = self._model.__orig_forward

optimum/onnxruntime/modeling_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def __init__(
185185
"To re-export your model, simply set `export=True` as in `from_pretrained(..., export=True, use_cache=True)`."
186186
)
187187

188-
if self.config.model_type in {"gemma", "gpt_oss", "nemotron"}:
188+
if self.config.model_type in {"gemma", "gemma3", "gemma3_text", "gpt_oss", "nemotron"}:
189189
self.embed_size_per_head = self.config.head_dim
190190
elif self.old_gpt_bigcode_modeling:
191191
# (before v4.54) GPT BigCode fuses keys and values in one tensor, doubling the head dimension
@@ -202,6 +202,8 @@ def __init__(
202202
"deepseek_v3",
203203
"cohere",
204204
"gemma",
205+
"gemma3",
206+
"gemma3_text",
205207
"glm",
206208
"granite",
207209
"gpt_oss",

0 commit comments

Comments
 (0)