Skip to content

Commit 615cfee

Browse files
nyo16jonatanklosko
andauthored
Add Gemma 3 Text support (#436)
Co-authored-by: Jonatan Kłosko <[email protected]>
1 parent 2252d44 commit 615cfee

File tree

6 files changed

+1276
-1
lines changed

6 files changed

+1276
-1
lines changed

lib/bumblebee.ex

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ defmodule Bumblebee do
137137
"GemmaModel" => {Bumblebee.Text.Gemma, :base},
138138
"GemmaForCausalLM" => {Bumblebee.Text.Gemma, :for_causal_language_modeling},
139139
"GemmaForSequenceClassification" => {Bumblebee.Text.Gemma, :for_sequence_classification},
140+
"Gemma3ForCausalLM" => {Bumblebee.Text.Gemma3Text, :for_causal_language_modeling},
141+
"Gemma3TextModel" => {Bumblebee.Text.Gemma3Text, :base},
142+
"Gemma3TextForCausalLM" => {Bumblebee.Text.Gemma3Text, :for_causal_language_modeling},
143+
"Gemma3TextForSequenceClassification" =>
144+
{Bumblebee.Text.Gemma3Text, :for_sequence_classification},
140145
"GPT2ForSequenceClassification" => {Bumblebee.Text.Gpt2, :for_sequence_classification},
141146
"GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification},
142147
"GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling},
@@ -258,6 +263,7 @@ defmodule Bumblebee do
258263
"camembert" => :camembert,
259264
"clip" => :clip,
260265
"gemma" => :gemma,
266+
"gemma3_text" => :gemma,
261267
"gpt_neox" => :gpt_neo_x,
262268
"gpt2" => :gpt2,
263269
"gpt_bigcode" => :gpt2,

lib/bumblebee/layers/transformer.ex

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ defmodule Bumblebee.Layers.Transformer do
2525
- a keyword list (applied to all blocks)
2626
- a function that takes the block index and returns the configuration
2727
28+
* `:attention_window_size` - sliding window attention configuration. Can be:
29+
- `nil` for global attention (default)
30+
- a `{left, right}` tuple (applied to all blocks)
31+
- a function that takes the block index and returns `nil` or `{left, right}`.
32+
This enables per-layer attention patterns like Gemma 3's alternating
33+
local/global attention (5 local layers followed by 1 global layer)
34+
2835
* `:name` - the prefix for layer names
2936
3037
For all other options (including required options) see `block/2`.
@@ -36,6 +43,8 @@ defmodule Bumblebee.Layers.Transformer do
3643
def blocks(hidden_state, opts) do
3744
validate_required_keys!(opts, [:num_blocks, :num_attention_heads, :hidden_size, :ffn])
3845

46+
# Note: :attention_window_size is NOT in block_opts_keys because it's handled
47+
# specially (supports per-layer function) and passed explicitly to block/2
3948
block_opts_keys = [
4049
:num_attention_heads,
4150
:num_key_value_heads,
@@ -52,7 +61,6 @@ defmodule Bumblebee.Layers.Transformer do
5261
:output_use_bias,
5362
:layer_norm,
5463
:block_type,
55-
:attention_window_size,
5664
:attention_scale,
5765
:query_norm,
5866
:key_norm
@@ -66,6 +74,7 @@ defmodule Bumblebee.Layers.Transformer do
6674
:name,
6775
:num_blocks,
6876
:rotary_embedding,
77+
:attention_window_size,
6978
attention_mask: Layers.none(),
7079
attention_head_mask: Layers.none(),
7180
attention_relative_bias: nil,
@@ -87,6 +96,7 @@ defmodule Bumblebee.Layers.Transformer do
8796
cross_attention_head_mask = opts[:cross_attention_head_mask]
8897
cache = opts[:cache]
8998
rotary_embedding = opts[:rotary_embedding]
99+
attention_window_size = opts[:attention_window_size]
90100

91101
block_opts = Keyword.take(opts, block_opts_keys)
92102

@@ -123,6 +133,15 @@ defmodule Bumblebee.Layers.Transformer do
123133
config when is_list(config) -> config
124134
end
125135

136+
# Support per-layer attention window size for models like Gemma 3
137+
# that alternate between local (sliding window) and global attention
138+
block_attention_window_size =
139+
case attention_window_size do
140+
nil -> nil
141+
fun when is_function(fun, 1) -> fun.(idx)
142+
size -> size
143+
end
144+
126145
{hidden_state, attention, cross_attention, block_cache, attention_relative_bias} =
127146
block(
128147
state.hidden_state,
@@ -136,6 +155,7 @@ defmodule Bumblebee.Layers.Transformer do
136155
block_cache: block_cache,
137156
offset: offset,
138157
rotary_embedding: block_rotary_embedding,
158+
attention_window_size: block_attention_window_size,
139159
name: join(name, idx)
140160
] ++ block_opts
141161
)

0 commit comments

Comments
 (0)