-
Notifications
You must be signed in to change notification settings - Fork 25
Add support for gemma3-text #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for gemma3-text #70
Conversation
97a3e2c to
3f08a26
Compare
|
Happy to add more examples if they’re welcomed. for example, this one is hugely useful for fine-tuned models with LoRA """Simple example: Export Gemma3 270M with LoRA adapter to ONNX and generate text.
Usage:
uv pip install onnxruntime peft
uv run examples/gemma3.py
"""
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from optimum.exporters.onnx import onnx_export_from_model
from optimum.onnxruntime import ORTModelForCausalLM
import time
# Load base model and merge with LoRA adapter
base_model_id = "google/gemma-3-270m-it" # The base model for your LoRA
adapter_id = "thewh1teagle/gemma3-heb-g2p"
base_model = AutoModelForCausalLM.from_pretrained(base_model_id)
model = PeftModel.from_pretrained(base_model, adapter_id)
model = model.merge_and_unload() # Merge LoRA weights into base model
tokenizer = AutoTokenizer.from_pretrained(adapter_id)
# Export merged model to ONNX
print("Exporting to ONNX...")
output_dir = "gemma3_onnx"
onnx_export_from_model(
model=model,
output=output_dir,
task="text-generation-with-past"
)
# Save tokenizer to the same directory
tokenizer.save_pretrained(output_dir)
# Load the exported ONNX model
ort_model = ORTModelForCausalLM.from_pretrained(output_dir)
# Chat with instruction-tuned model
system_message = """Given the following Hebrew sentence, convert it to IPA phonemes.
Input Format: A Hebrew sentence.
Output Format: A string of IPA phonemes.
"""
user_prompt = "אז מה דעתך, האם אתה יודע לדבר עברית גם כמו שאני יודע לדבר או שאתה לא?"
conversation = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_prompt}
]
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt")
# Generate with parameters similar to the working Ollama script
start_time = time.time()
outputs = ort_model.generate(
**inputs,
max_new_tokens=150,
temperature=0.9,
top_p=0.95,
top_k=64,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.convert_tokens_to_ids(["<end_of_turn>", "</s>"])
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the model's response (after the last "model" turn)
if "<start_of_turn>model" in response:
response = response.split("<start_of_turn>model")[-1].strip()
# Remove any end tokens
for end_token in ["<end_of_turn>", "</s>"]:
response = response.replace(end_token, "")
print(response.strip())
print(f"Time taken: {time.time() - start_time:.2f} seconds") |
|
Looking forward to gemma3n multimodal support |
|
Thanks for the addition ! I don't think an example script is the best way, maybe it would be better to add the snippet in the documentation under a relevant section or make it into a notebook can also be very useful ! |
| CohereRotaryEmbedding.forward = self.original_forward | ||
|
|
||
|
|
||
| class Gemma3LMModelPatcher(DecoderModelPatcher): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you try exporting without this patcher ? (it might not be necessary for text only generation)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yup, u right
I removed them and tests are working
uv.lock
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for the lock file 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
3f08a26 to
e663633
Compare
|
@IlyasMoutawwakil |
yeah ofc feel free, my proposition is to simply put it in the docs for better viz
yes needs to add testing (you can see how it's added in https://github.com/huggingface/optimum-onnx/pull/43/files 🤗) |
|
@thewh1teagle Build script from Xenova: |
simplify the example
7c35fdf to
9824aec
Compare
|
@IlyasMoutawwakil |
|
Added tests and verified with uv run --extra tests --extra onnxruntime pytest tests/onnxruntime/test_decoder.py -k "gemma3" -v |
|
@IlyasMoutawwakil |
We do have docs 😥 the repo's mian page links to it right under the description and also in the readme if you click on "Documentation" |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@IlyasMoutawwakil
many unclear navigation buttons |
|
@thewh1teagle what's not clear exactly ? btw it's open source so contributions are welcome https://github.com/huggingface/doc-builder |
|
@bot \style |
|
hi @thewh1teagle does that also work with gemma3 4b model? |
|
@IlyasMoutawwakil my feedback was just to let you know so you can potentially improve it. not a complaint! :) I really appreciate the work you’re doing on this open source library (and open source projects in general from HF) |
| @register_tasks_manager_onnx("gemma3", *COMMON_TEXT_GENERATION_TASKS) | ||
| @register_tasks_manager_onnx("gemma3_text", *COMMON_TEXT_GENERATION_TASKS) | ||
| class Gemma3OnnxConfig(GemmaOnnxConfig): | ||
| MIN_TRANSFORMERS_VERSION = version.parse("4.52.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason why not 4.50.0 ?
https://github.com/huggingface/transformers/blob/v4.50.0/src/transformers/models/gemma3/modeling_gemma3.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I bumped to 4.53.0 and added a comment about why
|
|
||
| @register_tasks_manager_onnx("gemma", *[*COMMON_TEXT_GENERATION_TASKS, "text-classification"]) | ||
| class GemmaOnnxConfig(LlamaOnnxConfig): | ||
| class GemmaOnnxConfig(TextDecoderOnnxConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I discovered that gemma models in general don't need the position ids argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@echarlaix wdyt ? this also removes the need for position ids from gpt_oss and nemotron
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the contribution ! I made some changes making sure the minimal transformers version passes all tests !


Added support for gemma3-text following the code in:
also added a working example with
gemma3-270m-instructwill update and improve as needed.
Related