Skip to content

Commit 74f9fae

Browse files
zzz2010chfhf
andauthored
funnel transformer (#1419)
* first version * add funnel to init.py and run_squad.py * add attention_mask * rename FunnelTokenizerFast to FunnelTokenizer * revised based on the PR comments * further clean up function description section * further clean up * pre-commit check * pre-commit check * pre-commit check Co-authored-by: chfhf <[email protected]>
1 parent a8b1956 commit 74f9fae

File tree

5 files changed

+3325
-7
lines changed

5 files changed

+3325
-7
lines changed

examples/machine_reading_comprehension/SQuAD/run_squad.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@
2929
import paddlenlp as ppnlp
3030

3131
from paddlenlp.data import Pad, Stack, Tuple, Dict
32-
from paddlenlp.transformers import BertForQuestionAnswering, BertTokenizer, ErnieForQuestionAnswering, ErnieTokenizer
32+
from paddlenlp.transformers import BertForQuestionAnswering, BertTokenizer, ErnieForQuestionAnswering, ErnieTokenizer, FunnelForQuestionAnswering, FunnelTokenizer
3333
from paddlenlp.transformers import LinearDecayWithWarmup
3434
from paddlenlp.metrics.squad import squad_evaluate, compute_prediction
3535
from paddlenlp.datasets import load_dataset
3636

3737
MODEL_CLASSES = {
3838
"bert": (BertForQuestionAnswering, BertTokenizer),
39-
"ernie": (ErnieForQuestionAnswering, ErnieTokenizer)
39+
"ernie": (ErnieForQuestionAnswering, ErnieTokenizer),
40+
'funnel':(FunnelForQuestionAnswering, FunnelTokenizer)
4041
}
4142

4243

@@ -162,9 +163,9 @@ def evaluate(model, data_loader, args):
162163
tic_eval = time.time()
163164

164165
for batch in data_loader:
165-
input_ids, token_type_ids = batch
166+
input_ids, token_type_ids, attention_mask = batch
166167
start_logits_tensor, end_logits_tensor = model(input_ids,
167-
token_type_ids)
168+
token_type_ids=token_type_ids, attention_mask=attention_mask)
168169

169170
for idx in range(start_logits_tensor.shape[0]):
170171
if len(all_start_logits) % 1000 == 0 and len(all_start_logits):
@@ -251,6 +252,7 @@ def run(args):
251252
train_batchify_fn = lambda samples, fn=Dict({
252253
"input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
253254
"token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
255+
'attention_mask': Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
254256
"start_positions": Stack(dtype="int64"),
255257
"end_positions": Stack(dtype="int64")
256258
}): fn(samples)
@@ -288,10 +290,10 @@ def run(args):
288290
for epoch in range(num_train_epochs):
289291
for step, batch in enumerate(train_data_loader):
290292
global_step += 1
291-
input_ids, token_type_ids, start_positions, end_positions = batch
293+
input_ids, token_type_ids, attention_mask, start_positions, end_positions = batch
292294

293295
logits = model(
294-
input_ids=input_ids, token_type_ids=token_type_ids)
296+
input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
295297
loss = criterion(logits, (start_positions, end_positions))
296298

297299
if global_step % args.logging_steps == 0:
@@ -329,7 +331,8 @@ def run(args):
329331

330332
dev_batchify_fn = lambda samples, fn=Dict({
331333
"input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id),
332-
"token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id)
334+
"token_type_ids": Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
335+
"attention_mask": Pad(axis=0, pad_val=tokenizer.pad_token_type_id)
333336
}): fn(samples)
334337

335338
dev_data_loader = DataLoader(

paddlenlp/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,5 @@
8888
from .mobilebert.tokenizer import *
8989
from .chinesebert.modeling import *
9090
from .chinesebert.tokenizer import *
91+
from .funnel.modeling import *
92+
from .funnel.tokenizer import *
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .modeling import *
2+
from .tokenizer import *

0 commit comments

Comments
 (0)