Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ defmodule Bumblebee do
"GemmaModel" => {Bumblebee.Text.Gemma, :base},
"GemmaForCausalLM" => {Bumblebee.Text.Gemma, :for_causal_language_modeling},
"GemmaForSequenceClassification" => {Bumblebee.Text.Gemma, :for_sequence_classification},
"Gemma3ForCausalLM" => {Bumblebee.Text.Gemma3Text, :for_causal_language_modeling},
"Gemma3TextModel" => {Bumblebee.Text.Gemma3Text, :base},
"Gemma3TextForCausalLM" => {Bumblebee.Text.Gemma3Text, :for_causal_language_modeling},
"Gemma3TextForSequenceClassification" =>
{Bumblebee.Text.Gemma3Text, :for_sequence_classification},
"GPT2ForSequenceClassification" => {Bumblebee.Text.Gpt2, :for_sequence_classification},
"GPT2ForTokenClassification" => {Bumblebee.Text.Gpt2, :for_token_classification},
"GPT2LMHeadModel" => {Bumblebee.Text.Gpt2, :for_causal_language_modeling},
Expand Down Expand Up @@ -252,6 +257,7 @@ defmodule Bumblebee do
"camembert" => :camembert,
"clip" => :clip,
"gemma" => :gemma,
"gemma3_text" => :gemma,
"gpt_neox" => :gpt_neo_x,
"gpt2" => :gpt2,
"gpt_bigcode" => :gpt2,
Expand Down
22 changes: 21 additions & 1 deletion lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ defmodule Bumblebee.Layers.Transformer do
- a keyword list (applied to all blocks)
- a function that takes the block index and returns the configuration

* `:attention_window_size` - sliding window attention configuration. Can be:
- `nil` for global attention (default)
- a `{left, right}` tuple (applied to all blocks)
- a function that takes the block index and returns `nil` or `{left, right}`.
This enables per-layer attention patterns like Gemma 3's alternating
local/global attention (5 local layers followed by 1 global layer)

* `:name` - the prefix for layer names

For all other options (including required options) see `block/2`.
Expand All @@ -36,6 +43,8 @@ defmodule Bumblebee.Layers.Transformer do
def blocks(hidden_state, opts) do
validate_required_keys!(opts, [:num_blocks, :num_attention_heads, :hidden_size, :ffn])

# Note: :attention_window_size is NOT in block_opts_keys because it's handled
# specially (supports per-layer function) and passed explicitly to block/2
block_opts_keys = [
:num_attention_heads,
:num_key_value_heads,
Expand All @@ -52,7 +61,6 @@ defmodule Bumblebee.Layers.Transformer do
:output_use_bias,
:layer_norm,
:block_type,
:attention_window_size,
:attention_scale,
:query_norm,
:key_norm
Expand All @@ -66,6 +74,7 @@ defmodule Bumblebee.Layers.Transformer do
:name,
:num_blocks,
:rotary_embedding,
:attention_window_size,
attention_mask: Layers.none(),
attention_head_mask: Layers.none(),
attention_relative_bias: nil,
Expand All @@ -87,6 +96,7 @@ defmodule Bumblebee.Layers.Transformer do
cross_attention_head_mask = opts[:cross_attention_head_mask]
cache = opts[:cache]
rotary_embedding = opts[:rotary_embedding]
attention_window_size = opts[:attention_window_size]

block_opts = Keyword.take(opts, block_opts_keys)

Expand Down Expand Up @@ -123,6 +133,15 @@ defmodule Bumblebee.Layers.Transformer do
config when is_list(config) -> config
end

# Support per-layer attention window size for models like Gemma 3
# that alternate between local (sliding window) and global attention
block_attention_window_size =
case attention_window_size do
nil -> nil
fun when is_function(fun, 1) -> fun.(idx)
size -> size
end

{hidden_state, attention, cross_attention, block_cache, attention_relative_bias} =
block(
state.hidden_state,
Expand All @@ -136,6 +155,7 @@ defmodule Bumblebee.Layers.Transformer do
block_cache: block_cache,
offset: offset,
rotary_embedding: block_rotary_embedding,
attention_window_size: block_attention_window_size,
name: join(name, idx)
] ++ block_opts
)
Expand Down
Loading
Loading