Skip to content

Commit bae534a

Browse files
committed
Add Ministral support to Mistral model
This extends the existing Mistral model to support Ministral variants by adding: - `attention_head_size`: explicit head_dim support (Ministral uses 128) - `use_interleaved_attention`: even layers use global attention, odd layers use sliding window attention - `tie_word_embeddings`: share weights between embedding and lm_head Also adds function-based `attention_window_size` support in transformer.ex to enable per-layer attention window configuration. All changes are backward compatible with existing Mistral models.
1 parent 8365426 commit bae534a

File tree

3 files changed

+109
-8
lines changed

3 files changed

+109
-8
lines changed

lib/bumblebee/layers/transformer.ex

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ defmodule Bumblebee.Layers.Transformer do
2525
- a keyword list (applied to all blocks)
2626
- a function that takes the block index and returns the configuration
2727
28+
* `:attention_window_size` - window size for sliding attention. Can be:
29+
- a tuple `{left_size, right_size}` (applied to all blocks)
30+
- a function that takes the block index and returns the configuration
31+
(useful for interleaved attention patterns)
32+
- `nil` for global attention
33+
2834
* `:name` - the prefix for layer names
2935
3036
For all other options (including required options) see `block/2`.
@@ -52,7 +58,6 @@ defmodule Bumblebee.Layers.Transformer do
5258
:output_use_bias,
5359
:layer_norm,
5460
:block_type,
55-
:attention_window_size,
5661
:scale_attention_weights
5762
]
5863

@@ -64,6 +69,7 @@ defmodule Bumblebee.Layers.Transformer do
6469
:name,
6570
:num_blocks,
6671
:rotary_embedding,
72+
:attention_window_size,
6773
attention_mask: Layers.none(),
6874
attention_head_mask: Layers.none(),
6975
attention_relative_bias: nil,
@@ -85,6 +91,7 @@ defmodule Bumblebee.Layers.Transformer do
8591
cross_attention_head_mask = opts[:cross_attention_head_mask]
8692
cache = opts[:cache]
8793
rotary_embedding = opts[:rotary_embedding]
94+
attention_window_size = opts[:attention_window_size]
8895

8996
block_opts = Keyword.take(opts, block_opts_keys)
9097

@@ -121,6 +128,13 @@ defmodule Bumblebee.Layers.Transformer do
121128
config when is_list(config) -> config
122129
end
123130

131+
block_attention_window_size =
132+
case attention_window_size do
133+
nil -> nil
134+
fun when is_function(fun, 1) -> fun.(idx)
135+
config -> config
136+
end
137+
124138
{hidden_state, attention, cross_attention, block_cache, attention_relative_bias} =
125139
block(
126140
state.hidden_state,
@@ -134,6 +148,7 @@ defmodule Bumblebee.Layers.Transformer do
134148
block_cache: block_cache,
135149
offset: offset,
136150
rotary_embedding: block_rotary_embedding,
151+
attention_window_size: block_attention_window_size,
137152
name: join(name, idx)
138153
] ++ block_opts
139154
)

lib/bumblebee/text/mistral.ex

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,28 @@ defmodule Bumblebee.Text.Mistral do
4747
default: 4096,
4848
doc: "window size for both sides of the sliding attention window"
4949
],
50+
attention_head_size: [
51+
default: nil,
52+
doc: """
53+
the projection size for key, value, and query states per attention head.
54+
When `nil`, defaults to `hidden_size / num_attention_heads`. Ministral
55+
models use an explicit head_dim (typically 128) that differs from this default
56+
"""
57+
],
58+
use_interleaved_attention: [
59+
default: false,
60+
doc: """
61+
whether to use interleaved attention pattern. When enabled, even layers
62+
use global attention and odd layers use sliding window attention
63+
"""
64+
],
65+
tie_word_embeddings: [
66+
default: false,
67+
doc: """
68+
whether to tie the word embeddings with the language modeling head weights.
69+
When true, the lm_head uses the same weights as the token embedding layer
70+
"""
71+
],
5072
activation: [
5173
default: :silu,
5274
doc: "the activation function"
@@ -165,7 +187,8 @@ defmodule Bumblebee.Text.Mistral do
165187
Layers.Decoder.init_cache(batch_size, max_length,
166188
hidden_size: spec.hidden_size,
167189
decoder_num_attention_heads: spec.num_attention_heads,
168-
decoder_num_blocks: spec.num_blocks
190+
decoder_num_blocks: spec.num_blocks,
191+
attention_head_size: spec.attention_head_size
169192
)
170193
end
171194

@@ -315,6 +338,32 @@ defmodule Bumblebee.Text.Mistral do
315338
) do
316339
name = opts[:name]
317340

