Skip to content

Commit c375fba

Browse files
authored
Merge pull request #3363 from flairNLP/fix_transformer_smaller_training_vocab_with_best_model
fix trainer final evaluation with tstv & best-model
2 parents ceea719 + 010c0fb commit c375fba

File tree

1 file changed

+28
-28
lines changed

1 file changed

+28
-28
lines changed

flair/trainers/trainer.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -767,39 +767,39 @@ def train_custom(
767767
# TensorboardLogger -> closes writer
768768
self.dispatch("_training_finally")
769769

770-
# test best model if test data is present
771-
if self.corpus.test and not train_with_test:
772-
log_line(log)
770+
# test best model if test data is present
771+
if self.corpus.test and not train_with_test:
772+
log_line(log)
773773

774-
self.model.eval()
774+
self.model.eval()
775775

776-
if (base_path / "best-model.pt").exists():
777-
log.info("Loading model from best epoch ...")
778-
self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict())
779-
else:
780-
log.info("Testing using last state of model ...")
781-
782-
test_results = self.model.evaluate(
783-
self.corpus.test,
784-
gold_label_type=self.model.label_type,
785-
mini_batch_size=eval_batch_size,
786-
out_path=base_path / "test.tsv",
787-
embedding_storage_mode="none",
788-
main_evaluation_metric=main_evaluation_metric,
789-
gold_label_dictionary=gold_label_dictionary_for_eval,
790-
exclude_labels=exclude_labels,
791-
return_loss=False,
792-
)
776+
if (base_path / "best-model.pt").exists():
777+
log.info("Loading model from best epoch ...")
778+
self.model.load_state_dict(self.model.load(base_path / "best-model.pt").state_dict())
779+
else:
780+
log.info("Testing using last state of model ...")
781+
782+
test_results = self.model.evaluate(
783+
self.corpus.test,
784+
gold_label_type=self.model.label_type,
785+
mini_batch_size=eval_batch_size,
786+
out_path=base_path / "test.tsv",
787+
embedding_storage_mode="none",
788+
main_evaluation_metric=main_evaluation_metric,
789+
gold_label_dictionary=gold_label_dictionary_for_eval,
790+
exclude_labels=exclude_labels,
791+
return_loss=False,
792+
)
793793

794-
log.info(test_results.detailed_results)
795-
log_line(log)
794+
log.info(test_results.detailed_results)
795+
log_line(log)
796796

797-
# get and return the final test score of best model
798-
self.return_values["test_score"] = test_results.main_score
797+
# get and return the final test score of best model
798+
self.return_values["test_score"] = test_results.main_score
799799

800-
else:
801-
self.return_values["test_score"] = 0
802-
log.info("Test data not provided setting final score to 0")
800+
else:
801+
self.return_values["test_score"] = 0
802+
log.info("Test data not provided setting final score to 0")
803803

804804
# MetricHistoryPlugin -> stores the loss history in return_values
805805
self.dispatch("after_training")

0 commit comments

Comments
 (0)