Skip to content

Commit e544a04

Browse files
sijunhewj-McatguoshengCS
authored
Support past_key_values argument for Electra (#3411)
* unit test pass; fix yapf * change docstring Co-authored-by: 骑马小猫 <[email protected]> Co-authored-by: Guo Sheng <[email protected]>
1 parent ddb59bf commit e544a04

File tree

2 files changed

+104
-10
lines changed

2 files changed

+104
-10
lines changed

paddlenlp/transformers/electra/modeling.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from paddle.nn.layer.transformer import _convert_attention_mask
2424

2525
from .. import PretrainedModel, register_base_model
26-
from ..model_outputs import (BaseModelOutput, SequenceClassifierOutput,
27-
TokenClassifierOutput,
26+
from ..model_outputs import (BaseModelOutputWithPastAndCrossAttentions,
27+
SequenceClassifierOutput, TokenClassifierOutput,
2828
QuestionAnsweringModelOutput,
2929
MultipleChoiceModelOutput, MaskedLMOutput,
3030
tuple_output)
@@ -153,9 +153,12 @@ def forward(self,
153153
src_mask=src_mask,
154154
output_attentions=output_attentions)
155155
else:
156+
cache_wrapper = cache[i] if isinstance(
157+
cache[i], nn.MultiHeadAttention.Cache
158+
) else nn.MultiHeadAttention.Cache(*cache[i])
156159
output, new_cache = mod(output,
157160
src_mask=src_mask,
158-
cache=cache[i],
161+
cache=cache_wrapper,
159162
output_attentions=output_attentions)
160163
new_caches.append(new_cache)
161164
if output_attentions:
@@ -174,14 +177,13 @@ def forward(self,
174177
if not return_dict:
175178
if output_attentions or output_hidden_states:
176179
output = (output, all_attentions, all_hidden_states)
177-
178180
return output if cache is None else (output, new_caches)
179181

180-
return BaseModelOutput(
182+
return BaseModelOutputWithPastAndCrossAttentions(
181183
last_hidden_state=output,
182184
hidden_states=all_hidden_states,
183185
attentions=all_attentions,
184-
)
186+
past_key_values=new_caches)
185187

186188

187189
class ElectraEmbeddings(nn.Layer):
@@ -199,11 +201,17 @@ def __init__(self, vocab_size, embedding_size, hidden_dropout_prob,
199201
self.layer_norm = nn.LayerNorm(embedding_size, epsilon=layer_norm_eps)
200202
self.dropout = nn.Dropout(hidden_dropout_prob)
201203

202-
def forward(self, input_ids, token_type_ids=None, position_ids=None):
204+
def forward(self,
205+
input_ids,
206+
token_type_ids=None,
207+
position_ids=None,
208+
past_key_values_length=None):
203209
if position_ids is None:
204210
ones = paddle.ones_like(input_ids, dtype="int64")
205211
seq_length = paddle.cumsum(ones, axis=-1)
206212
position_ids = seq_length - ones
213+
if past_key_values_length is not None:
214+
position_ids += past_key_values_length
207215
position_ids.stop_gradient = True
208216
position_ids = position_ids.astype("int64")
209217

@@ -550,6 +558,8 @@ def forward(self,
550558
token_type_ids=None,
551559
position_ids=None,
552560
attention_mask=None,
561+
past_key_values=None,
562+
use_cache=None,
553563
output_attentions=False,
554564
output_hidden_states=False,
555565
return_dict=False):
@@ -585,6 +595,17 @@ def forward(self,
585595
When the data type is float, the `masked` tokens have `-INF` values and the others have `0` values.
586596
It is a tensor with shape broadcasted to `[batch_size, num_attention_heads, sequence_length, sequence_length]`.
587597
Defaults to `None`, which means nothing needed to be prevented attention to.
598+
past_key_values (tuple(tuple(Tensor)), optional):
599+
Precomputed key and value hidden states of the attention blocks of each layer. This can be used to speedup
600+
auto-regressive decoding for generation tasks or to support use cases such as Prefix-Tuning where vectors are prepended
601+
to each attention layer. The length of tuple equals to the number of layers, and each tuple having 2 tensors of shape
602+
`(batch_size, num_heads, past_key_values_length, embed_size_per_head)`)
603+
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that
604+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
605+
`input_ids` of shape `(batch_size, sequence_length)`.
606+
use_cache (`bool`, optional):
607+
If set to `True`, `past_key_values` key value states are returned.
608+
Defaults to `None`.
588609
output_hidden_states (bool, optional):
589610
Whether to return the hidden states of all layers.
590611
Defaults to `False`.
@@ -613,26 +634,40 @@ def forward(self,
613634
output = model(**inputs)
614635
615636
'''
637+
past_key_values_length = None
638+
if past_key_values is not None:
639+
past_key_values_length = past_key_values[0][0].shape[2]
616640

617641
if attention_mask is None:
618642
attention_mask = paddle.unsqueeze(
619643
(input_ids == self.pad_token_id).astype(
620644
paddle.get_default_dtype()) * -1e4,
621645
axis=[1, 2])
646+
if past_key_values is not None:
647+
batch_size = past_key_values[0][0].shape[0]
648+
past_mask = paddle.zeros(
649+
[batch_size, 1, 1, past_key_values_length],
650+
dtype=attention_mask.dtype)
651+
attention_mask = paddle.concat([past_mask, attention_mask],
652+
axis=-1)
622653
else:
623654
if attention_mask.ndim == 2:
624655
attention_mask = attention_mask.unsqueeze(axis=[1, 2])
625656

626-
embedding_output = self.embeddings(input_ids=input_ids,
627-
position_ids=position_ids,
628-
token_type_ids=token_type_ids)
657+
embedding_output = self.embeddings(
658+
input_ids=input_ids,
659+
position_ids=position_ids,
660+
token_type_ids=token_type_ids,
661+
past_key_values_length=past_key_values_length)
629662

630663
if hasattr(self, "embeddings_project"):
631664
embedding_output = self.embeddings_project(embedding_output)
632665

666+
self.encoder._use_cache = use_cache # To be consistent with HF
633667
encoder_outputs = self.encoder(
634668
embedding_output,
635669
attention_mask,
670+
cache=past_key_values,
636671
output_attentions=output_attentions,
637672
output_hidden_states=output_hidden_states,
638673
return_dict=return_dict)

tests/transformers/electra/test_modeling.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,60 @@ def create_and_check_electra_model(
133133
result[0].shape,
134134
[self.batch_size, self.seq_length, self.hidden_size])
135135

136+
def create_and_check_electra_model_cache(self, config, input_ids,
137+
token_type_ids, input_mask,
138+
sequence_labels, token_labels,
139+
choice_labels):
140+
model = ElectraModel(**config)
141+
model.eval()
142+
143+
input_ids = ids_tensor((self.batch_size, self.seq_length),
144+
self.vocab_size)
145+
input_token_types = ids_tensor([self.batch_size, self.seq_length],
146+
self.type_vocab_size)
147+
148+
# create tensors for past_key_values of shape [batch_size, num_heads, seq_length, head_size]
149+
embed_size_per_head = self.hidden_size // self.num_attention_heads
150+
key_tensor = floats_tensor((self.batch_size, self.num_attention_heads,
151+
self.seq_length, embed_size_per_head))
152+
values_tensor = floats_tensor(
153+
(self.batch_size, self.num_attention_heads, self.seq_length,
154+
embed_size_per_head))
155+
past_key_values = ((
156+
key_tensor,
157+
values_tensor,
158+
), ) * self.num_hidden_layers
159+
160+
# create fully-visible attention mask for input_ids only and input_ids + past
161+
attention_mask = paddle.ones([self.batch_size, self.seq_length])
162+
attention_mask_with_past = paddle.ones(
163+
[self.batch_size, self.seq_length * 2])
164+
165+
outputs_with_cache = model(input_ids,
166+
token_type_ids=input_token_types,
167+
attention_mask=attention_mask_with_past,
168+
past_key_values=past_key_values,
169+
return_dict=self.parent.return_dict)
170+
outputs_without_cache = model(input_ids,
171+
token_type_ids=input_token_types,
172+
attention_mask=attention_mask,
173+
return_dict=self.parent.return_dict)
174+
175+
# last_hidden_state should have the same shape but different values when given past_key_values
176+
if self.parent.return_dict:
177+
self.parent.assertEqual(
178+
outputs_with_cache.last_hidden_state.shape,
179+
outputs_without_cache.last_hidden_state.shape)
180+
self.parent.assertFalse(
181+
paddle.allclose(outputs_with_cache.last_hidden_state,
182+
outputs_without_cache.last_hidden_state))
183+
else:
184+
outputs_with_cache, _ = outputs_with_cache
185+
self.parent.assertEqual(outputs_with_cache.shape,
186+
outputs_without_cache.shape)
187+
self.parent.assertFalse(
188+
paddle.allclose(outputs_with_cache, outputs_without_cache))
189+
136190
def create_and_check_electra_for_masked_lm(
137191
self,
138192
config,
@@ -356,6 +410,11 @@ def test_electra_model(self):
356410
config_and_inputs = self.model_tester.prepare_config_and_inputs()
357411
self.model_tester.create_and_check_electra_model(*config_and_inputs)
358412

413+
def test_electra_model_cache(self):
414+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
415+
self.model_tester.create_and_check_electra_model_cache(
416+
*config_and_inputs)
417+
359418
def test_for_masked_lm(self):
360419
config_and_inputs = self.model_tester.prepare_config_and_inputs()
361420
self.model_tester.create_and_check_electra_for_masked_lm(

0 commit comments

Comments
 (0)