341+
# Build attention_window_size configuration
342+
# When interleaved attention is enabled, even layers use global attention
343+
# and odd layers use sliding window attention
344+
attention_window_size =
345+
cond do
346+
# If no sliding window is configured, use global attention for all layers
347+
spec.attention_window_size == nil ->
348+
nil
349+
350+
# Interleaved attention: even layers use global, odd layers use sliding window
351+
spec.use_interleaved_attention ->
352+
fn layer_idx ->
353+
if rem(layer_idx, 2) == 0 do
354+
# Even layers: global attention (no window)
355+
nil
356+
else
357+
# Odd layers: sliding window attention
358+
{spec.attention_window_size, spec.attention_window_size}
359+
end
360+
end
361+
362+
# Non-interleaved: apply sliding window to all layers
363+
true ->
364+
{spec.attention_window_size, spec.attention_window_size}
365+
end
366+
318367
Layers.Transformer.blocks(hidden_state,
319368
attention_mask: attention_mask,
320369
attention_head_mask: attention_head_mask,
@@ -323,6 +372,7 @@ defmodule Bumblebee.Text.Mistral do
323372
num_attention_heads: spec.num_attention_heads,
324373
num_key_value_heads: spec.num_key_value_heads,
325374
hidden_size: spec.hidden_size,
375+
attention_head_size: spec.attention_head_size,
326376
kernel_initializer: kernel_initializer(spec),
327377
layer_norm: &Layers.rms_norm(&1, name: &2, epsilon: spec.layer_norm_epsilon),
328378
ffn:
@@ -332,8 +382,7 @@ defmodule Bumblebee.Text.Mistral do
332382
),
333383
block_type: :norm_first,
334384
causal: true,
335-
attention_window_size:
336-
spec.attention_window_size && {spec.attention_window_size, spec.attention_window_size},
385+
attention_window_size: attention_window_size,
337386
rotary_embedding: [
338387
position_ids: position_ids,
339388
max_positions: spec.max_positions,
@@ -367,7 +416,6 @@ defmodule Bumblebee.Text.Mistral do
367416
defp language_modeling_head(hidden_state, spec, opts) do
368417
name = opts[:name]
369418

370-
# TODO: Tie lm-head to word embedding as a spec option
371419
Layers.dense_transposed(hidden_state, spec.vocab_size,
372420
kernel_initializer: kernel_initializer(spec),
373421
name: join(name, "output")
@@ -391,19 +439,22 @@ defmodule Bumblebee.Text.Mistral do
391439
num_attention_heads: {"num_attention_heads", number()},
392440
num_key_value_heads: {"num_key_value_heads", number()},
393441
attention_window_size: {"sliding_window", optional(number())},
442+
attention_head_size: {"head_dim", optional(number())},
443+
use_interleaved_attention: {"use_interleaved_attention", optional(boolean())},
394444
intermediate_size: {"intermediate_size", number()},
395445
activation: {"hidden_act", activation()},
396446
rotary_embedding_base: {"rope_theta", number()},
397447
initializer_scale: {"initializer_range", number()},
398-
layer_norm_epsilon: {"rms_norm_eps", number()}
448+
layer_norm_epsilon: {"rms_norm_eps", number()},
449+
tie_word_embeddings: {"tie_word_embeddings", boolean()}
399450
) ++ Shared.common_options_from_transformers(data, spec)
400451

401452
@for.config(spec, opts)
402453
end
403454
end
404455

405456
defimpl Bumblebee.HuggingFace.Transformers.Model do
406-
def params_mapping(_spec) do
457+
def params_mapping(spec) do
407458
%{
408459
"embedder.token_embedding" => "model.embed_tokens",
409460
"decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj",
@@ -416,7 +467,8 @@ defmodule Bumblebee.Text.Mistral do
416467
"decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.down_proj",
417468
"decoder.blocks.{n}.output_norm" => "model.layers.{n}.post_attention_layernorm",
418469
"output_norm" => "model.norm",
419-
"language_modeling_head.output" => "lm_head",
470+
"language_modeling_head.output" =>
471+
if(spec.tie_word_embeddings, do: "model.embed_tokens", else: "lm_head"),
420472
"sequence_classification_head.output" => "score"
421473
}
422474
end

test/bumblebee/text/mistral_test.exs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,40 @@ defmodule Bumblebee.Text.MistralTest do
5858
)
5959
end
6060

61+
test ":base with interleaved attention" do
62+
assert {:ok, spec} =
63+
Bumblebee.load_spec({:hf, "hf-internal-testing/tiny-random-MistralModel"})
64+
65+
# Enable interleaved attention: even layers use global, odd layers use sliding window
66+
spec = Bumblebee.configure(spec, attention_window_size: 2, use_interleaved_attention: true)
67+
68+
assert {:ok, %{model: model, params: params, spec: spec}} =
69+
Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MistralModel"},
70+
spec: spec
71+
)
72+
73+
assert %Bumblebee.Text.Mistral{architecture: :base, use_interleaved_attention: true} = spec
74+
75+
inputs = %{
76+
"input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]),
77+
"attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
78+
}
79+
80+
outputs = Axon.predict(model, params, inputs)
81+
82+
assert Nx.shape(outputs.hidden_state) == {1, 10, 32}
83+
84+
# With interleaved attention, even layers (0, 2, 4...) use global attention
85+
# and odd layers (1, 3, 5...) use sliding window attention
86+
# The output should be different from both pure global and pure sliding window
87+
assert_all_close(
88+
outputs.hidden_state[[.., 1..3, 1..3]],
89+
Nx.tensor([
90+
[[0.9450, -1.3945, 0.7331], [-2.1118, -1.3091, -0.7834], [-1.4057, -1.2495, 0.8730]]
91+
])
92+
)
93+
end
94+
6195
test ":for_sequence_classification" do
6296
assert {:ok, %{model: model, params: params, spec: spec}} =
6397
Bumblebee.load_model(

0 commit comments

Comments
 (0)