@@ -201,19 +201,13 @@ def encoder_forward(self, src, src_mask=None, cache=None):
201
201
src (Tensor):
202
202
The input of Transformer encoder. It is a tensor
203
203
with shape `[batch_size, sequence_length, d_model]`. The data
204
- type should be float32 or float64 .
204
+ type should be float32 or float16 .
205
205
src_mask (Tensor, optional):
206
206
A tensor used in multi-head attention to prevents attention to
207
207
some unwanted positions, usually the paddings or the subsequent
208
208
positions. It is a tensor with shape `[batch_size, 1, 1, sequence_length]`.
209
- When the data type is bool, the unwanted positions have `False`
210
- values and the others have `True` values. When the data type is
211
- int, the unwanted positions have 0 values and the others have 1
212
- values. When the data type is float, the unwanted positions have
213
- `-INF` values and the others have 0 values. It can be None when
214
- nothing wanted or needed to be prevented attention to. Defaults
215
- to None.
216
-
209
+ The data type must be float, the unwanted positions have `-INF` values or other non-zeros
210
+ and the wanted positions must be 0.0.
217
211
Returns:
218
212
output (Tensor|tuple):
219
213
It is a tensor that has the same shape and data type as `src`,
@@ -225,9 +219,17 @@ def encoder_forward(self, src, src_mask=None, cache=None):
225
219
`paddle.nn.MultiHeadAttention.forward` for more details.
226
220
"""
227
221
228
- max_seq_len = src .shape [1 ]
229
- # broadcast
230
- src_mask = paddle .concat (x = [src_mask ] * max_seq_len , axis = 2 )
222
+ if src_mask .dtype == paddle .float16 :
223
+ src_mask = paddle .cast (src_mask , "float32" )
224
+
225
+ src_mask = src_mask == 0.0
226
+ src_mask = paddle .cast (src_mask , src .dtype )
227
+
228
+ # transpose_src_mask: [batch_size, 1, sequence_length, 1]
229
+ transpose_src_mask = paddle .transpose (src_mask , perm = [0 , 1 , 3 , 2 ])
230
+
231
+ # src_mask: [batch_size, 1, sequence_length, sequence_length]
232
+ src_mask = src_mask * transpose_src_mask
231
233
output = src
232
234
for i , layer in enumerate (self .layers ):
233
235
output = layer (output , src_mask )
0 commit comments