Skip to content

Commit 1fc7aaf

Browse files
committed
Add Gemma 3 support for FunctionGemma and other Gemma 3 models
Gemma 3 architecture includes several key differences from Gemma v1: - QK-norm (RMS normalization on query/key after projection) - Pre/post FFN layer norms (pre_feedforward_layernorm, post_feedforward_layernorm) - Different residual connection order (after post_attention_layernorm) - Alternating local/global attention (sliding window) - RMS norm with shift=1.0 formula: output * (1.0 + weight) Files added: - lib/bumblebee/text/gemma3.ex: Full Gemma 3 model implementation - test/bumblebee/text/gemma3_test.exs: Unit tests - notebooks/function_calling.livemd: Livebook with FunctionGemma examples Files modified: - lib/bumblebee.ex: Model and tokenizer registrations - lib/bumblebee/layers/transformer.ex: Per-layer attention_window_size support
1 parent 55ec9ac commit 1fc7aaf

File tree

5 files changed

+1363
-1
lines changed

5 files changed

+1363
-1
lines changed

lib/bumblebee.ex

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,13 @@ defmodule Bumblebee do
137137
"GemmaModel" => {Bumblebee.Text.Gemma, :base},
138138
"GemmaForCausalLM" => {Bumblebee.Text.Gemma, :for_causal_language_modeling},
139139
"GemmaForSequenceClassification" => {Bumblebee.Text.Gemma, :for_sequence_classification},
140+
"Gemma3Model" => {Bumblebee.Text.Gemma3, :base},
141+
"Gemma3ForCausalLM" => {Bumblebee.Text.Gemma3, :for_causal_language_modeling},
142+
"Gemma3ForSequenceClassification" => {Bumblebee.Text.Gemma3, :for_sequence_classification},
143+
"Gemma3TextModel" => {Bumblebee.Text.Gemma3, :base},
144+
"Gemma3TextForCausalLM" => {Bumblebee.Text.Gemma3, :for_causal_language_modeling},
145+
"Gemma3TextForSequenceClassification" =>
146+
{Bumblebee.Text.Gemma3, :for_sequence_classification},
140147
"GPT2ForSequenceClassification" => {Bumblebee.Text.Gpt2, :for_sequence_classification},
141148
"GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification},
142149
"GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling},
@@ -249,6 +256,8 @@ defmodule Bumblebee do
249256
"camembert" => :camembert,
250257
"clip" => :clip,
251258
"gemma" => :gemma,
259+
"gemma3" => :gemma,
260+
"gemma3_text" => :gemma,
252261
"gpt_neox" => :gpt_neo_x,
253262
"gpt2" => :gpt2,
254263
"gpt_bigcode" => :gpt2,

lib/bumblebee/layers/transformer.ex

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ defmodule Bumblebee.Layers.Transformer do
2525
- a keyword list (applied to all blocks)
2626
- a function that takes the block index and returns the configuration
2727
28+
* `:attention_window_size` - sliding window attention configuration. Can be:
29+
- `nil` for global attention (default)
30+
- a `{left, right}` tuple (applied to all blocks)
31+
- a function that takes the block index and returns `nil` or `{left, right}`.
32+
This enables per-layer attention patterns like Gemma 3's alternating
33+
local/global attention (5 local layers followed by 1 global layer)
34+
2835
* `:name` - the prefix for layer names
2936
3037
For all other options (including required options) see `block/2`.
@@ -52,7 +59,6 @@ defmodule Bumblebee.Layers.Transformer do
5259
:output_use_bias,
5360
:layer_norm,
5461
:block_type,
55-
:attention_window_size,
5662
:scale_attention_weights
5763
]
5864

@@ -64,6 +70,7 @@ defmodule Bumblebee.Layers.Transformer do
6470
:name,
6571
:num_blocks,
6672
:rotary_embedding,
73+
:attention_window_size,
6774
attention_mask: Layers.none(),
6875
attention_head_mask: Layers.none(),
6976
attention_relative_bias: nil,
@@ -85,6 +92,7 @@ defmodule Bumblebee.Layers.Transformer do
8592
cross_attention_head_mask = opts[:cross_attention_head_mask]
8693
cache = opts[:cache]
8794
rotary_embedding = opts[:rotary_embedding]
95+
attention_window_size = opts[:attention_window_size]
8896

8997
block_opts = Keyword.take(opts, block_opts_keys)
9098

@@ -121,6 +129,15 @@ defmodule Bumblebee.Layers.Transformer do
121129
config when is_list(config) -> config
122130
end
123131

132+
# Support per-layer attention window size for models like Gemma 3
133+
# that alternate between local (sliding window) and global attention
134+
block_attention_window_size =
135+
case attention_window_size do
136+
nil -> nil
137+
fun when is_function(fun, 1) -> fun.(idx)
138+
size -> size
139+
end
140+
124141
{hidden_state, attention, cross_attention, block_cache, attention_relative_bias} =
125142
block(
126143
state.hidden_state,
@@ -134,6 +151,7 @@ defmodule Bumblebee.Layers.Transformer do
134151
block_cache: block_cache,
135152
offset: offset,
136153
rotary_embedding: block_rotary_embedding,
154+
attention_window_size: block_attention_window_size,
137155
name: join(name, idx)
138156
] ++ block_opts
139157
)

0 commit comments

Comments
 (0)