Skip to content

Commit 32d01fa

Browse files
author
gongenlei
authored
[BUGFIX] BART and mBART support 2D attention mask from tokenizer (#1637)
* fix: bart and mbart support 2d attention mask * fix: bart and mbart support 2d attention mask
1 parent 82b1cc4 commit 32d01fa

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

paddlenlp/transformers/bart/modeling.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,12 @@ def forward(self, input_ids=None, attention_mask=None, **kwargs):
214214
attention_mask = paddle.cast(
215215
input_ids == self.pad_token_id,
216216
dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e4
217-
attention_mask.stop_gradient = True
217+
# For 2D attention_mask from tokenizer
218+
elif attention_mask.ndim == 2:
219+
attention_mask = paddle.unsqueeze(
220+
attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype())
221+
attention_mask = (1.0 - attention_mask) * -1e4
222+
attention_mask.stop_gradient = True
218223

219224
encoder_output = self.encoder(encoder_input, src_mask=attention_mask)
220225
return encoder_output

paddlenlp/transformers/mbart/modeling.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,12 @@ def forward(self, input_ids=None, attention_mask=None, **kwargs):
286286
attention_mask = paddle.cast(
287287
input_ids == self.pad_token_id,
288288
dtype=paddle.get_default_dtype()).unsqueeze([1, 2]) * -1e4
289-
attention_mask.stop_gradient = True
289+
# For 2D attention_mask from tokenizer
290+
elif attention_mask.ndim == 2:
291+
attention_mask = paddle.unsqueeze(
292+
attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype())
293+
attention_mask = (1.0 - attention_mask) * -1e4
294+
attention_mask.stop_gradient = True
290295

291296
encoder_output = self.encoder(encoder_input, src_mask=attention_mask)
292297
return encoder_output

0 commit comments

Comments
 (0)