@@ -25,6 +25,13 @@ defmodule Bumblebee.Layers.Transformer do
2525 - a keyword list (applied to all blocks)
2626 - a function that takes the block index and returns the configuration
2727
28+ * `:attention_window_size` - sliding window attention configuration. Can be:
29+ - `nil` for global attention (default)
30+ - a `{left, right}` tuple (applied to all blocks)
31+ - a function that takes the block index and returns `nil` or `{left, right}`.
32+ This enables per-layer attention patterns like Gemma 3's alternating
33+ local/global attention (5 local layers followed by 1 global layer)
34+
2835 * `:name` - the prefix for layer names
2936
3037 For all other options (including required options) see `block/2`.
@@ -36,6 +43,8 @@ defmodule Bumblebee.Layers.Transformer do
3643 def blocks ( hidden_state , opts ) do
3744 validate_required_keys! ( opts , [ :num_blocks , :num_attention_heads , :hidden_size , :ffn ] )
3845
46+ # Note: :attention_window_size is NOT in block_opts_keys because it's handled
47+ # specially (supports per-layer function) and passed explicitly to block/2
3948 block_opts_keys = [
4049 :num_attention_heads ,
4150 :num_key_value_heads ,
@@ -52,7 +61,6 @@ defmodule Bumblebee.Layers.Transformer do
5261 :output_use_bias ,
5362 :layer_norm ,
5463 :block_type ,
55- :attention_window_size ,
5664 :attention_scale ,
5765 :query_norm ,
5866 :key_norm
@@ -66,6 +74,7 @@ defmodule Bumblebee.Layers.Transformer do
6674 :name ,
6775 :num_blocks ,
6876 :rotary_embedding ,
77+ :attention_window_size ,
6978 attention_mask: Layers . none ( ) ,
7079 attention_head_mask: Layers . none ( ) ,
7180 attention_relative_bias: nil ,
@@ -87,6 +96,7 @@ defmodule Bumblebee.Layers.Transformer do
8796 cross_attention_head_mask = opts [ :cross_attention_head_mask ]
8897 cache = opts [ :cache ]
8998 rotary_embedding = opts [ :rotary_embedding ]
99+ attention_window_size = opts [ :attention_window_size ]
90100
91101 block_opts = Keyword . take ( opts , block_opts_keys )
92102
@@ -123,6 +133,15 @@ defmodule Bumblebee.Layers.Transformer do
123133 config when is_list ( config ) -> config
124134 end
125135
136+ # Support per-layer attention window size for models like Gemma 3
137+ # that alternate between local (sliding window) and global attention
138+ block_attention_window_size =
139+ case attention_window_size do
140+ nil -> nil
141+ fun when is_function ( fun , 1 ) -> fun . ( idx )
142+ size -> size
143+ end
144+
126145 { hidden_state , attention , cross_attention , block_cache , attention_relative_bias } =
127146 block (
128147 state . hidden_state ,
@@ -136,6 +155,7 @@ defmodule Bumblebee.Layers.Transformer do
136155 block_cache: block_cache ,
137156 offset: offset ,
138157 rotary_embedding: block_rotary_embedding ,
158+ attention_window_size: block_attention_window_size ,
139159 name: join ( name , idx )
140160 ] ++ block_opts
141161 )
0 commit comments