File tree Expand file tree Collapse file tree 3 files changed +21
-0
lines changed Expand file tree Collapse file tree 3 files changed +21
-0
lines changed Original file line number Diff line number Diff line change @@ -371,6 +371,13 @@ def forward(self,
371
371
(input_ids == self .pad_token_id
372
372
).astype (self .pooler .dense .weight .dtype ) * - 1e4 ,
373
373
axis = [1 , 2 ])
374
+ # For 2D attention_mask from tokenizer
375
+ elif attention_mask .ndim == 2 :
376
+ attention_mask = paddle .unsqueeze (
377
+ attention_mask , axis = [1 , 2 ]).astype (paddle .get_default_dtype ())
378
+ attention_mask = (1.0 - attention_mask ) * - 1e4
379
+ attention_mask .stop_gradient = True
380
+
374
381
embedding_output = self .embeddings (
375
382
input_ids = input_ids ,
376
383
position_ids = position_ids ,
Original file line number Diff line number Diff line change @@ -398,6 +398,12 @@ def forward(self,
398
398
(input_ids == self .pad_token_id
399
399
).astype (self .pooler .dense .weight .dtype ) * - 1e4 ,
400
400
axis = [1 , 2 ])
401
+ # For 2D attention_mask from tokenizer
402
+ elif attention_mask .ndim == 2 :
403
+ attention_mask = paddle .unsqueeze (
404
+ attention_mask , axis = [1 , 2 ]).astype (paddle .get_default_dtype ())
405
+ attention_mask = (1.0 - attention_mask ) * - 1e4
406
+ attention_mask .stop_gradient = True
401
407
402
408
embedding_output = self .embeddings (
403
409
input_ids = input_ids ,
Original file line number Diff line number Diff line change @@ -288,6 +288,14 @@ def forward(self,
288
288
(input_ids == self .pad_token_id
289
289
).astype (self .pooler .dense .weight .dtype ) * - 1e4 ,
290
290
axis = [1 , 2 ])
291
+ # For 2D attention_mask from tokenizer
292
+ elif attention_mask .ndim == 2 :
293
+ attention_mask = paddle .unsqueeze (
294
+ attention_mask ,
295
+ axis = [1 , 2 ]).astype (self .pooler .dense .weight .dtype )
296
+ attention_mask = (1.0 - attention_mask ) * - 1e4
297
+ attention_mask .stop_gradient = True
298
+
291
299
embedding_output = self .embeddings (
292
300
input_ids = input_ids ,
293
301
position_ids = position_ids ,
You can’t perform that action at this time.
0 commit comments