-
Notifications
You must be signed in to change notification settings - Fork 123
Add Gemma 3 Text support #436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Gemma 3 Text support #436
Conversation
1f4f847 to
13b3a33
Compare
|
I generated the tiny models with # Create tiny random Gemma 3 models for Bumblebee testing
# Run with: HF_TOKEN=hf_xxx python create_tiny_gemma3.py
import os
from huggingface_hub import login
# Login with token
token = os.environ.get("HF_TOKEN")
if token:
login(token=token)
else:
print("Warning: HF_TOKEN not set, using cached credentials")
from transformers import (
Gemma3TextConfig,
Gemma3TextModel,
Gemma3ForCausalLM, # No "Text" variant for CausalLM
Gemma3TextForSequenceClassification,
)
# Tiny config matching Gemma 3 text architecture
config = Gemma3TextConfig(
vocab_size=1024,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=8,
intermediate_size=64,
hidden_activation="gelu_pytorch_tanh",
max_position_embeddings=512,
rms_norm_eps=1e-6,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
sliding_window=128,
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
tie_word_embeddings=True,
initializer_range=0.02,
query_pre_attn_scalar=8,
)
# For sequence classification
config_seq_cls = Gemma3TextConfig(
vocab_size=1024,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
head_dim=8,
intermediate_size=64,
hidden_activation="gelu_pytorch_tanh",
max_position_embeddings=512,
rms_norm_eps=1e-6,
rope_theta=10000.0,
attention_bias=False,
attention_dropout=0.0,
sliding_window=128,
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
tie_word_embeddings=True,
initializer_range=0.02,
query_pre_attn_scalar=8,
num_labels=2,
)
models = [
(Gemma3TextModel, config, "tiny-random-Gemma3Model"),
(Gemma3ForCausalLM, config, "tiny-random-Gemma3ForCausalLM"),
(Gemma3TextForSequenceClassification, config_seq_cls, "tiny-random-Gemma3ForSequenceClassification"),
]
print("Creating tiny random Gemma 3 models...")
for model_class, model_config, name in models:
print(f"\nCreating {name}...")
model = model_class(model_config)
local_path = f"./{name}"
model.save_pretrained(local_path)
print(f" Saved to {local_path}")
repo_id = f"nmaroulis/{name}"
model.push_to_hub(repo_id)
print(f" Pushed to https://huggingface.co/{repo_id}")
print("\nDone!")
|
Gemma 3 architecture includes several key differences from Gemma v1: - QK-norm (RMS normalization on query/key after projection) - Pre/post FFN layer norms (pre_feedforward_layernorm, post_feedforward_layernorm) - Different residual connection order (after post_attention_layernorm) - Alternating local/global attention (sliding window) - RMS norm with shift=1.0 formula: output * (1.0 + weight) Files added: - lib/bumblebee/text/gemma3.ex: Full Gemma 3 model implementation - test/bumblebee/text/gemma3_test.exs: Unit tests - notebooks/function_calling.livemd: Livebook with FunctionGemma examples Files modified: - lib/bumblebee.ex: Model and tokenizer registrations - lib/bumblebee/layers/transformer.ex: Per-layer attention_window_size support
13b3a33 to
1fc7aaf
Compare
…ests - Refactor decoder to use shared Layers.Transformer.blocks infrastructure - Use per-layer attention_window_size function for alternating local/global attention - Use query_norm/key_norm options for QK-normalization - Use custom block_type function for Gemma 3's unique normalization structure - Add assert_all_close with reference values from Python transformers - Fix bug in Layers.Transformer.blocks where attention_window_size was duplicated when using a function for per-layer configuration - Update params_mapping to use query_norm/key_norm naming from shared infrastructure
|
Thanks for the review feedback! I've addressed both comments:
Added assert_all_close assertions with reference values obtained from Python transformers: import torch input_ids = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]) :basemodel = AutoModel.from_pretrained("nmaroulis/tiny-random-Gemma3Model") hidden_state[[.., 1..3, 1..3]]: [[-1.6458, 0.7249, -0.5747], [-1.9452, -0.1602, -0.2329], [-2.3408, -0.4665, -0.1177]]:for_causal_language_modelingmodel = AutoModelForCausalLM.from_pretrained("nmaroulis/tiny-random-Gemma3ForCausalLM") logits[[.., 1..3, 1..3]]: [[0.1472, 0.0633, 0.0922], [-0.1089, -0.0344, 0.0755], [0.0112, 0.1083, 0.1461]]:for_sequence_classificationmodel = AutoModelForSequenceClassification.from_pretrained("nmaroulis/tiny-random-Gemma3ForSequenceClassification") logits: [[-0.0060, -0.0212]]All tests now verify numerical equivalence with Python transformers.
Replaced the custom block iteration with Layers.Transformer.blocks. The key changes:
Layers.Transformer.blocks(hidden_state, The custom gemma3_block_impl/4 function handles Gemma 3's unique block structure while leveraging the shared attention infrastructure. |
|
Btw. I uploaded the tiny-random models to bumblebee-testing, so you can switch the tests to use those :) |
|
@nyo16 by the way, the gemma3 attention uses a configurable scalar here. I've just pushed a8caabd, which adds support for I updated the tiny-random checkpoints I pushed, so that |
- Rename Bumblebee.Text.Gemma3 to Bumblebee.Text.Gemma3Text to distinguish text-only model from future multimodal Gemma3 - Add attention_scale_base config option (from query_pre_attn_scalar) - Compute attention scale as attention_scale_base ** -0.5 - Update model mappings to use Gemma3Text* variants - Update tests to use bumblebee-testing models with Python reference values - Fix duplicate attention_window_size key in transformer.ex after merge
|
Thank so much @jonatanklosko I've updated the branch based on your feedback:
Note: There are still some small numerical differences between Elixir and Python outputs (max ~0.15), so I used slightly higher tolerances in the tests. This might be worth investigating further if needed, but the model works correctly with FunctionGemma examples. I tested with the code and go correct answer, if i find more time i will investigate more. |
Co-authored-by: Jonatan Kłosko <[email protected]>
Co-authored-by: Jonatan Kłosko <[email protected]>
Co-authored-by: Jonatan Kłosko <[email protected]>
Co-authored-by: Jonatan Kłosko <[email protected]>
lib/bumblebee/text/gemma3_text.ex
Outdated
| global_attention_layer_interval: [ | ||
| default: 6, | ||
| doc: """ | ||
| the interval for global attention layers. In Gemma 3, every Nth layer uses global | ||
| attention while others use local (sliding window) attention. A value of 6 means | ||
| layers 5, 11, 17, 23... use global attention (5:1 local/global ratio) | ||
| """ | ||
| ], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We never load this from the HuggingFace config, so it will cause discrepancies (could be the reason why tests failed).
In fact, the "interval" approach is deprecated, see this code.
I think we can go with option :layer_types, which is a list of either :sliding_attention or :full_attention. We have a similar config here:
bumblebee/lib/bumblebee/diffusion/controlnet.ex
Lines 414 to 422 in a8caabd
| up_block_types: { | |
| "up_block_types", | |
| list( | |
| mapping(%{ | |
| "UpBlock2D" => :up_block, | |
| "CrossAttnUpBlock2D" => :cross_attention_up_block | |
| }) | |
| ) | |
| }, |
We should also handle "sliding_window_pattern" if present. One way to do it would be to do something like this:
# Support sliding_window_pattern for backward compatibility, see https://github.com/huggingface/transformers/blob/v5.0.0rc1/src/transformers/models/gemma3/configuration_gemma3.py#L188-L195
data =
Map.put_new_lazy(data, "layer_types", fn ->
pattern = data["sliding_window_pattern"] || 6
# generate a list of Python-like layer types
end)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw. I pushed updated checkpoints to HF with:
"layer_types": [
"sliding_attention",
"full_attention"
],
This way we can test both layer types.
You will need to update the reference values.
We should be within 4 digit precision, or in other words we should not need FWIW. I generated the checkpoints using this config: from transformers import Gemma3TextConfig, Gemma3TextModel, Gemma3ForCausalLM, Gemma3TextForSequenceClassification
config = Gemma3TextConfig(
vocab_size=1024,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
num_key_value_heads=2,
intermediate_size=37,
hidden_act="gelu_pytorch_tanh",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=16,
is_decoder=False,
initializer_range=0.02,
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
sliding_window=128,
query_pre_attn_scalar=8,
sliding_window_pattern=2
# layer_types=[
# "sliding_attention",
# "full_attention"
# ]
)
for c in [Gemma3TextModel, Gemma3ForCausalLM, Gemma3TextForSequenceClassification]:
name = c.__name__
c(config).save_pretrained(f"bumblebee-testing/tiny-random-{name}", repo_id=f"bumblebee-testing/tiny-random-{name}", push_to_hub=True) |
- Add layer_types config option for per-layer attention type (sliding vs full) - Generate layer_types from sliding_window_pattern for backward compatibility - Update test expected values with Python reference outputs - Add tolerances to tests (base model needs larger tolerance, see TODO)
Gemma 3 uses different RoPE bases for local vs global attention: - Local (sliding) attention: rope_local_base_freq = 10,000 - Global (full) attention: rope_theta = 1,000,000 This fixes numerical discrepancies where positions 2+ diverged by ~0.2-0.3 from Python reference values. Now achieves 4-digit precision (atol: 1.0e-4) across all positions. Changes: - Add rotary_embedding_base_local spec option - Load rope_local_base_freq from HuggingFace config - Use per-layer rotary embedding base based on layer type - Tighten test tolerances from 0.35 to 1.0e-4
|
@jonatanklosko Thanks for the updated test models and pushing for better testing! Issue: Gemma 3 uses different RoPE bases for local vs global attention layers:
Bumblebee was using the global base (1M) for all layers, which caused positions 2+ to diverge significantly (~0.2-0.3). Fix: Added rotary_embedding_base_local config option and updated the decoder to use per-layer rotary embedding bases based on layer_types. All tests now pass with atol: 1.0e-4 (4-digit precision) as you expected. 🎉 |
|
Also the python code that i generated the test numbers #!/usr/bin/env python3
"""Generate reference values for Bumblebee Gemma3Text tests."""
import torch
from transformers import AutoModel, AutoModelForCausalLM, AutoModelForSequenceClassification
input_ids = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]])
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
# :base test
model = AutoModel.from_pretrained("bumblebee-testing/tiny-random-Gemma3TextModel")
model.eval()
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
print(":base - hidden_state[0, 1:4, 1:4]:")
print(outputs.last_hidden_state[0, 1:4, 1:4])
# :for_causal_language_modeling test
model = AutoModelForCausalLM.from_pretrained("bumblebee-testing/tiny-random-Gemma3ForCausalLM")
model.eval()
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
print("\n:for_causal_language_modeling - logits[0, 1:4, 1:4]:")
print(outputs.logits[0, 1:4, 1:4])
# :for_sequence_classification test
model = AutoModelForSequenceClassification.from_pretrained(
"bumblebee-testing/tiny-random-Gemma3TextForSequenceClassification"
)
model.eval()
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
print("\n:for_sequence_classification - logits:")
print(outputs.logits) |
| activation: {"hidden_activation", activation()}, | ||
| use_attention_bias: {"attention_bias", boolean()}, | ||
| rotary_embedding_base: {"rope_theta", number()}, | ||
| rotary_embedding_base_local: {"rope_local_base_freq", number()}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rope_local_base_freq does not actually exit on the config. Looks like now they have a whole configuration struct for rope, and in this case one per layer type. I will look into handling the new config shape.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, rope_local_base_freq is still present in the config on Transformers v4, and v5 is still a release candidate. I pushed a small update, but I think we can merge this, I will make the rope config update separately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Thank you @jonatanklosko
Adds support for Gemma 3 architecture, enabling FunctionGemma (
google/functiongemma-270m-it) and other Gemma 3 models to run in Bumblebee.Why FunctionGemma?
Gemma 3 Architecture Changes
Gemma 3 has several key differences from Gemma v1:
weight * normalized(1 + weight) * normalizedFiles Changed
lib/bumblebee/text/gemma3.ex- Full Gemma 3 model implementation with custom decoder supporting QK-norm and extra layer normslib/bumblebee.ex- Model and tokenizer registrations forGemma3Model,Gemma3ForCausalLM, etc.lib/bumblebee/layers/transformer.ex- Per-layerattention_window_sizecallback for alternating local/global attentiontest/bumblebee/text/gemma3_test.exs- Unit tests (require tiny-random models on HuggingFace)notebooks/function_calling.livemd- Comprehensive Livebook example with:FunctionGemma.Schema- Build function declarationsFunctionGemma.Parser- Parse function call responsesFunctionGemma.Executor- Execute parsed callsSmartHome- Mock functions demo (lights, thermostat, weather)Example Usage
Test Plan