Skip to content

Commit 7b8dd49

Browse files
turkeymzyingyibiao
andauthored
[Robert] support 2D attention_mask from tokenizer (#1676)
* [Robert] support 2D attention_mask from tokenizer * format code style for yapf * format code style for yapf Co-authored-by: yingyibiao <[email protected]>
1 parent 1e14c2a commit 7b8dd49

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

paddlenlp/transformers/roberta/modeling.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,12 @@ def forward(self,
354354
(input_ids == self.pad_token_id
355355
).astype(self.pooler.dense.weight.dtype) * -1e4,
356356
axis=[1, 2])
357+
elif attention_mask.ndim == 2:
358+
attention_mask = paddle.unsqueeze(
359+
attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype())
360+
attention_mask = (1.0 - attention_mask) * -1e4
361+
attention_mask.stop_gradient = True
362+
357363
embedding_output = self.embeddings(
358364
input_ids=input_ids,
359365
position_ids=position_ids,
@@ -679,7 +685,7 @@ def forward(self,
679685
680686
tokenizer = RobertaTokenizer.from_pretrained('roberta-wwm-ext')
681687
model = RobertaForMaskedLM.from_pretrained('roberta-wwm-ext')
682-
688+
683689
inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!")
684690
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
685691
@@ -777,7 +783,7 @@ def forward(self,
777783
778784
tokenizer = RobertaTokenizer.from_pretrained('roberta-wwm-ext')
779785
model = RobertaForCausalLM.from_pretrained('roberta-wwm-ext')
780-
786+
781787
inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!")
782788
inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()}
783789

0 commit comments

Comments
 (0)