3232from napari_cellseg3d .model_workers import TrainingWorker
3333
3434NUMBER_TABS = 3
35+ DEFAULT_PATCH_SIZE = 60
3536
3637
3738class Trainer (ModelFramework ):
@@ -116,6 +117,7 @@ def __init__(
116117 self .data_path = ""
117118 self .label_path = ""
118119 self .results_path = ""
120+ self .results_path_folder = ""
119121 ######################
120122 ######################
121123 ######################
@@ -252,7 +254,9 @@ def __init__(
252254 ]
253255 """Close buttons list for each tab"""
254256
255- self .patch_size_widgets = ui .make_n_spinboxes (3 , 10 , 1023 , 120 )
257+ self .patch_size_widgets = ui .make_n_spinboxes (
258+ 3 , 10 , 1024 , DEFAULT_PATCH_SIZE
259+ )
256260
257261 self .patch_size_lbl = [
258262 QLabel (f"Size of patch in { axis } :" ) for axis in "xyz"
@@ -651,23 +655,24 @@ def build(self):
651655 ui .make_scrollable (
652656 contained_layout = data_tab_layout ,
653657 containing_widget = data_tab ,
654- min_wh = [100 , 150 ],
658+ min_wh = [200 , 300 ],
655659 ) # , max_wh=[200,1000])
656660 self .addTab (data_tab , "Data" )
657661
658662 ui .make_scrollable (
659663 contained_layout = augment_tab_l ,
660664 containing_widget = augment_tab_w ,
661- min_wh = [100 , 200 ],
665+ min_wh = [200 , 300 ],
662666 )
663667 self .addTab (augment_tab_w , "Augmentation" )
664668
665669 ui .make_scrollable (
666670 contained_layout = train_tab_layout ,
667671 containing_widget = train_tab ,
668- min_wh = [100 , 200 ],
672+ min_wh = [200 , 300 ],
669673 )
670674 self .addTab (train_tab , "Training" )
675+ self .setMinimumSize (220 , 200 )
671676
672677 def show_dialog_lab (self ):
673678 """Shows the dialog to load label files in a path, loads them (see :doc:model_framework) and changes the widget
@@ -760,12 +765,12 @@ def start(self):
760765 "class" : self .get_model (self .model_choice .currentText ()),
761766 "name" : self .model_choice .currentText (),
762767 }
763- self .results_path = (
768+ self .results_path_folder = (
764769 self .results_path
765770 + f"/{ model_dict ['name' ]} _results_{ utils .get_date_time ()} "
766771 )
767772 os .makedirs (
768- self .results_path , exist_ok = False
773+ self .results_path_folder , exist_ok = False
769774 ) # avoid overwrite where possible
770775
771776 if self .use_transfer_choice .isChecked ():
@@ -777,7 +782,7 @@ def start(self):
777782 weights_path = None
778783
779784 self .log .print_and_log (
780- f"Notice : Saving results to : { self .results_path } "
785+ f"Notice : Saving results to : { self .results_path_folder } "
781786 )
782787
783788 self .worker = TrainingWorker (
@@ -790,7 +795,7 @@ def start(self):
790795 learning_rate = self .learning_rate ,
791796 val_interval = self .val_interval ,
792797 batch_size = self .batch_size ,
793- results_path = self .results_path ,
798+ results_path = self .results_path_folder ,
794799 sampling = self .patch_choice .isChecked (),
795800 num_samples = self .num_samples ,
796801 sample_size = self .patch_size ,
@@ -812,11 +817,13 @@ def start(self):
812817 self .worker .errored .connect (self .on_error )
813818
814819 if self .worker .is_running :
820+ self .log .print_and_log ("*" * 20 )
815821 self .log .print_and_log (
816- f"Stop requested at { utils .get_time ()} . \n Waiting for next validation step..."
822+ f"Stop requested at { utils .get_time ()} . \n Waiting for next yielding step..."
817823 )
818824 self .stop_requested = True
819- self .btn_start .setText ("Stopping... Please wait for next saving" )
825+ self .btn_start .setText ("Stopping... Please wait" )
826+ self .log .print_and_log ("*" * 20 )
820827 self .worker .quit ()
821828 else :
822829 self .worker .start ()
@@ -833,32 +840,35 @@ def on_start(self):
833840
834841 def on_finish (self ):
835842 """Catches finished signal from worker"""
843+ self .log .print_and_log ("*" * 20 )
836844 self .log .print_and_log (f"\n Worker finished at { utils .get_time ()} " )
837845
838- self .log .print_and_log (f"Saving last loss plot at { self .results_path } " )
846+ self .log .print_and_log (f"Saving in { self .results_path_folder } " )
847+ self .log .print_and_log (f"Saving last loss plot" )
839848
840849 if self .canvas is not None :
841850 self .canvas .figure .savefig (
842851 (
843- self .results_path
852+ self .results_path_folder
844853 + f"/final_metric_plots_{ utils .get_time_filepath ()} .png"
845854 ),
846855 format = "png" ,
847856 )
848857
849- self .log .print_and_log ("Auto-saving log" )
850- self .save_log ()
858+ self .log .print_and_log ("Saving log" )
859+ self .save_log (spec_path = self . results_path_folder )
851860
852861 self .log .print_and_log ("Done" )
853862 self .log .print_and_log ("*" * 10 )
854863
855864 self .btn_start .setText ("Start" )
856865 [btn .setVisible (True ) for btn in self .close_buttons ]
857866
867+ del self .worker
858868 self .worker = None
869+ self .results_path_folder = ""
859870 self .empty_cuda_cache ()
860871
861- self .results_path = ""
862872 # self.clean_cache() # trying to fix memory leak
863873
864874 def on_error (self ):
@@ -880,13 +890,17 @@ def on_yield(data, widget):
880890 widget .update_loss_plot (data ["losses" ], data ["val_metrics" ])
881891
882892 if widget .stop_requested :
893+ widget .log .print_and_log (
894+ "Saving weights from aborted training in results folder"
895+ )
883896 torch .save (
884897 data ["weights" ],
885898 os .path .join (
886- widget .results_path ,
899+ widget .results_path_folder ,
887900 f"latest_weights_aborted_training_{ utils .get_time_filepath ()} .pth" ,
888901 ),
889902 )
903+ widget .log .print_and_log ("Saving complete" )
890904 widget .stop_requested = False
891905
892906 # def clean_cache(self):
@@ -942,7 +956,7 @@ def plot_loss(self, loss, dice_metric):
942956 )
943957 self .canvas .draw_idle ()
944958
945- plot_path = self .results_path + "/Loss_plots"
959+ plot_path = self .results_path_folder + "/Loss_plots"
946960 os .makedirs (plot_path , exist_ok = True )
947961 if self .canvas is not None :
948962 self .canvas .figure .savefig (
0 commit comments