Skip to content

Commit 36a8f50

Browse files
authored
Simplify replacement of pretrained models (#1788)
1 parent 43ce071 commit 36a8f50

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

examples/text_classification/pretrained_models/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,13 +95,13 @@ $ python -m paddle.distributed.launch --gpus "0" train.py --device gpu --save_di
9595
```python
9696
# 使用ernie预训练模型
9797
# ernie-1.0
98-
model = ppnlp.transformers.ErnieForSequenceClassification.from_pretrained('ernie-1.0',num_classes=2))
99-
tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0')
98+
model = AutoModelForSequenceClassification.from_pretrained('ernie-1.0',num_classes=2))
99+
tokenizer = AutoTokenizer.from_pretrained('ernie-1.0')
100100

101101
# 使用bert预训练模型
102102
# bert-base-chinese
103-
model = ppnlp.transformers.BertForSequenceClassification.from_pretrained('bert-base-chinese', num_class=2)
104-
tokenizer = ppnlp.transformers.BertTokenizer.from_pretrained('bert-base-chinese')
103+
model = AutoModelForSequenceClassification.from_pretrained('bert-base-chinese', num_class=2)
104+
tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese')
105105
```
106106
更多预训练模型,参考[transformers](../../../docs/model_zoo/transformers.rst)
107107

examples/text_classification/pretrained_models/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
import numpy as np
2323
import paddle
2424
import paddle.nn.functional as F
25-
import paddlenlp as ppnlp
2625
from paddlenlp.data import Stack, Tuple, Pad
2726
from paddlenlp.datasets import load_dataset
27+
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
2828
from paddlenlp.transformers import LinearDecayWithWarmup
2929

3030
from utils import convert_example
@@ -120,9 +120,9 @@ def do_train():
120120
train_ds, dev_ds, test_ds = load_dataset(
121121
args.dataset, splits=["train", "dev", "test"])
122122

123-
model = ppnlp.transformers.ErnieForSequenceClassification.from_pretrained(
123+
model = AutoModelForSequenceClassification.from_pretrained(
124124
'ernie-1.0', num_classes=len(train_ds.label_list))
125-
tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0')
125+
tokenizer = AutoTokenizer.from_pretrained('ernie-1.0')
126126

127127
trans_func = partial(
128128
convert_example,

0 commit comments

Comments
 (0)