Skip to content

Commit 9ef9a8d

Browse files
authored
fix the tipc for the ernie-tiny (#4207)
1 parent b19244d commit 9ef9a8d

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

tests/test_tipc/benchmark/modules/ernie_tiny.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,26 @@
1414

1515
import os
1616
import sys
17+
1718
import paddle.nn as nn
19+
from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler
1820

19-
from paddlenlp.utils.log import logger
20-
from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler
2121
from paddlenlp.data import DataCollatorWithPadding
22+
from paddlenlp.datasets import load_dataset
2223
from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer
24+
from paddlenlp.utils.log import logger
2325

2426
from .model_base import BenchmarkBase
2527

26-
from paddlenlp.datasets import load_dataset
27-
2828
sys.path.append(
2929
os.path.abspath(
3030
os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, os.pardir, os.pardir, "model_zoo", "ernie-3.0")
3131
)
3232
)
3333

34-
from run_seq_cls import convert_example
35-
from functools import partial
34+
from functools import partial # noqa: E402
35+
36+
from utils import seq_convert_example # noqa: E402
3637

3738

3839
class ErnieTinyBenchmark(BenchmarkBase):
@@ -59,7 +60,7 @@ def create_data_loader(self, args, **kwargs):
5960
tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
6061
train_ds, dev_ds = load_dataset("clue", args.task_name, splits=("train", "dev"))
6162
trans_func = partial(
62-
convert_example, label_list=train_ds.label_list, tokenizer=tokenizer, max_seq_length=args.max_seq_length
63+
seq_convert_example, label_list=train_ds.label_list, tokenizer=tokenizer, max_seq_len=args.max_seq_length
6364
)
6465

6566
train_ds = train_ds.map(trans_func, lazy=True)

0 commit comments

Comments
 (0)