Skip to content

Commit e663633

Browse files
committed
feat: add support for gemma3-text
simplify the example
1 parent 7226a01 commit e663633

File tree

5 files changed

+3221
-2
lines changed

5 files changed

+3221
-2
lines changed

examples/gemma3.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Simple example: Export Gemma3 270M to ONNX and generate text.
2+
3+
Usage:
4+
uv pip install onnxruntime
5+
uv run examples/gemma3.py
6+
"""
7+
8+
from transformers import AutoTokenizer
9+
10+
from optimum.onnxruntime import ORTModelForCausalLM
11+
12+
13+
model_id = "google/gemma-3-270m-it"
14+
tokenizer = AutoTokenizer.from_pretrained(model_id)
15+
model = ORTModelForCausalLM.from_pretrained(model_id, export=True)
16+
17+
# Chat with instruction-tuned model
18+
conversation = [{"role": "user", "content": "Hello! How are you?"}]
19+
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
20+
inputs = tokenizer(prompt, return_tensors="pt")
21+
22+
outputs = model.generate(**inputs, max_new_tokens=50, pad_token_id=tokenizer.eos_token_id)
23+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
24+
25+
print(response)

optimum/exporters/onnx/model_configs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
CohereModelPatcher,
4444
FluxTransformerModelPatcher,
4545
MetaCLIP2Patcher,
46+
Gemma3LMModelPatcher,
4647
MgpstrModelPatcher,
4748
MoonshineModelPatcher,
4849
MusicgenModelPatcher,
@@ -517,6 +518,14 @@ class Gemma2OnnxConfig(TextDecoderOnnxConfig):
517518
MIN_TRANSFORMERS_VERSION = version.parse("4.53.0")
518519

519520

521+
@register_tasks_manager_onnx("gemma3", *COMMON_TEXT_GENERATION_TASKS)
522+
@register_tasks_manager_onnx("gemma3_text", *COMMON_TEXT_GENERATION_TASKS)
523+
class Gemma3OnnxConfig(GemmaOnnxConfig):
524+
"""ONNX config for Gemma3 text-only models."""
525+
MIN_TRANSFORMERS_VERSION = version.parse("4.52.0")
526+
_MODEL_PATCHER = Gemma3LMModelPatcher
527+
528+
520529
@register_tasks_manager_onnx("gpt_oss", *COMMON_TEXT_GENERATION_TASKS)
521530
class GPTOssOnnxConfig(GemmaOnnxConfig):
522531
MIN_TRANSFORMERS_VERSION = version.parse("4.55.0")

optimum/exporters/onnx/model_patcher.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import inspect
1919
import sys
2020
import types
21-
from typing import TYPE_CHECKING, Any, Callable
21+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
2222

2323
import torch
2424
import transformers
@@ -30,6 +30,7 @@
3030
jit_utils,
3131
symbolic_helper,
3232
)
33+
from transformers import PreTrainedModel, TFPreTrainedModel
3334
from transformers.modeling_outputs import BaseModelOutput
3435
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
3536

@@ -1444,3 +1445,57 @@ def __exit__(self, exc_type, exc_value, traceback):
14441445
from transformers.models.cohere.modeling_cohere import CohereRotaryEmbedding
14451446

14461447
CohereRotaryEmbedding.forward = self.original_forward
1448+
1449+
1450+
class Gemma3LMModelPatcher(DecoderModelPatcher):
1451+
"""Patcher for Gemma3 language model to handle cache conversion for ONNX export."""
1452+
1453+
def __init__(
1454+
self,
1455+
config,
1456+
model: Union[PreTrainedModel, TFPreTrainedModel],
1457+
model_kwargs: Optional[Dict[str, Any]] = None,
1458+
):
1459+
def forward(
1460+
self,
1461+
attention_mask,
1462+
position_ids,
1463+
past_key_values,
1464+
inputs_embeds,
1465+
use_cache=True,
1466+
):
1467+
from transformers.cache_utils import DynamicCache
1468+
1469+
pkv = DynamicCache.from_legacy_cache(past_key_values)
1470+
1471+
past_seen_tokens = past_key_values[0][0].shape[-2] if past_key_values is not None else 0
1472+
cache_position = torch.arange(
1473+
past_seen_tokens,
1474+
past_seen_tokens + inputs_embeds.shape[1],
1475+
device=inputs_embeds.device,
1476+
)
1477+
1478+
result = self.__orig_forward(
1479+
input_ids=None,
1480+
attention_mask=attention_mask,
1481+
position_ids=position_ids,
1482+
cache_position=cache_position,
1483+
past_key_values=pkv,
1484+
inputs_embeds=inputs_embeds,
1485+
use_cache=use_cache,
1486+
)
1487+
upd_pkv = result["past_key_values"]
1488+
result["past_key_values"] = upd_pkv.to_legacy_cache()
1489+
return result
1490+
1491+
if is_transformers_version("<", "4.53.0"):
1492+
model.__orig_forward = model.forward
1493+
model.forward = types.MethodType(forward, model)
1494+
1495+
super().__init__(config, model, model_kwargs)
1496+
1497+
def __exit__(self, exc_type, exc_value, traceback):
1498+
super().__exit__(exc_type, exc_value, traceback)
1499+
1500+
if is_transformers_version("<", "4.53.0"):
1501+
self._model.forward = self._model.__orig_forward

optimum/onnxruntime/modeling_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def __init__(
185185
"To re-export your model, simply set `export=True` as in `from_pretrained(..., export=True, use_cache=True)`."
186186
)
187187

188-
if self.config.model_type in {"gemma", "gpt_oss", "nemotron"}:
188+
if self.config.model_type in {"gemma", "gemma3", "gemma3_text", "gpt_oss", "nemotron"}:
189189
self.embed_size_per_head = self.config.head_dim
190190
elif self.old_gpt_bigcode_modeling:
191191
# (before v4.54) GPT BigCode fuses keys and values in one tensor, doubling the head dimension
@@ -202,6 +202,8 @@ def __init__(
202202
"deepseek_v3",
203203
"cohere",
204204
"gemma",
205+
"gemma3",
206+
"gemma3_text",
205207
"glm",
206208
"granite",
207209
"gpt_oss",

0 commit comments

Comments
 (0)