@@ -134,6 +134,7 @@ def _model_candidates(self) -> List[Dict[str, Any]]:
134
134
chinese_models = hp .choice (
135
135
"models" ,
136
136
[
137
+ "ernie-1.0-large-zh-cw" # 24-layer, 1024-hidden, 16-heads, 272M parameters.
137
138
"ernie-3.0-xbase-zh" , # 20-layer, 1024-hidden, 16-heads, 296M parameters.
138
139
"ernie-3.0-tiny-base-v2-zh" , # 12-layer, 768-hidden, 12-heads, 118M parameters.
139
140
"ernie-3.0-tiny-medium-v2-zh" , # 6-layer, 768-hidden, 12-heads, 75M parameters.
@@ -155,6 +156,21 @@ def _model_candidates(self) -> List[Dict[str, Any]]:
155
156
"ernie-2.0-large-en" , # 24-layer, 1024-hidden, 16-heads, 336M parameters. Trained on lower-cased English text.
156
157
],
157
158
)
159
+ english_prompt_models = hp .choice (
160
+ "models" ,
161
+ [
162
+ # add deberta-v3 when we have it
163
+ "roberta-large" , # 24-layer, 1024-hidden, 16-heads, 334M parameters. Case-sensitive
164
+ "roberta-base" , # 12-layer, 768-hidden, 12-heads, 110M parameters. Case-sensitive
165
+ ],
166
+ )
167
+ chinese_prompt_models = hp .choice (
168
+ "models" ,
169
+ [
170
+ "ernie-1.0-large-zh-cw" # 24-layer, 1024-hidden, 16-heads, 272M parameters.
171
+ "ernie-1.0-base-zh-cw" # 12-layer, 768-hidden, 12-heads, 118M parameters.
172
+ ],
173
+ )
158
174
return [
159
175
# fast learning: high LR, small early stop patience
160
176
{
@@ -202,7 +218,33 @@ def _model_candidates(self) -> List[Dict[str, Any]]:
202
218
"TrainingArguments.model_name_or_path" : english_models ,
203
219
"TrainingArguments.learning_rate" : 5e-6 ,
204
220
},
205
- # Note: prompt tuning candidates not included for now due to lack of inference capability
221
+ # prompt tuning candidates
222
+ {
223
+ "preset" : "prompt" ,
224
+ "language" : "Chinese" ,
225
+ "trainer_type" : "PromptTrainer" ,
226
+ "template.prompt" : "{'mask'}{'soft'}“{'text': '" + self .text_column + "'}”" ,
227
+ "EarlyStoppingCallback.early_stopping_patience" : 5 ,
228
+ "PromptTuningArguments.per_device_train_batch_size" : train_batch_size ,
229
+ "PromptTuningArguments.per_device_eval_batch_size" : train_batch_size * 2 ,
230
+ "PromptTuningArguments.num_train_epochs" : 100 ,
231
+ "PromptTuningArguments.model_name_or_path" : chinese_prompt_models ,
232
+ "PromptTuningArguments.learning_rate" : 1e-5 ,
233
+ "PromptTuningArguments.ppt_learning_rate" : 1e-4 ,
234
+ },
235
+ {
236
+ "preset" : "prompt" ,
237
+ "language" : "English" ,
238
+ "trainer_type" : "PromptTrainer" ,
239
+ "template.prompt" : "{'mask'}{'soft'}“{'text': '" + self .text_column + "'}”" ,
240
+ "EarlyStoppingCallback.early_stopping_patience" : 5 ,
241
+ "PromptTuningArguments.per_device_train_batch_size" : train_batch_size ,
242
+ "PromptTuningArguments.per_device_eval_batch_size" : train_batch_size * 2 ,
243
+ "PromptTuningArguments.num_train_epochs" : 100 ,
244
+ "PromptTuningArguments.model_name_or_path" : english_prompt_models ,
245
+ "PromptTuningArguments.learning_rate" : 1e-5 ,
246
+ "PromptTuningArguments.ppt_learning_rate" : 1e-4 ,
247
+ },
206
248
]
207
249
208
250
def _data_checks_and_inference (self ):
@@ -247,6 +289,8 @@ def _data_checks_and_inference(self):
247
289
raise ValueError (
248
290
f"Label { label } is not found in the user-provided id2label argument: { self .id2label } "
249
291
)
292
+ if not os .path .exists (self .output_dir ):
293
+ os .makedirs (self .output_dir )
250
294
id2label_path = os .path .join (self .output_dir , "id2label.json" )
251
295
with open (id2label_path , "w" , encoding = "utf-8" ) as f :
252
296
json .dump (self .id2label , f , ensure_ascii = False )
0 commit comments