@@ -767,39 +767,39 @@ def train_custom(
767767 # TensorboardLogger -> closes writer
768768 self .dispatch ("_training_finally" )
769769
770- # test best model if test data is present
771- if self .corpus .test and not train_with_test :
772- log_line (log )
770+ # test best model if test data is present
771+ if self .corpus .test and not train_with_test :
772+ log_line (log )
773773
774- self .model .eval ()
774+ self .model .eval ()
775775
776- if (base_path / "best-model.pt" ).exists ():
777- log .info ("Loading model from best epoch ..." )
778- self .model .load_state_dict (self .model .load (base_path / "best-model.pt" ).state_dict ())
779- else :
780- log .info ("Testing using last state of model ..." )
781-
782- test_results = self .model .evaluate (
783- self .corpus .test ,
784- gold_label_type = self .model .label_type ,
785- mini_batch_size = eval_batch_size ,
786- out_path = base_path / "test.tsv" ,
787- embedding_storage_mode = "none" ,
788- main_evaluation_metric = main_evaluation_metric ,
789- gold_label_dictionary = gold_label_dictionary_for_eval ,
790- exclude_labels = exclude_labels ,
791- return_loss = False ,
792- )
776+ if (base_path / "best-model.pt" ).exists ():
777+ log .info ("Loading model from best epoch ..." )
778+ self .model .load_state_dict (self .model .load (base_path / "best-model.pt" ).state_dict ())
779+ else :
780+ log .info ("Testing using last state of model ..." )
781+
782+ test_results = self .model .evaluate (
783+ self .corpus .test ,
784+ gold_label_type = self .model .label_type ,
785+ mini_batch_size = eval_batch_size ,
786+ out_path = base_path / "test.tsv" ,
787+ embedding_storage_mode = "none" ,
788+ main_evaluation_metric = main_evaluation_metric ,
789+ gold_label_dictionary = gold_label_dictionary_for_eval ,
790+ exclude_labels = exclude_labels ,
791+ return_loss = False ,
792+ )
793793
794- log .info (test_results .detailed_results )
795- log_line (log )
794+ log .info (test_results .detailed_results )
795+ log_line (log )
796796
797- # get and return the final test score of best model
798- self .return_values ["test_score" ] = test_results .main_score
797+ # get and return the final test score of best model
798+ self .return_values ["test_score" ] = test_results .main_score
799799
800- else :
801- self .return_values ["test_score" ] = 0
802- log .info ("Test data not provided setting final score to 0" )
800+ else :
801+ self .return_values ["test_score" ] = 0
802+ log .info ("Test data not provided setting final score to 0" )
803803
804804 # MetricHistoryPlugin -> stores the loss history in return_values
805805 self .dispatch ("after_training" )
0 commit comments