@@ -303,8 +303,10 @@ def construct(
303303 is_cross_attention = key_value_states is not None
304304
305305 query_states = self .q (hidden_states )
306- query_states = query_states .view (batch_size , - 1 , self .n_heads , self .key_value_proj_dim ).transpose (1 , 2 )
306+ query_states = query_states .view (batch_size , - 1 , self .n_heads , self .key_value_proj_dim ).swapaxes (1 , 2 )
307307
308+ is_updated = False
309+ curr_past_key_value = None
308310 if past_key_value is not None :
309311 is_updated = past_key_value .is_updated .get (self .layer_idx )
310312 if is_cross_attention :
@@ -321,8 +323,8 @@ def construct(
321323 else :
322324 key_states = self .k (current_states )
323325 value_states = self .v (current_states )
324- key_states = key_states .view (batch_size , - 1 , self .n_heads , self .key_value_proj_dim ).transpose (1 , 2 )
325- value_states = value_states .view (batch_size , - 1 , self .n_heads , self .key_value_proj_dim ).transpose (1 , 2 )
326+ key_states = key_states .view (batch_size , - 1 , self .n_heads , self .key_value_proj_dim ).swapaxes (1 , 2 )
327+ value_states = value_states .view (batch_size , - 1 , self .n_heads , self .key_value_proj_dim ).swapaxes (1 , 2 )
326328
327329 if past_key_value is not None :
328330 # save all key/value_states to cache to be re-used for fast auto-regressive generation
@@ -335,7 +337,7 @@ def construct(
335337 past_key_value .is_updated [self .layer_idx ] = True
336338
337339 # compute scores, equivalent of mint.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
338- scores = mint .matmul (query_states , key_states .transpose (3 , 2 ))
340+ scores = mint .matmul (query_states , key_states .swapaxes (3 , 2 ))
339341
340342 if position_bias is None :
341343 key_length = key_states .shape [- 2 ]
@@ -372,7 +374,7 @@ def construct(
372374
373375 attn_output = mint .matmul (attn_weights , value_states )
374376
375- attn_output = attn_output .transpose (1 , 2 ).contiguous ()
377+ attn_output = attn_output .swapaxes (1 , 2 ).contiguous ()
376378 attn_output = attn_output .view (batch_size , - 1 , self .inner_dim )
377379 attn_output = self .o (attn_output )
378380
@@ -483,7 +485,7 @@ def construct(
483485 past_key_value = None ,
484486 use_cache = False ,
485487 output_attentions = False ,
486- return_dict : Optional [bool ] = False ,
488+ return_dict : Optional [bool ] = None ,
487489 cache_position = None ,
488490 ):
489491 self_attention_outputs = self .layer [0 ](
@@ -676,7 +678,7 @@ def construct(
676678 use_cache = None ,
677679 output_attentions = None ,
678680 output_hidden_states = None ,
679- return_dict : Optional [bool ] = False ,
681+ return_dict : Optional [bool ] = None ,
680682 cache_position = None ,
681683 ):
682684 use_cache = use_cache if use_cache is not None else self .config .use_cache
@@ -786,6 +788,7 @@ def construct(
786788 all_cross_attentions = () if (output_attentions and self .is_decoder ) else None
787789 position_bias = None
788790 encoder_decoder_position_bias = None
791+ next_decoder_cache = None
789792
790793 hidden_states = self .dropout (inputs_embeds )
791794
@@ -1069,7 +1072,7 @@ def construct(
10691072 use_cache : Optional [bool ] = None ,
10701073 output_attentions : Optional [bool ] = None ,
10711074 output_hidden_states : Optional [bool ] = None ,
1072- return_dict : Optional [bool ] = False ,
1075+ return_dict : Optional [bool ] = None ,
10731076 cache_position : Optional [ms .Tensor ] = None ,
10741077 ) -> Union [Tuple [ms .Tensor ], Seq2SeqModelOutput ]:
10751078 r"""
@@ -1099,6 +1102,7 @@ def construct(
10991102 >>> last_hidden_states = outputs[0]
11001103 ```"""
11011104 use_cache = use_cache if use_cache is not None else self .use_cache
1105+ return_dict = return_dict if return_dict is not None else self .config .use_return_dict
11021106
11031107 # Encode if needed (training, first prediction pass)
11041108 if encoder_outputs is None :
@@ -1383,7 +1387,7 @@ def construct(
13831387 inputs_embeds : Optional [Tensor ] = None ,
13841388 output_attentions : Optional [bool ] = None ,
13851389 output_hidden_states : Optional [bool ] = None ,
1386- return_dict : Optional [bool ] = False ,
1390+ return_dict : Optional [bool ] = None ,
13871391 ) -> Union [Tuple [ms .Tensor ], BaseModelOutput ]:
13881392 r"""
13891393 Returns:
@@ -1403,6 +1407,7 @@ def construct(
14031407 >>> outputs = model(input_ids=Tensor(input_ids))
14041408 >>> last_hidden_states = outputs[0]
14051409 ```"""
1410+ return_dict = return_dict if return_dict is not None else self .config .use_return_dict
14061411
14071412 encoder_outputs = self .encoder (
14081413 input_ids = input_ids ,
0 commit comments