@@ -47,6 +47,28 @@ defmodule Bumblebee.Text.Mistral do
4747 default: 4096 ,
4848 doc: "window size for both sides of the sliding attention window"
4949 ] ,
50+ attention_head_size: [
51+ default: nil ,
52+ doc: """
53+ the projection size for key, value, and query states per attention head.
54+ When `nil`, defaults to `hidden_size / num_attention_heads`. Ministral
55+ models use an explicit head_dim (typically 128) that differs from this default
56+ """
57+ ] ,
58+ use_interleaved_attention: [
59+ default: false ,
60+ doc: """
61+ whether to use interleaved attention pattern. When enabled, even layers
62+ use global attention and odd layers use sliding window attention
63+ """
64+ ] ,
65+ tie_word_embeddings: [
66+ default: false ,
67+ doc: """
68+ whether to tie the word embeddings with the language modeling head weights.
69+ When true, the lm_head uses the same weights as the token embedding layer
70+ """
71+ ] ,
5072 activation: [
5173 default: :silu ,
5274 doc: "the activation function"
@@ -165,7 +187,8 @@ defmodule Bumblebee.Text.Mistral do
165187 Layers.Decoder . init_cache ( batch_size , max_length ,
166188 hidden_size: spec . hidden_size ,
167189 decoder_num_attention_heads: spec . num_attention_heads ,
168- decoder_num_blocks: spec . num_blocks
190+ decoder_num_blocks: spec . num_blocks ,
191+ attention_head_size: spec . attention_head_size
169192 )
170193 end
171194
@@ -315,6 +338,32 @@ defmodule Bumblebee.Text.Mistral do
315338 ) do
316339 name = opts [ :name ]
317340
341+ # Build attention_window_size configuration
342+ # When interleaved attention is enabled, even layers use global attention
343+ # and odd layers use sliding window attention
344+ attention_window_size =
345+ cond do
346+ # If no sliding window is configured, use global attention for all layers
347+ spec . attention_window_size == nil ->
348+ nil
349+
350+ # Interleaved attention: even layers use global, odd layers use sliding window
351+ spec . use_interleaved_attention ->
352+ fn layer_idx ->
353+ if rem ( layer_idx , 2 ) == 0 do
354+ # Even layers: global attention (no window)
355+ nil
356+ else
357+ # Odd layers: sliding window attention
358+ { spec . attention_window_size , spec . attention_window_size }
359+ end
360+ end
361+
362+ # Non-interleaved: apply sliding window to all layers
363+ true ->
364+ { spec . attention_window_size , spec . attention_window_size }
365+ end
366+
318367 Layers.Transformer . blocks ( hidden_state ,
319368 attention_mask: attention_mask ,
320369 attention_head_mask: attention_head_mask ,
@@ -323,6 +372,7 @@ defmodule Bumblebee.Text.Mistral do
323372 num_attention_heads: spec . num_attention_heads ,
324373 num_key_value_heads: spec . num_key_value_heads ,
325374 hidden_size: spec . hidden_size ,
375+ attention_head_size: spec . attention_head_size ,
326376 kernel_initializer: kernel_initializer ( spec ) ,
327377 layer_norm: & Layers . rms_norm ( & 1 , name: & 2 , epsilon: spec . layer_norm_epsilon ) ,
328378 ffn:
@@ -332,8 +382,7 @@ defmodule Bumblebee.Text.Mistral do
332382 ) ,
333383 block_type: :norm_first ,
334384 causal: true ,
335- attention_window_size:
336- spec . attention_window_size && { spec . attention_window_size , spec . attention_window_size } ,
385+ attention_window_size: attention_window_size ,
337386 rotary_embedding: [
338387 position_ids: position_ids ,
339388 max_positions: spec . max_positions ,
@@ -367,7 +416,6 @@ defmodule Bumblebee.Text.Mistral do
367416 defp language_modeling_head ( hidden_state , spec , opts ) do
368417 name = opts [ :name ]
369418
370- # TODO: Tie lm-head to word embedding as a spec option
371419 Layers . dense_transposed ( hidden_state , spec . vocab_size ,
372420 kernel_initializer: kernel_initializer ( spec ) ,
373421 name: join ( name , "output" )
@@ -391,19 +439,22 @@ defmodule Bumblebee.Text.Mistral do
391439 num_attention_heads: { "num_attention_heads" , number ( ) } ,
392440 num_key_value_heads: { "num_key_value_heads" , number ( ) } ,
393441 attention_window_size: { "sliding_window" , optional ( number ( ) ) } ,
442+ attention_head_size: { "head_dim" , optional ( number ( ) ) } ,
443+ use_interleaved_attention: { "use_interleaved_attention" , optional ( boolean ( ) ) } ,
394444 intermediate_size: { "intermediate_size" , number ( ) } ,
395445 activation: { "hidden_act" , activation ( ) } ,
396446 rotary_embedding_base: { "rope_theta" , number ( ) } ,
397447 initializer_scale: { "initializer_range" , number ( ) } ,
398- layer_norm_epsilon: { "rms_norm_eps" , number ( ) }
448+ layer_norm_epsilon: { "rms_norm_eps" , number ( ) } ,
449+ tie_word_embeddings: { "tie_word_embeddings" , boolean ( ) }
399450 ) ++ Shared . common_options_from_transformers ( data , spec )
400451
401452 @ for . config ( spec , opts )
402453 end
403454 end
404455
405456 defimpl Bumblebee.HuggingFace.Transformers.Model do
406- def params_mapping ( _spec ) do
457+ def params_mapping ( spec ) do
407458 % {
408459 "embedder.token_embedding" => "model.embed_tokens" ,
409460 "decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj" ,
@@ -416,7 +467,8 @@ defmodule Bumblebee.Text.Mistral do
416467 "decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.down_proj" ,
417468 "decoder.blocks.{n}.output_norm" => "model.layers.{n}.post_attention_layernorm" ,
418469 "output_norm" => "model.norm" ,
419- "language_modeling_head.output" => "lm_head" ,
470+ "language_modeling_head.output" =>
471+ if ( spec . tie_word_embeddings , do: "model.embed_tokens" , else: "lm_head" ) ,
420472 "sequence_classification_head.output" => "score"
421473 }
422474 end
0 commit comments