Skip to content

Commit 33be7e5

Browse files
authored
add test (#4865)
1 parent 2e487ed commit 33be7e5

File tree

2 files changed

+31
-0
lines changed

2 files changed

+31
-0
lines changed

paddlenlp/experimental/autonlp/text_classification.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,9 @@ def evaluate(self, trial_id=None, eval_dataset=None):
364364
)
365365

366366
eval_metrics = trainer.evaluate()
367+
if os.path.exists(self.training_path):
368+
logger.info(f"Removing {self.training_path} to conserve disk space")
369+
shutil.rmtree(self.training_path)
367370
return eval_metrics
368371

369372
def _compute_metrics(self, eval_preds: EvalPrediction) -> Dict[str, float]:
@@ -506,4 +509,8 @@ def export(self, export_path, trial_id=None):
506509
# save id2label
507510
shutil.copyfile(os.path.join(self.output_dir, "id2label.json"), os.path.join(export_path, "id2label.json"))
508511

512+
if os.path.exists(self.training_path):
513+
logger.info("Removing training checkpoints to conserve disk space")
514+
shutil.rmtree(self.training_path)
515+
509516
logger.info(f"Exported {trial_id} to {export_path}")

tests/experimental/autonlp/test_text_classification.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ def test_multiclass(self, custom_model_candidate, hp_overrides):
126126
result_hp_key = f"config/candidates/{hp_key}"
127127
self.assertEqual(results_df[result_hp_key][0], hp_value)
128128

129+
# test save
130+
self.assertTrue(os.path.exists(os.path.join(auto_trainer.output_dir, "id2label.json")))
131+
save_path = os.path.join(auto_trainer._get_model_result().log_dir, auto_trainer.save_path)
132+
self.assertTrue(os.path.exists(os.path.join(save_path, "model_state.pdparams")))
133+
self.assertTrue(os.path.exists(os.path.join(save_path, "tokenizer_config.json")))
134+
if custom_model_candidate["trainer_type"] == "PromptTrainer":
135+
self.assertTrue(os.path.exists(os.path.join(save_path, "template_config.json")))
136+
self.assertTrue(os.path.exists(os.path.join(save_path, "verbalizer_config.json")))
137+
129138
# test export
130139
temp_export_path = os.path.join(temp_dir_path, "test_export")
131140
auto_trainer.export(export_path=temp_export_path)
@@ -159,6 +168,9 @@ def test_multiclass(self, custom_model_candidate, hp_overrides):
159168
for prediction in test_result["predictions"]:
160169
self.assertIn(prediction["label"], auto_trainer.label2id)
161170

171+
# test training_path
172+
self.assertFalse(os.path.exists(os.path.join(auto_trainer.training_path)))
173+
162174
@parameterized.expand(
163175
[
164176
(finetune_model_candidate, {"TrainingArguments.max_steps": 2}),
@@ -205,6 +217,15 @@ def test_multilabel(self, custom_model_candidate, hp_overrides):
205217
result_hp_key = f"config/candidates/{hp_key}"
206218
self.assertEqual(results_df[result_hp_key][0], hp_value)
207219

220+
# test save
221+
self.assertTrue(os.path.exists(os.path.join(auto_trainer.output_dir, "id2label.json")))
222+
save_path = os.path.join(auto_trainer._get_model_result().log_dir, auto_trainer.save_path)
223+
self.assertTrue(os.path.exists(os.path.join(save_path, "model_state.pdparams")))
224+
self.assertTrue(os.path.exists(os.path.join(save_path, "tokenizer_config.json")))
225+
if custom_model_candidate["trainer_type"] == "PromptTrainer":
226+
self.assertTrue(os.path.exists(os.path.join(save_path, "template_config.json")))
227+
self.assertTrue(os.path.exists(os.path.join(save_path, "verbalizer_config.json")))
228+
208229
# test export
209230
temp_export_path = os.path.join(temp_dir_path, "test_export")
210231
auto_trainer.export(export_path=temp_export_path)
@@ -240,6 +261,9 @@ def test_multilabel(self, custom_model_candidate, hp_overrides):
240261
self.assertIn(prediction["label"], auto_trainer.label2id)
241262
self.assertGreater(prediction["score"], taskflow.task_instance.multilabel_threshold)
242263

264+
# test training_path
265+
self.assertFalse(os.path.exists(os.path.join(auto_trainer.training_path)))
266+
243267
def test_untrained_auto_trainer(self):
244268
with TemporaryDirectory() as temp_dir:
245269
train_ds = copy.deepcopy(self.multi_class_train_ds)

0 commit comments

Comments
 (0)