11import os
22import warnings
33import torch
4+ from torch import nn
45from pathlib import Path
56
67import matplotlib .pyplot as plt
@@ -174,10 +175,12 @@ def __init__(
174175 """Data dictionary containing file paths"""
175176 self .stop_requested = False
176177 """Whether the worker should stop or not"""
178+ self .start_time = ""
177179
178180 self .loss_dict = {
179181 "Dice loss" : DiceLoss (sigmoid = True ),
180182 "Focal loss" : FocalLoss (),
183+ # "BCELoss":nn.BCELoss(),
181184 "Dice-Focal loss" : DiceFocalLoss (sigmoid = True , lambda_dice = 0.5 ),
182185 "Generalized Dice loss" : GeneralizedDiceLoss (sigmoid = True ),
183186 "DiceCELoss" : DiceCELoss (sigmoid = True , lambda_ce = 0.5 ),
@@ -226,6 +229,7 @@ def __init__(
226229 self .lbl_val_interv_choice = QLabel ("Validation interval : " , self )
227230
228231 self .learning_rate_dict = {
232+ "1e-2" : 1e-2 ,
229233 "1e-3" : 1e-3 ,
230234 "1e-4" : 1e-4 ,
231235 "1e-5" : 1e-5 ,
@@ -238,6 +242,7 @@ def __init__(
238242 ) = ui .make_combobox (
239243 self .learning_rate_dict .keys (), label = "Learning rate"
240244 )
245+ self .learning_rate_choice .setCurrentIndex (1 )
241246
242247 self .augment_choice = ui .make_checkbox ("Augment data" )
243248
@@ -676,7 +681,7 @@ def start(self):
676681 Returns: Returns empty immediately if the file paths are not set correctly.
677682
678683 """
679-
684+ self . start_time = utils . get_time_filepath ()
680685 if self .stop_requested :
681686 self .log .print_and_log ("Worker is already stopping !" )
682687 return
@@ -720,8 +725,11 @@ def start(self):
720725
721726 self .results_path = (
722727 self .results_path
723- + f"/{ model_dict ['name' ]} _results_{ utils . get_date_time () } "
728+ + f"/{ model_dict ['name' ]} _results_{ self . start_time } "
724729 )
730+ os .makedirs (
731+ self .results_path , exist_ok = False
732+ ) # avoid overwrite where possible
725733
726734 if self .use_transfer_choice .isChecked ():
727735 if self .custom_weights_choice .isChecked ():
@@ -731,10 +739,6 @@ def start(self):
731739 else :
732740 weights_path = None
733741
734- os .makedirs (
735- self .results_path , exist_ok = False
736- ) # avoid overwrite where possible
737-
738742 self .log .print_and_log (
739743 f"Notice : Saving results to : { self .results_path } "
740744 )
@@ -782,13 +786,11 @@ def start(self):
782786
783787 def on_start (self ):
784788 """Catches started signal from worker"""
785- if self .plot_dock is not None :
786- self ._viewer .window .remove_dock_widget (self .plot_dock )
787- self .plot_dock = None
788789
790+ self .remove_docked_widgets ()
789791 self .display_status_report ()
790792
791- self .log .print_and_log (f"Worker started at { utils . get_time () } " )
793+ self .log .print_and_log (f"Worker started at { self . start_time } " )
792794 self .log .print_and_log ("\n Worker is running..." )
793795
794796 def on_finish (self ):
@@ -801,7 +803,7 @@ def on_finish(self):
801803 self .canvas .figure .savefig (
802804 (
803805 self .results_path
804- + f"/final_metric_plots_{ utils .get_date_time ()} .png"
806+ + f"/final_metric_plots_{ utils .get_time_filepath ()} .png"
805807 ),
806808 format = "png" ,
807809 )
@@ -817,11 +819,14 @@ def on_finish(self):
817819
818820 self .worker = None
819821 self .empty_cuda_cache ()
822+
823+ self .results_path = ""
820824 # self.clean_cache() # trying to fix memory leak
821825
822826 def on_error (self ):
823827 """Catches errored signal from worker"""
824828 self .log .print_and_log (f"WORKER ERRORED at { utils .get_time ()} " )
829+ self .worker = None
825830 self .empty_cuda_cache ()
826831 # self.clean_cache()
827832
@@ -840,7 +845,7 @@ def on_yield(data, widget):
840845 data ["weights" ],
841846 os .path .join (
842847 widget .results_path ,
843- f"latest_weights_aborted_training_{ utils .get_date_time ()} .pth" ,
848+ f"latest_weights_aborted_training_{ utils .get_time ()} .pth" ,
844849 ),
845850 )
846851 widget .stop_requested = False
0 commit comments