Skip to content

Commit 9744167

Browse files
committed
feat: add support for gemma3-text
simplify the example
1 parent 08bd601 commit 9744167

File tree

5 files changed

+3220
-2
lines changed

5 files changed

+3220
-2
lines changed

examples/gemma3.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Simple example: Export Gemma3 270M to ONNX and generate text.
2+
3+
Usage:
4+
uv pip install onnxruntime
5+
uv run examples/gemma3.py
6+
"""
7+
8+
from transformers import AutoTokenizer
9+
10+
from optimum.onnxruntime import ORTModelForCausalLM
11+
12+
13+
model_id = "google/gemma-3-270m-it"
14+
tokenizer = AutoTokenizer.from_pretrained(model_id)
15+
model = ORTModelForCausalLM.from_pretrained(model_id, export=True)
16+
17+
# Chat with instruction-tuned model
18+
conversation = [{"role": "user", "content": "Hello! How are you?"}]
19+
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
20+
inputs = tokenizer(prompt, return_tensors="pt")
21+
22+
outputs = model.generate(**inputs, max_new_tokens=50, pad_token_id=tokenizer.eos_token_id)
23+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
24+
25+
print(response)

optimum/exporters/onnx/model_configs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
CohereModelPatcher,
4444
FluxTransformerModelPatcher,
4545
MetaCLIP2Patcher,
46+
Gemma3LMModelPatcher,
4647
MgpstrModelPatcher,
4748
MoonshineModelPatcher,
4849
MusicgenModelPatcher,
@@ -517,6 +518,14 @@ class Gemma2OnnxConfig(TextDecoderOnnxConfig):
517518
MIN_TRANSFORMERS_VERSION = version.parse("4.53.0")
518519

519520

521+
@register_tasks_manager_onnx("gemma3", *COMMON_TEXT_GENERATION_TASKS)
522+
@register_tasks_manager_onnx("gemma3_text", *COMMON_TEXT_GENERATION_TASKS)
523+
class Gemma3OnnxConfig(GemmaOnnxConfig):
524+
"""ONNX config for Gemma3 text-only models."""
525+
MIN_TRANSFORMERS_VERSION = version.parse("4.52.0")
526+
_MODEL_PATCHER = Gemma3LMModelPatcher
527+
528+
520529
@register_tasks_manager_onnx("gpt_oss", *COMMON_TEXT_GENERATION_TASKS)
521530
class GPTOssOnnxConfig(GemmaOnnxConfig):
522531
MIN_TRANSFORMERS_VERSION = version.parse("4.55.0")

optimum/exporters/onnx/model_patcher.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import sys
2020
import types
2121
import warnings
22-
from typing import TYPE_CHECKING, Any, Callable
22+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
2323

2424
import torch
2525
import transformers
@@ -1457,3 +1457,57 @@ def __exit__(self, exc_type, exc_value, traceback):
14571457
from transformers.models.cohere.modeling_cohere import CohereRotaryEmbedding
14581458

14591459
CohereRotaryEmbedding.forward = self.original_forward
1460+
1461+
1462+
class Gemma3LMModelPatcher(DecoderModelPatcher):
1463+
"""Patcher for Gemma3 language model to handle cache conversion for ONNX export."""
1464+
1465+
def __init__(
1466+
self,
1467+
config,
1468+
model: Union[PreTrainedModel, TFPreTrainedModel],
1469+
model_kwargs: Optional[Dict[str, Any]] = None,
1470+
):
1471+
def forward(
1472+
self,
1473+
attention_mask,
1474+
position_ids,
1475+
past_key_values,
1476+
inputs_embeds,
1477+
use_cache=True,
1478+
):
1479+
from transformers.cache_utils import DynamicCache
1480+
1481+
pkv = DynamicCache.from_legacy_cache(past_key_values)
1482+
1483+
past_seen_tokens = past_key_values[0][0].shape[-2] if past_key_values is not None else 0
1484+
cache_position = torch.arange(
1485+
past_seen_tokens,
1486+
past_seen_tokens + inputs_embeds.shape[1],
1487+
device=inputs_embeds.device,
1488+
)
1489+
1490+
result = self.__orig_forward(
1491+
input_ids=None,
1492+
attention_mask=attention_mask,
1493+
position_ids=position_ids,
1494+
cache_position=cache_position,
1495+
past_key_values=pkv,
1496+
inputs_embeds=inputs_embeds,
1497+
use_cache=use_cache,
1498+
)
1499+
upd_pkv = result["past_key_values"]
1500+
result["past_key_values"] = upd_pkv.to_legacy_cache()
1501+
return result
1502+
1503+
if is_transformers_version("<", "4.53.0"):
1504+
model.__orig_forward = model.forward
1505+
model.forward = types.MethodType(forward, model)
1506+
1507+
super().__init__(config, model, model_kwargs)
1508+
1509+
def __exit__(self, exc_type, exc_value, traceback):
1510+
super().__exit__(exc_type, exc_value, traceback)
1511+
1512+
if is_transformers_version("<", "4.53.0"):
1513+
self._model.forward = self._model.__orig_forward

optimum/onnxruntime/modeling_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def __init__(
185185
"To re-export your model, simply set `export=True` as in `from_pretrained(..., export=True, use_cache=True)`."
186186
)
187187

188-
if self.config.model_type in {"gemma", "gpt_oss", "nemotron"}:
188+
if self.config.model_type in {"gemma", "gemma3", "gemma3_text", "gpt_oss", "nemotron"}:
189189
self.embed_size_per_head = self.config.head_dim
190190
elif self.old_gpt_bigcode_modeling:
191191
# (before v4.54) GPT BigCode fuses keys and values in one tensor, doubling the head dimension
@@ -202,6 +202,8 @@ def __init__(
202202
"deepseek_v3",
203203
"cohere",
204204
"gemma",
205+
"gemma3",
206+
"gemma3_text",
205207
"glm",
206208
"granite",
207209
"gpt_oss",

0 commit comments

Comments
 (0)