Skip to content

Commit 8993c32

Browse files
authored
support 2D attention_mask from tokenizer (#1634)
1 parent 078758e commit 8993c32

File tree

3 files changed

+21
-0
lines changed

3 files changed

+21
-0
lines changed

paddlenlp/transformers/ernie/modeling.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,13 @@ def forward(self,
371371
(input_ids == self.pad_token_id
372372
).astype(self.pooler.dense.weight.dtype) * -1e4,
373373
axis=[1, 2])
374+
# For 2D attention_mask from tokenizer
375+
elif attention_mask.ndim == 2:
376+
attention_mask = paddle.unsqueeze(
377+
attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype())
378+
attention_mask = (1.0 - attention_mask) * -1e4
379+
attention_mask.stop_gradient = True
380+
374381
embedding_output = self.embeddings(
375382
input_ids=input_ids,
376383
position_ids=position_ids,

paddlenlp/transformers/ernie_ctm/modeling.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,12 @@ def forward(self,
398398
(input_ids == self.pad_token_id
399399
).astype(self.pooler.dense.weight.dtype) * -1e4,
400400
axis=[1, 2])
401+
# For 2D attention_mask from tokenizer
402+
elif attention_mask.ndim == 2:
403+
attention_mask = paddle.unsqueeze(
404+
attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype())
405+
attention_mask = (1.0 - attention_mask) * -1e4
406+
attention_mask.stop_gradient = True
401407

402408
embedding_output = self.embeddings(
403409
input_ids=input_ids,

paddlenlp/transformers/ernie_gram/modeling.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,14 @@ def forward(self,
288288
(input_ids == self.pad_token_id
289289
).astype(self.pooler.dense.weight.dtype) * -1e4,
290290
axis=[1, 2])
291+
# For 2D attention_mask from tokenizer
292+
elif attention_mask.ndim == 2:
293+
attention_mask = paddle.unsqueeze(
294+
attention_mask,
295+
axis=[1, 2]).astype(self.pooler.dense.weight.dtype)
296+
attention_mask = (1.0 - attention_mask) * -1e4
297+
attention_mask.stop_gradient = True
298+
291299
embedding_output = self.embeddings(
292300
input_ids=input_ids,
293301
position_ids=position_ids,

0 commit comments

Comments
 (0)