@@ -273,6 +273,15 @@ def __init__(
273273 "Transfer weights" , self .toggle_transfer_param
274274 )
275275
276+ self .use_deterministic_choice = ui .make_checkbox (
277+ "Deterministic training" , func = self .toggle_deterministic_param
278+ )
279+ self .box_seed = ui .make_n_spinboxes (max = 10000000 , default = 23498 )
280+ self .lbl_seed = ui .make_label ("Seed" , self )
281+ self .container_seed = ui .combine_blocks (
282+ self .box_seed , self .lbl_seed , horizontal = False
283+ )
284+
276285 self .progress = QProgressBar ()
277286 self .progress .setVisible (False )
278287 """Dock widget containing the progress bar"""
@@ -304,6 +313,12 @@ def toggle_transfer_param(self):
304313 else :
305314 self .custom_weights_choice .setVisible (False )
306315
316+ def toggle_deterministic_param (self ):
317+ if self .use_deterministic_choice .isChecked ():
318+ self .container_seed .setVisible (True )
319+ else :
320+ self .container_seed .setVisible (False )
321+
307322 def check_ready (self ):
308323 """
309324 Checks that the paths to the images and labels are correctly set
@@ -595,6 +610,22 @@ def build(self):
595610 train_tab_layout .addWidget (train_param_group_w )
596611 # end of training params group
597612 ##################
613+ ui .add_blank (self , train_tab_layout )
614+ ##################
615+ # deterministic choice group
616+ seed_w , seed_l = ui .make_group (
617+ "Deterministic training" , R = 1 , B = 5 , T = 11
618+ )
619+
620+ seed_l .addWidget (self .use_deterministic_choice , alignment = ui .LEFT_AL )
621+ seed_l .addWidget (self .container_seed , alignment = ui .LEFT_AL )
622+ self .container_seed .setVisible (False )
623+
624+ seed_w .setLayout (seed_l )
625+ train_tab_layout .addWidget (seed_w )
626+
627+ # end of deterministic choice group
628+ ##################
598629 # buttons
599630
600631 ui .add_blank (self , train_tab_layout )
@@ -710,7 +741,15 @@ def start(self):
710741 self .data = self .create_train_dataset_dict ()
711742 self .max_epochs = self .epoch_choice .value ()
712743
713- self .learning_rate = self .learning_rate_dict [self .learning_rate_choice .currentText ()]
744+ self .learning_rate = self .learning_rate_dict [
745+ self .learning_rate_choice .currentText ()
746+ ]
747+
748+ seed_dict = {
749+ "use deterministic" : self .use_deterministic_choice .isChecked (),
750+ "seed" : self .box_seed .value (),
751+ }
752+
714753
715754 self .patch_size = []
716755 [
@@ -722,10 +761,9 @@ def start(self):
722761 "class" : self .get_model (self .model_choice .currentText ()),
723762 "name" : self .model_choice .currentText (),
724763 }
725-
726764 self .results_path = (
727- self .results_path
728- + f"/{ model_dict ['name' ]} _results_{ self . start_time } "
765+ self .results_path
766+ + f"/{ model_dict ['name' ]} _results_{ utils . get_date_time () } "
729767 )
730768 os .makedirs (
731769 self .results_path , exist_ok = False
@@ -758,6 +796,7 @@ def start(self):
758796 num_samples = self .num_samples ,
759797 sample_size = self .patch_size ,
760798 do_augmentation = self .augment_choice .isChecked (),
799+ deterministic = seed_dict ,
761800 )
762801
763802 [btn .setVisible (False ) for btn in self .close_buttons ]
@@ -826,7 +865,7 @@ def on_finish(self):
826865 def on_error (self ):
827866 """Catches errored signal from worker"""
828867 self .log .print_and_log (f"WORKER ERRORED at { utils .get_time ()} " )
829- self .worker = None
868+ self .worker = None
830869 self .empty_cuda_cache ()
831870 # self.clean_cache()
832871
0 commit comments