Skip to content

Commit 886834e

Browse files
authored
[autonlp]add prompt candidates for text classification (#4867)
* add prompt candidates * fix
1 parent 33be7e5 commit 886834e

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

paddlenlp/experimental/autonlp/text_classification.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def _model_candidates(self) -> List[Dict[str, Any]]:
134134
chinese_models = hp.choice(
135135
"models",
136136
[
137+
"ernie-1.0-large-zh-cw" # 24-layer, 1024-hidden, 16-heads, 272M parameters.
137138
"ernie-3.0-xbase-zh", # 20-layer, 1024-hidden, 16-heads, 296M parameters.
138139
"ernie-3.0-tiny-base-v2-zh", # 12-layer, 768-hidden, 12-heads, 118M parameters.
139140
"ernie-3.0-tiny-medium-v2-zh", # 6-layer, 768-hidden, 12-heads, 75M parameters.
@@ -155,6 +156,21 @@ def _model_candidates(self) -> List[Dict[str, Any]]:
155156
"ernie-2.0-large-en", # 24-layer, 1024-hidden, 16-heads, 336M parameters. Trained on lower-cased English text.
156157
],
157158
)
159+
english_prompt_models = hp.choice(
160+
"models",
161+
[
162+
# add deberta-v3 when we have it
163+
"roberta-large", # 24-layer, 1024-hidden, 16-heads, 334M parameters. Case-sensitive
164+
"roberta-base", # 12-layer, 768-hidden, 12-heads, 110M parameters. Case-sensitive
165+
],
166+
)
167+
chinese_prompt_models = hp.choice(
168+
"models",
169+
[
170+
"ernie-1.0-large-zh-cw" # 24-layer, 1024-hidden, 16-heads, 272M parameters.
171+
"ernie-1.0-base-zh-cw" # 12-layer, 768-hidden, 12-heads, 118M parameters.
172+
],
173+
)
158174
return [
159175
# fast learning: high LR, small early stop patience
160176
{
@@ -202,7 +218,33 @@ def _model_candidates(self) -> List[Dict[str, Any]]:
202218
"TrainingArguments.model_name_or_path": english_models,
203219
"TrainingArguments.learning_rate": 5e-6,
204220
},
205-
# Note: prompt tuning candidates not included for now due to lack of inference capability
221+
# prompt tuning candidates
222+
{
223+
"preset": "prompt",
224+
"language": "Chinese",
225+
"trainer_type": "PromptTrainer",
226+
"template.prompt": "{'mask'}{'soft'}“{'text': '" + self.text_column + "'}”",
227+
"EarlyStoppingCallback.early_stopping_patience": 5,
228+
"PromptTuningArguments.per_device_train_batch_size": train_batch_size,
229+
"PromptTuningArguments.per_device_eval_batch_size": train_batch_size * 2,
230+
"PromptTuningArguments.num_train_epochs": 100,
231+
"PromptTuningArguments.model_name_or_path": chinese_prompt_models,
232+
"PromptTuningArguments.learning_rate": 1e-5,
233+
"PromptTuningArguments.ppt_learning_rate": 1e-4,
234+
},
235+
{
236+
"preset": "prompt",
237+
"language": "English",
238+
"trainer_type": "PromptTrainer",
239+
"template.prompt": "{'mask'}{'soft'}“{'text': '" + self.text_column + "'}”",
240+
"EarlyStoppingCallback.early_stopping_patience": 5,
241+
"PromptTuningArguments.per_device_train_batch_size": train_batch_size,
242+
"PromptTuningArguments.per_device_eval_batch_size": train_batch_size * 2,
243+
"PromptTuningArguments.num_train_epochs": 100,
244+
"PromptTuningArguments.model_name_or_path": english_prompt_models,
245+
"PromptTuningArguments.learning_rate": 1e-5,
246+
"PromptTuningArguments.ppt_learning_rate": 1e-4,
247+
},
206248
]
207249

208250
def _data_checks_and_inference(self):
@@ -247,6 +289,8 @@ def _data_checks_and_inference(self):
247289
raise ValueError(
248290
f"Label {label} is not found in the user-provided id2label argument: {self.id2label}"
249291
)
292+
if not os.path.exists(self.output_dir):
293+
os.makedirs(self.output_dir)
250294
id2label_path = os.path.join(self.output_dir, "id2label.json")
251295
with open(id2label_path, "w", encoding="utf-8") as f:
252296
json.dump(self.id2label, f, ensure_ascii=False)

paddlenlp/prompt/prompt_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _save(self, output_dir: Optional[str] = None, state_dict: Dict[str, Any] = N
144144
def load_state_dict_from_checkpoint(self, resume_from_checkpoint: os.PathLike = None):
145145
if resume_from_checkpoint is not None:
146146
self.template = AutoTemplate.load_from(
147-
resume_from_checkpoint, self.tokenizer, self.args.max_seq_length, self._get_model()
147+
resume_from_checkpoint, self.tokenizer, self.args.max_seq_length, self._get_model().plm
148148
)
149149
super(PromptTrainer, self).load_state_dict_from_checkpoint(resume_from_checkpoint)
150150

0 commit comments

Comments
 (0)