23
23
from paddle .nn .layer .transformer import _convert_attention_mask
24
24
25
25
from .. import PretrainedModel , register_base_model
26
- from ..model_outputs import (BaseModelOutput , SequenceClassifierOutput ,
27
- TokenClassifierOutput ,
26
+ from ..model_outputs import (BaseModelOutputWithPastAndCrossAttentions ,
27
+ SequenceClassifierOutput , TokenClassifierOutput ,
28
28
QuestionAnsweringModelOutput ,
29
29
MultipleChoiceModelOutput , MaskedLMOutput ,
30
30
tuple_output )
@@ -153,9 +153,12 @@ def forward(self,
153
153
src_mask = src_mask ,
154
154
output_attentions = output_attentions )
155
155
else :
156
+ cache_wrapper = cache [i ] if isinstance (
157
+ cache [i ], nn .MultiHeadAttention .Cache
158
+ ) else nn .MultiHeadAttention .Cache (* cache [i ])
156
159
output , new_cache = mod (output ,
157
160
src_mask = src_mask ,
158
- cache = cache [ i ] ,
161
+ cache = cache_wrapper ,
159
162
output_attentions = output_attentions )
160
163
new_caches .append (new_cache )
161
164
if output_attentions :
@@ -174,14 +177,13 @@ def forward(self,
174
177
if not return_dict :
175
178
if output_attentions or output_hidden_states :
176
179
output = (output , all_attentions , all_hidden_states )
177
-
178
180
return output if cache is None else (output , new_caches )
179
181
180
- return BaseModelOutput (
182
+ return BaseModelOutputWithPastAndCrossAttentions (
181
183
last_hidden_state = output ,
182
184
hidden_states = all_hidden_states ,
183
185
attentions = all_attentions ,
184
- )
186
+ past_key_values = new_caches )
185
187
186
188
187
189
class ElectraEmbeddings (nn .Layer ):
@@ -199,11 +201,17 @@ def __init__(self, vocab_size, embedding_size, hidden_dropout_prob,
199
201
self .layer_norm = nn .LayerNorm (embedding_size , epsilon = layer_norm_eps )
200
202
self .dropout = nn .Dropout (hidden_dropout_prob )
201
203
202
- def forward (self , input_ids , token_type_ids = None , position_ids = None ):
204
+ def forward (self ,
205
+ input_ids ,
206
+ token_type_ids = None ,
207
+ position_ids = None ,
208
+ past_key_values_length = None ):
203
209
if position_ids is None :
204
210
ones = paddle .ones_like (input_ids , dtype = "int64" )
205
211
seq_length = paddle .cumsum (ones , axis = - 1 )
206
212
position_ids = seq_length - ones
213
+ if past_key_values_length is not None :
214
+ position_ids += past_key_values_length
207
215
position_ids .stop_gradient = True
208
216
position_ids = position_ids .astype ("int64" )
209
217
@@ -550,6 +558,8 @@ def forward(self,
550
558
token_type_ids = None ,
551
559
position_ids = None ,
552
560
attention_mask = None ,
561
+ past_key_values = None ,
562
+ use_cache = None ,
553
563
output_attentions = False ,
554
564
output_hidden_states = False ,
555
565
return_dict = False ):
@@ -585,6 +595,17 @@ def forward(self,
585
595
When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values.
586
596
It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`.
587
597
Defaults to `None`, which means nothing needed to be prevented attention to.
598
+ past_key_values (tuple(tuple(Tensor)), optional):
599
+ Precomputed key and value hidden states of the attention blocks of each layer. This can be used to speedup
600
+ auto-regressive decoding for generation tasks or to support use cases such as Prefix-Tuning where vectors are prepended
601
+ to each attention layer. The length of tuple equals to the number of layers, and each tuple having 2 tensors of shape
602
+ `(batch_size, num_heads, past_key_values_length, embed_size_per_head)`)
603
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
604
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
605
+ `input_ids` of shape `(batch_size, sequence_length)`.
606
+ use_cache (`bool`, optional):
607
+ If set to `True`, `past_key_values` key value states are returned.
608
+ Defaults to `None`.
588
609
output_hidden_states (bool, optional):
589
610
Whether to return the hidden states of all layers.
590
611
Defaults to `False`.
@@ -613,26 +634,40 @@ def forward(self,
613
634
output = model(**inputs)
614
635
615
636
'''
637
+ past_key_values_length = None
638
+ if past_key_values is not None :
639
+ past_key_values_length = past_key_values [0 ][0 ].shape [2 ]
616
640
617
641
if attention_mask is None :
618
642
attention_mask = paddle .unsqueeze (
619
643
(input_ids == self .pad_token_id ).astype (
620
644
paddle .get_default_dtype ()) * - 1e4 ,
621
645
axis = [1 , 2 ])
646
+ if past_key_values is not None :
647
+ batch_size = past_key_values [0 ][0 ].shape [0 ]
648
+ past_mask = paddle .zeros (
649
+ [batch_size , 1 , 1 , past_key_values_length ],
650
+ dtype = attention_mask .dtype )
651
+ attention_mask = paddle .concat ([past_mask , attention_mask ],
652
+ axis = - 1 )
622
653
else :
623
654
if attention_mask .ndim == 2 :
624
655
attention_mask = attention_mask .unsqueeze (axis = [1 , 2 ])
625
656
626
- embedding_output = self .embeddings (input_ids = input_ids ,
627
- position_ids = position_ids ,
628
- token_type_ids = token_type_ids )
657
+ embedding_output = self .embeddings (
658
+ input_ids = input_ids ,
659
+ position_ids = position_ids ,
660
+ token_type_ids = token_type_ids ,
661
+ past_key_values_length = past_key_values_length )
629
662
630
663
if hasattr (self , "embeddings_project" ):
631
664
embedding_output = self .embeddings_project (embedding_output )
632
665
666
+ self .encoder ._use_cache = use_cache # To be consistent with HF
633
667
encoder_outputs = self .encoder (
634
668
embedding_output ,
635
669
attention_mask ,
670
+ cache = past_key_values ,
636
671
output_attentions = output_attentions ,
637
672
output_hidden_states = output_hidden_states ,
638
673
return_dict = return_dict )
0 commit comments