11use std:: collections:: HashMap ;
22
33use crate :: flash_attn:: flash_attn_varlen;
4- use crate :: layers:: { apply_rotary , get_cos_sin, get_inv_freqs, LayerNorm , Linear } ;
4+ use crate :: layers:: { get_cos_sin, get_inv_freqs, LayerNormNoBias , Linear } ;
55use crate :: models:: modernbert:: {
66 ClassificationHead , ModernBertClassificationHead , ModernBertConfig , ModernBertEmbeddings ,
77 ModernBertMLP ,
88} ;
99use crate :: models:: Model ;
1010use candle:: { DType , Device , IndexOp , Result , Tensor } ;
1111use candle_nn:: VarBuilder ;
12+ use candle_rotary:: apply_rotary_inplace;
1213use text_embeddings_backend_core:: { Batch , ModelType , Pool } ;
1314
1415struct ModernBertAttention {
@@ -79,35 +80,34 @@ impl ModernBertAttention {
7980 new_qkv_shape. pop ( ) ;
8081 new_qkv_shape. push ( self . num_attention_heads * 3 ) ;
8182 new_qkv_shape. push ( self . attention_head_size ) ;
82- let qkv = qkv. reshape ( new_qkv_shape. as_slice ( ) ) ?. transpose ( 1 , 2 ) ? ;
83+ let qkv = qkv. reshape ( new_qkv_shape. as_slice ( ) ) ?;
8384
84- let qkv = qkv. chunk ( 3 , 1 ) ? ;
85- let query_layer = & qkv[ 0 ] . contiguous ( ) ?;
86- let key_layer = & qkv[ 1 ] . contiguous ( ) ?;
87- let value_layer = & qkv[ 2 ] ;
85+ // Split qkv tensor
86+ let q = qkv. narrow ( 1 , 0 , self . num_attention_heads ) ?;
87+ let k = qkv. narrow ( 1 , self . num_attention_heads , self . num_attention_heads ) ?;
88+ let v = qkv. narrow ( 1 , self . num_attention_heads * 2 , self . num_attention_heads ) ? ;
8889
89- let query_layer = apply_rotary ( query_layer, cos, sin, self . attention_head_size ) ?;
90- let key_layer = apply_rotary ( key_layer, cos, sin, self . attention_head_size ) ?;
90+ apply_rotary_inplace ( & q, & k, & cos, & sin, true ) ?;
9191
92- let attention_size = if self . use_local_attention {
92+ let window_size = if self . use_local_attention {
9393 Some ( self . local_attention )
9494 } else {
9595 None
9696 } ;
9797
9898 let attention = flash_attn_varlen (
99- & query_layer ,
100- & key_layer ,
101- & value_layer ,
99+ & q ,
100+ & k ,
101+ & v ,
102102 None ,
103103 cu_seqlens,
104104 cu_seqlens,
105105 max_s,
106106 max_s,
107107 self . softmax_scale ,
108108 false ,
109- attention_size ,
110- attention_size ,
109+ window_size ,
110+ window_size ,
111111 ) ?;
112112 let attention = attention. flatten_from ( candle:: D :: Minus2 ) ?;
113113
@@ -118,9 +118,9 @@ impl ModernBertAttention {
118118}
119119
120120struct ModernBertEncoderLayer {
121- attn_norm : Option < LayerNorm > ,
121+ attn_norm : Option < LayerNormNoBias > ,
122122 attn : ModernBertAttention ,
123- mlp_norm : LayerNorm ,
123+ mlp_norm : LayerNormNoBias ,
124124 mlp : ModernBertMLP ,
125125
126126 span : tracing:: Span ,
@@ -129,7 +129,7 @@ struct ModernBertEncoderLayer {
129129impl ModernBertEncoderLayer {
130130 pub fn load ( vb : VarBuilder , index : usize , config : & ModernBertConfig ) -> Result < Self > {
131131 let attn_norm = if index != 0 {
132- Some ( LayerNorm :: load (
132+ Some ( LayerNormNoBias :: load (
133133 vb. pp ( "attn_norm" ) ,
134134 config. hidden_size ,
135135 config. norm_eps as f32 ,
@@ -140,7 +140,7 @@ impl ModernBertEncoderLayer {
140140
141141 let attn = ModernBertAttention :: load ( vb. pp ( "attn" ) , index, config) ?;
142142
143- let mlp_norm = LayerNorm :: load (
143+ let mlp_norm = LayerNormNoBias :: load (
144144 vb. pp ( "mlp_norm" ) ,
145145 config. hidden_size ,
146146 config. norm_eps as f32 ,
@@ -236,11 +236,10 @@ impl ModernBertEncoder {
236236pub struct FlashModernBertModel {
237237 embeddings : ModernBertEmbeddings ,
238238 encoder : ModernBertEncoder ,
239- final_norm : LayerNorm ,
239+ final_norm : LayerNormNoBias ,
240240 pool : Pool ,
241241 classifier : Option < Box < dyn ClassificationHead + Send > > ,
242242
243- rotary_dim : usize ,
244243 rotary_cache : HashMap < bool , ( Tensor , Tensor ) > ,
245244
246245 device : Device ,
@@ -277,13 +276,22 @@ impl FlashModernBertModel {
277276 }
278277 } ;
279278
280- let embeddings = ModernBertEmbeddings :: load ( vb. pp ( "model.embeddings" ) , config) ?;
281- let encoder = ModernBertEncoder :: load ( vb. pp ( "model.layers" ) , config) ?;
282- let final_norm = LayerNorm :: load (
279+ let embeddings = ModernBertEmbeddings :: load ( vb. pp ( "model.embeddings" ) , config)
280+ . or_else ( |_| ModernBertEmbeddings :: load ( vb. pp ( "embeddings" ) , config) ) ?;
281+ let encoder = ModernBertEncoder :: load ( vb. pp ( "model.layers" ) , config)
282+ . or_else ( |_| ModernBertEncoder :: load ( vb. pp ( "layers" ) , config) ) ?;
283+ let final_norm = LayerNormNoBias :: load (
283284 vb. pp ( "model.final_norm" ) ,
284285 config. hidden_size ,
285286 config. norm_eps as f32 ,
286- ) ?;
287+ )
288+ . or_else ( |_| {
289+ LayerNormNoBias :: load (
290+ vb. pp ( "final_norm" ) ,
291+ config. hidden_size ,
292+ config. norm_eps as f32 ,
293+ )
294+ } ) ?;
287295
288296 let rotary_dim = config. hidden_size / config. num_attention_heads ;
289297 let mut rotary_cache: HashMap < bool , ( Tensor , Tensor ) > = HashMap :: new ( ) ;
@@ -295,15 +303,11 @@ impl FlashModernBertModel {
295303 config. global_rope_theta
296304 } ;
297305
298- let max_position_embeddings = if use_local_attention {
299- config. max_position_embeddings
300- } else {
301- config. local_attention
302- } ;
306+ let max_position_embeddings = config. max_position_embeddings ;
303307
304308 let inv_freqs = get_inv_freqs ( rotary_dim, rope_theta as f32 , vb. device ( ) , None ) ?;
305309
306- let ( cos, sin) = get_cos_sin ( max_position_embeddings, & inv_freqs, vb. dtype ( ) , true ) ?;
310+ let ( cos, sin) = get_cos_sin ( max_position_embeddings, & inv_freqs, vb. dtype ( ) , false ) ?;
307311
308312 rotary_cache. insert ( use_local_attention, ( cos, sin) ) ;
309313 }
@@ -314,7 +318,6 @@ impl FlashModernBertModel {
314318 final_norm,
315319 pool,
316320 classifier,
317- rotary_dim,
318321 rotary_cache,
319322 device : vb. device ( ) . clone ( ) ,
320323 span : tracing:: span!( tracing:: Level :: TRACE , "model" ) ,
@@ -343,9 +346,6 @@ impl FlashModernBertModel {
343346 let cos = cos. index_select ( & position_ids, 0 ) ?;
344347 let sin = sin. index_select ( & position_ids, 0 ) ?;
345348
346- let cos = cos. reshape ( ( batch_size, 1 , max_length, self . rotary_dim ) ) ?;
347- let sin = sin. reshape ( ( batch_size, 1 , max_length, self . rotary_dim ) ) ?;
348-
349349 rotary_cache. insert ( use_local_attention, ( cos, sin) ) ;
350350 }
351351
0 commit comments