Skip to content

Commit 90e8a09

Browse files
authored
support graph mode for t5 (#1183)
* support graph mode for t5 * fix bugs * `return_dict` change to default value to align with transformers
1 parent c102906 commit 90e8a09

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

mindone/transformers/models/t5/modeling_t5.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -303,8 +303,10 @@ def construct(
303303
is_cross_attention = key_value_states is not None
304304

305305
query_states = self.q(hidden_states)
306-
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
306+
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).swapaxes(1, 2)
307307

308+
is_updated = False
309+
curr_past_key_value = None
308310
if past_key_value is not None:
309311
is_updated = past_key_value.is_updated.get(self.layer_idx)
310312
if is_cross_attention:
@@ -321,8 +323,8 @@ def construct(
321323
else:
322324
key_states = self.k(current_states)
323325
value_states = self.v(current_states)
324-
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
325-
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
326+
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).swapaxes(1, 2)
327+
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).swapaxes(1, 2)
326328

327329
if past_key_value is not None:
328330
# save all key/value_states to cache to be re-used for fast auto-regressive generation
@@ -335,7 +337,7 @@ def construct(
335337
past_key_value.is_updated[self.layer_idx] = True
336338

337339
# compute scores, equivalent of mint.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
338-
scores = mint.matmul(query_states, key_states.transpose(3, 2))
340+
scores = mint.matmul(query_states, key_states.swapaxes(3, 2))
339341

340342
if position_bias is None:
341343
key_length = key_states.shape[-2]
@@ -372,7 +374,7 @@ def construct(
372374

373375
attn_output = mint.matmul(attn_weights, value_states)
374376

375-
attn_output = attn_output.transpose(1, 2).contiguous()
377+
attn_output = attn_output.swapaxes(1, 2).contiguous()
376378
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
377379
attn_output = self.o(attn_output)
378380

@@ -483,7 +485,7 @@ def construct(
483485
past_key_value=None,
484486
use_cache=False,
485487
output_attentions=False,
486-
return_dict: Optional[bool] = False,
488+
return_dict: Optional[bool] = None,
487489
cache_position=None,
488490
):
489491
self_attention_outputs = self.layer[0](
@@ -676,7 +678,7 @@ def construct(
676678
use_cache=None,
677679
output_attentions=None,
678680
output_hidden_states=None,
679-
return_dict: Optional[bool] = False,
681+
return_dict: Optional[bool] = None,
680682
cache_position=None,
681683
):
682684
use_cache = use_cache if use_cache is not None else self.config.use_cache
@@ -786,6 +788,7 @@ def construct(
786788
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
787789
position_bias = None
788790
encoder_decoder_position_bias = None
791+
next_decoder_cache = None
789792

790793
hidden_states = self.dropout(inputs_embeds)
791794

@@ -1069,7 +1072,7 @@ def construct(
10691072
use_cache: Optional[bool] = None,
10701073
output_attentions: Optional[bool] = None,
10711074
output_hidden_states: Optional[bool] = None,
1072-
return_dict: Optional[bool] = False,
1075+
return_dict: Optional[bool] = None,
10731076
cache_position: Optional[ms.Tensor] = None,
10741077
) -> Union[Tuple[ms.Tensor], Seq2SeqModelOutput]:
10751078
r"""
@@ -1099,6 +1102,7 @@ def construct(
10991102
>>> last_hidden_states = outputs[0]
11001103
```"""
11011104
use_cache = use_cache if use_cache is not None else self.use_cache
1105+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
11021106

11031107
# Encode if needed (training, first prediction pass)
11041108
if encoder_outputs is None:
@@ -1383,7 +1387,7 @@ def construct(
13831387
inputs_embeds: Optional[Tensor] = None,
13841388
output_attentions: Optional[bool] = None,
13851389
output_hidden_states: Optional[bool] = None,
1386-
return_dict: Optional[bool] = False,
1390+
return_dict: Optional[bool] = None,
13871391
) -> Union[Tuple[ms.Tensor], BaseModelOutput]:
13881392
r"""
13891393
Returns:
@@ -1403,6 +1407,7 @@ def construct(
14031407
>>> outputs = model(input_ids=Tensor(input_ids))
14041408
>>> last_hidden_states = outputs[0]
14051409
```"""
1410+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
14061411

14071412
encoder_outputs = self.encoder(
14081413
input_ids=input_ids,

0 commit comments

Comments
 (0)