Skip to content

Commit 6772a10

Browse files
committed
feat: add support for gemma3-text
1 parent 671b84f commit 6772a10

File tree

5 files changed

+3249
-2
lines changed

5 files changed

+3249
-2
lines changed

examples/gemma3.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Simple example script for Gemma3 270M text generation using ONNX.
2+
3+
Installation:
4+
uv pip install onnxruntime
5+
6+
Usage:
7+
uv run examples/gemma3.py
8+
"""
9+
10+
from transformers import AutoTokenizer
11+
12+
from optimum.onnxruntime import ORTModelForCausalLM
13+
14+
15+
model_id = "google/gemma-3-270m-it"
16+
tokenizer = AutoTokenizer.from_pretrained(model_id)
17+
18+
# Export to ONNX
19+
model = ORTModelForCausalLM.from_pretrained(
20+
model_id,
21+
export=True,
22+
use_cache=True,
23+
)
24+
25+
# Inference
26+
conversation = [
27+
{"role": "user", "content": "Hello! How are you?"}
28+
]
29+
30+
# Apply chat template
31+
prompt = tokenizer.apply_chat_template(
32+
conversation,
33+
tokenize=False,
34+
add_generation_prompt=True
35+
)
36+
37+
inputs = tokenizer(prompt, return_tensors="pt")
38+
39+
outputs = model.generate(
40+
**inputs,
41+
max_new_tokens=100,
42+
do_sample=True,
43+
temperature=0.7,
44+
top_p=0.9,
45+
pad_token_id=tokenizer.eos_token_id,
46+
)
47+
48+
# Decode
49+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
50+
if prompt in response:
51+
response = response[len(prompt):].strip()
52+
53+
print(f"Response: {response}\n")

optimum/exporters/onnx/model_configs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
CLIPModelPatcher,
4343
CohereModelPatcher,
4444
FluxTransformerModelPatcher,
45+
Gemma3LMModelPatcher,
4546
MgpstrModelPatcher,
4647
MoonshineModelPatcher,
4748
MusicgenModelPatcher,
@@ -516,6 +517,14 @@ class Gemma2OnnxConfig(TextDecoderOnnxConfig):
516517
MIN_TRANSFORMERS_VERSION = version.parse("4.53.0")
517518

518519

520+
@register_tasks_manager_onnx("gemma3", *COMMON_TEXT_GENERATION_TASKS)
521+
@register_tasks_manager_onnx("gemma3_text", *COMMON_TEXT_GENERATION_TASKS)
522+
class Gemma3OnnxConfig(GemmaOnnxConfig):
523+
"""ONNX config for Gemma3 text-only models."""
524+
MIN_TRANSFORMERS_VERSION = version.parse("4.52.0")
525+
_MODEL_PATCHER = Gemma3LMModelPatcher
526+
527+
519528
@register_tasks_manager_onnx("gpt_oss", *COMMON_TEXT_GENERATION_TASKS)
520529
class GPTOssOnnxConfig(GemmaOnnxConfig):
521530
MIN_TRANSFORMERS_VERSION = version.parse("4.55.0")

optimum/exporters/onnx/model_patcher.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import inspect
1919
import sys
2020
import types
21-
from typing import TYPE_CHECKING, Any, Callable
21+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
2222

2323
import torch
2424
import transformers
@@ -30,6 +30,7 @@
3030
jit_utils,
3131
symbolic_helper,
3232
)
33+
from transformers import PreTrainedModel, TFPreTrainedModel
3334
from transformers.modeling_outputs import BaseModelOutput
3435
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
3536

@@ -1418,3 +1419,57 @@ def __exit__(self, exc_type, exc_value, traceback):
14181419
from transformers.models.cohere.modeling_cohere import CohereRotaryEmbedding
14191420

14201421
CohereRotaryEmbedding.forward = self.original_forward
1422+
1423+
1424+
class Gemma3LMModelPatcher(DecoderModelPatcher):
1425+
"""Patcher for Gemma3 language model to handle cache conversion for ONNX export."""
1426+
1427+
def __init__(
1428+
self,
1429+
config,
1430+
model: Union[PreTrainedModel, TFPreTrainedModel],
1431+
model_kwargs: Optional[Dict[str, Any]] = None,
1432+
):
1433+
def forward(
1434+
self,
1435+
attention_mask,
1436+
position_ids,
1437+
past_key_values,
1438+
inputs_embeds,
1439+
use_cache=True,
1440+
):
1441+
from transformers.cache_utils import DynamicCache
1442+
1443+
pkv = DynamicCache.from_legacy_cache(past_key_values)
1444+
1445+
past_seen_tokens = past_key_values[0][0].shape[-2] if past_key_values is not None else 0
1446+
cache_position = torch.arange(
1447+
past_seen_tokens,
1448+
past_seen_tokens + inputs_embeds.shape[1],
1449+
device=inputs_embeds.device,
1450+
)
1451+
1452+
result = self.__orig_forward(
1453+
input_ids=None,
1454+
attention_mask=attention_mask,
1455+
position_ids=position_ids,
1456+
cache_position=cache_position,
1457+
past_key_values=pkv,
1458+
inputs_embeds=inputs_embeds,
1459+
use_cache=use_cache,
1460+
)
1461+
upd_pkv = result["past_key_values"]
1462+
result["past_key_values"] = upd_pkv.to_legacy_cache()
1463+
return result
1464+
1465+
if is_transformers_version("<", "4.53.0"):
1466+
model.__orig_forward = model.forward
1467+
model.forward = types.MethodType(forward, model)
1468+
1469+
super().__init__(config, model, model_kwargs)
1470+
1471+
def __exit__(self, exc_type, exc_value, traceback):
1472+
super().__exit__(exc_type, exc_value, traceback)
1473+
1474+
if is_transformers_version("<", "4.53.0"):
1475+
self._model.forward = self._model.__orig_forward

optimum/onnxruntime/modeling_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def __init__(
185185
"To re-export your model, simply set `export=True` as in `from_pretrained(..., export=True, use_cache=True)`."
186186
)
187187

188-
if self.config.model_type in {"gemma", "gpt_oss", "nemotron"}:
188+
if self.config.model_type in {"gemma", "gemma3", "gemma3_text", "gpt_oss", "nemotron"}:
189189
self.embed_size_per_head = self.config.head_dim
190190
elif self.old_gpt_bigcode_modeling:
191191
# (before v4.54) GPT BigCode fuses keys and values in one tensor, doubling the head dimension
@@ -202,6 +202,8 @@ def __init__(
202202
"deepseek_v3",
203203
"cohere",
204204
"gemma",
205+
"gemma3",
206+
"gemma3_text",
205207
"glm",
206208
"granite",
207209
"gpt_oss",

0 commit comments

Comments
 (0)