@@ -241,16 +241,12 @@ def forward(
241241 dt_states = self .dt_proj (
242242 value_states .transpose (1 , 2 ).reshape (value_states .shape [0 ], value_states .shape [- 2 ], - 1 )
243243 )
244- dt_states = torch .exp (self .A * F .softplus (dt_states )).transpose (- 1 , - 2 )
245- attn_bias = dt_states [:, :, None , :].expand (
246- - 1 , - 1 , hidden_states .shape [1 ], - 1
247- ).to (hidden_states .dtype ) # [batch_size, num_heads, query_len, key_len]
244+ attn_bias = torch .exp (self .A * F .softplus (dt_states )).transpose (- 1 , - 2 ).to (hidden_states .dtype )
248245
249246 attention_interface : Callable = eager_attention_forward
250247 if flash_dynamic_mask_attention_forward is not None :
251248 attention_interface = flash_dynamic_mask_attention_forward
252249
253- attention_mask = attention_mask .expand (- 1 , attn_bias .shape [1 ], - 1 , - 1 ) if attention_mask is not None else None # attention_mask: batch, num_kv_heads, query_len, key_len
254250 attn_output , attn_weights = attention_interface (
255251 self ,
256252 query_states ,
@@ -414,7 +410,7 @@ def _init_weights(self, module):
414410 super ()._init_weights (module )
415411 if isinstance (module , DogeAttention ):
416412 if hasattr (module , "A" ):
417- module .A .data .zero_ ( )
413+ module .A .data .normal_ ( mean = 0.0 , std = self . config . initializer_range )
418414 elif isinstance (module , DogeCDMoE ):
419415 if hasattr (module , "router_gate" ):
420416 module .router_gate .weight .data .zero_ ()
0 commit comments