Skip to content

Commit aec3119

Browse files
author
tianxin
authored
add attn_mask input for encoder-decoder (#1431) (#1438)
1 parent f079f91 commit aec3119

File tree

2 files changed

+16
-15
lines changed

2 files changed

+16
-15
lines changed

examples/semantic_indexing/faster_predict.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def get_pooled_embedding(self,
9898
input_ids,
9999
token_type_ids=None,
100100
position_ids=None):
101-
src_mask = (input_ids != self.bos_id
102-
).astype(self.ptm.encoder.layers[0].norm1.bias.dtype)
101+
src_mask = input_ids == self.bos_id
102+
src_mask = paddle.cast(src_mask, "float32")
103103
# [bs, 1, 1, max_len]
104104
src_mask = paddle.unsqueeze(src_mask, axis=[1, 2])
105105
src_mask.stop_gradient = True
@@ -116,7 +116,6 @@ def get_pooled_embedding(self,
116116

117117
if self.use_fp16:
118118
embedding_output = paddle.cast(embedding_output, 'float16')
119-
src_mask = paddle.cast(src_mask, 'float16')
120119

121120
sequence_output = self.ptm.encoder(embedding_output, src_mask)
122121

paddlenlp/ops/faster_transformer/transformer/encoder.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -201,19 +201,13 @@ def encoder_forward(self, src, src_mask=None, cache=None):
201201
src (Tensor):
202202
The input of Transformer encoder. It is a tensor
203203
with shape `[batch_size, sequence_length, d_model]`. The data
204-
type should be float32 or float64.
204+
type should be float32 or float16.
205205
src_mask (Tensor, optional):
206206
A tensor used in multi-head attention to prevents attention to
207207
some unwanted positions, usually the paddings or the subsequent
208208
positions. It is a tensor with shape `[batch_size, 1, 1, sequence_length]`.
209-
When the data type is bool, the unwanted positions have `False`
210-
values and the others have `True` values. When the data type is
211-
int, the unwanted positions have 0 values and the others have 1
212-
values. When the data type is float, the unwanted positions have
213-
`-INF` values and the others have 0 values. It can be None when
214-
nothing wanted or needed to be prevented attention to. Defaults
215-
to None.
216-
209+
The data type must be float, the unwanted positions have `-INF` values or other non-zeros
210+
and the wanted positions must be 0.0.
217211
Returns:
218212
output (Tensor|tuple):
219213
It is a tensor that has the same shape and data type as `src`,
@@ -225,9 +219,17 @@ def encoder_forward(self, src, src_mask=None, cache=None):
225219
`paddle.nn.MultiHeadAttention.forward` for more details.
226220
"""
227221

228-
max_seq_len = src.shape[1]
229-
# broadcast
230-
src_mask = paddle.concat(x=[src_mask] * max_seq_len, axis=2)
222+
if src_mask.dtype == paddle.float16:
223+
src_mask = paddle.cast(src_mask, "float32")
224+
225+
src_mask = src_mask == 0.0
226+
src_mask = paddle.cast(src_mask, src.dtype)
227+
228+
# transpose_src_mask: [batch_size, 1, sequence_length, 1]
229+
transpose_src_mask = paddle.transpose(src_mask, perm=[0, 1, 3, 2])
230+
231+
# src_mask: [batch_size, 1, sequence_length, sequence_length]
232+
src_mask = src_mask * transpose_src_mask
231233
output = src
232234
for i, layer in enumerate(self.layers):
233235
output = layer(output, src_mask)

0 commit comments

Comments
 (0)