@@ -126,6 +126,15 @@ def test_multiclass(self, custom_model_candidate, hp_overrides):
126
126
result_hp_key = f"config/candidates/{ hp_key } "
127
127
self .assertEqual (results_df [result_hp_key ][0 ], hp_value )
128
128
129
+ # test save
130
+ self .assertTrue (os .path .exists (os .path .join (auto_trainer .output_dir , "id2label.json" )))
131
+ save_path = os .path .join (auto_trainer ._get_model_result ().log_dir , auto_trainer .save_path )
132
+ self .assertTrue (os .path .exists (os .path .join (save_path , "model_state.pdparams" )))
133
+ self .assertTrue (os .path .exists (os .path .join (save_path , "tokenizer_config.json" )))
134
+ if custom_model_candidate ["trainer_type" ] == "PromptTrainer" :
135
+ self .assertTrue (os .path .exists (os .path .join (save_path , "template_config.json" )))
136
+ self .assertTrue (os .path .exists (os .path .join (save_path , "verbalizer_config.json" )))
137
+
129
138
# test export
130
139
temp_export_path = os .path .join (temp_dir_path , "test_export" )
131
140
auto_trainer .export (export_path = temp_export_path )
@@ -159,6 +168,9 @@ def test_multiclass(self, custom_model_candidate, hp_overrides):
159
168
for prediction in test_result ["predictions" ]:
160
169
self .assertIn (prediction ["label" ], auto_trainer .label2id )
161
170
171
+ # test training_path
172
+ self .assertFalse (os .path .exists (os .path .join (auto_trainer .training_path )))
173
+
162
174
@parameterized .expand (
163
175
[
164
176
(finetune_model_candidate , {"TrainingArguments.max_steps" : 2 }),
@@ -205,6 +217,15 @@ def test_multilabel(self, custom_model_candidate, hp_overrides):
205
217
result_hp_key = f"config/candidates/{ hp_key } "
206
218
self .assertEqual (results_df [result_hp_key ][0 ], hp_value )
207
219
220
+ # test save
221
+ self .assertTrue (os .path .exists (os .path .join (auto_trainer .output_dir , "id2label.json" )))
222
+ save_path = os .path .join (auto_trainer ._get_model_result ().log_dir , auto_trainer .save_path )
223
+ self .assertTrue (os .path .exists (os .path .join (save_path , "model_state.pdparams" )))
224
+ self .assertTrue (os .path .exists (os .path .join (save_path , "tokenizer_config.json" )))
225
+ if custom_model_candidate ["trainer_type" ] == "PromptTrainer" :
226
+ self .assertTrue (os .path .exists (os .path .join (save_path , "template_config.json" )))
227
+ self .assertTrue (os .path .exists (os .path .join (save_path , "verbalizer_config.json" )))
228
+
208
229
# test export
209
230
temp_export_path = os .path .join (temp_dir_path , "test_export" )
210
231
auto_trainer .export (export_path = temp_export_path )
@@ -240,6 +261,9 @@ def test_multilabel(self, custom_model_candidate, hp_overrides):
240
261
self .assertIn (prediction ["label" ], auto_trainer .label2id )
241
262
self .assertGreater (prediction ["score" ], taskflow .task_instance .multilabel_threshold )
242
263
264
+ # test training_path
265
+ self .assertFalse (os .path .exists (os .path .join (auto_trainer .training_path )))
266
+
243
267
def test_untrained_auto_trainer (self ):
244
268
with TemporaryDirectory () as temp_dir :
245
269
train_ds = copy .deepcopy (self .multi_class_train_ds )
0 commit comments