Skip to content

Commit 70bc0b0

Browse files
authored
[few-shot] Update README (#3104)
* [few-shot] add input_spec in train.py * [few-shot] add early-stopping * [few-shot] update readme * [prompt] fix optimization problem in soft verbalizer
1 parent 7852629 commit 70bc0b0

File tree

9 files changed

+563
-260
lines changed

9 files changed

+563
-260
lines changed

applications/text_classification/hierarchical/few-shot/README.md

Lines changed: 159 additions & 73 deletions
Large diffs are not rendered by default.

applications/text_classification/hierarchical/few-shot/train.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,10 @@
1818

1919
import paddle
2020
import paddle.nn.functional as F
21+
from paddle.static import InputSpec
2122
from paddlenlp.utils.log import logger
2223
from paddlenlp.transformers import AutoTokenizer, AutoModelForMaskedLM
23-
from paddlenlp.trainer import PdArgumentParser
24+
from paddlenlp.trainer import PdArgumentParser, EarlyStoppingCallback
2425
from paddlenlp.prompt import (
2526
AutoTemplate,
2627
SoftVerbalizer,
@@ -106,13 +107,20 @@ def compute_metrics(eval_preds):
106107
"macro_f1_score": macro_f1_score
107108
}
108109

110+
# Deine the early-stopping callback.
111+
callbacks = [
112+
EarlyStoppingCallback(early_stopping_patience=4,
113+
early_stopping_threshold=0.)
114+
]
115+
109116
# Initialize the trainer.
110117
trainer = PromptTrainer(model=prompt_model,
111118
tokenizer=tokenizer,
112119
args=training_args,
113120
criterion=criterion,
114121
train_dataset=train_ds,
115122
eval_dataset=dev_ds,
123+
callbacks=callbacks,
116124
compute_metrics=compute_metrics)
117125

118126
# Training.
@@ -131,8 +139,15 @@ def compute_metrics(eval_preds):
131139

132140
# Export static model.
133141
if training_args.do_export:
142+
input_spec = [
143+
InputSpec(shape=[None, None], dtype="int64"), # input_ids
144+
InputSpec(shape=[None, None], dtype="int64"), # mask_ids
145+
InputSpec(shape=[None, None], dtype="int64"), # soft_token_ids
146+
]
134147
export_path = os.path.join(training_args.output_dir, 'export')
135-
trainer.export_model(export_path, export_type=model_args.export_type)
148+
trainer.export_model(export_path,
149+
input_spec=input_sepc,
150+
export_type=model_args.export_type)
136151

137152

138153
if __name__ == '__main__':

applications/text_classification/multi_class/few-shot/README.md

Lines changed: 143 additions & 64 deletions
Large diffs are not rendered by default.

applications/text_classification/multi_class/few-shot/train.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import os
1717

1818
import paddle
19+
from paddle.static import InputSpec
1920
from paddle.metric import Accuracy
2021
from paddlenlp.utils.log import logger
2122
from paddlenlp.transformers import AutoTokenizer, AutoModelForMaskedLM
22-
from paddlenlp.trainer import PdArgumentParser
23+
from paddlenlp.trainer import PdArgumentParser, EarlyStoppingCallback
2324
from paddlenlp.prompt import (
2425
AutoTemplate,
2526
SoftVerbalizer,
@@ -100,13 +101,20 @@ def compute_metrics(eval_preds):
100101
acc = metric.accumulate()
101102
return {'accuracy': acc}
102103

104+
# Deine the early-stopping callback.
105+
callbacks = [
106+
EarlyStoppingCallback(early_stopping_patience=4,
107+
early_stopping_threshold=0.)
108+
]
109+
103110
# Initialize the trainer.
104111
trainer = PromptTrainer(model=prompt_model,
105112
tokenizer=tokenizer,
106113
args=training_args,
107114
criterion=criterion,
108115
train_dataset=train_ds,
109116
eval_dataset=dev_ds,
117+
callbacks=callbacks,
110118
compute_metrics=compute_metrics)
111119

112120
# Traininig.
@@ -125,8 +133,15 @@ def compute_metrics(eval_preds):
125133

126134
# Export static model.
127135
if training_args.do_export:
136+
input_spec = [
137+
InputSpec(shape=[None, None], dtype="int64"), # input_ids
138+
InputSpec(shape=[None, None], dtype="int64"), # mask_ids
139+
InputSpec(shape=[None, None], dtype="int64"), # soft_token_ids
140+
]
128141
export_path = os.path.join(training_args.output_dir, 'export')
129-
trainer.export_model(export_path, export_type=model_args.export_type)
142+
trainer.export_model(export_path,
143+
input_spec=input_spec,
144+
export_type=model_args.export_type)
130145

131146

132147
if __name__ == '__main__':

0 commit comments

Comments
 (0)