Skip to content

Commit dc778ee

Browse files
committed
added deterministic
1 parent 1ab8575 commit dc778ee

File tree

2 files changed

+66
-6
lines changed

2 files changed

+66
-6
lines changed

src/napari_cellseg3d/model_workers.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
from monai.transforms import RandSpatialCropSamplesd
3030
from monai.transforms import SpatialPadd
3131
from monai.transforms import Zoom
32+
from monai.utils import set_determinism
33+
34+
# threads
3235
from napari.qt.threading import GeneratorWorker
3336
from napari.qt.threading import WorkerBaseSignals
3437

@@ -388,6 +391,7 @@ def __init__(
388391
num_samples,
389392
sample_size,
390393
do_augmentation,
394+
deterministic,
391395
):
392396
"""Initializes a worker for inference with the arguments needed by the :py:func:`~train` function.
393397
@@ -420,6 +424,8 @@ def __init__(
420424
421425
* do_augmentation : whether to perform data augmentation or not
422426
427+
* deterministic : dict with "use deterministic" : bool, whether to use deterministic training, "seed": seed for RNG
428+
423429
Note: See :py:func:`~train`
424430
"""
425431

@@ -443,6 +449,7 @@ def __init__(
443449
self.sample_size = sample_size
444450

445451
self.do_augment = do_augmentation
452+
self.seed_dict = deterministic
446453

447454
def log(self, text):
448455
"""Sends a signal that ``text`` should be logged
@@ -454,7 +461,12 @@ def log(self, text):
454461

455462
def log_parameters(self):
456463

457-
self.log("\nParameters summary :")
464+
self.log("\nParameters summary :\n")
465+
466+
if self.seed_dict["use deterministic"]:
467+
self.log(f"Deterministic training is enabled")
468+
self.log(f"Seed is {self.seed_dict['seed']}")
469+
458470
self.log(f"Training for {self.max_epochs} epochs")
459471
self.log(f"Loss function is : {str(self.loss_function)}")
460472
self.log(f"Validation is performed every {self.val_interval} epochs")
@@ -474,6 +486,8 @@ def log_parameters(self):
474486
if self.weights_path is not None:
475487
self.log(f"Using weights from : {self.weights_path}")
476488

489+
self.log("\n")
490+
477491
def train(self):
478492
"""Trains the Pytorch model for the given number of epochs, with the selected model and data,
479493
using the chosen batch size, validation interval, loss function, and number of samples.
@@ -508,13 +522,20 @@ def train(self):
508522
* sample_size : the size of the patches to extract when sampling
509523
510524
* do_augmentation : whether to perform data augmentation or not
525+
526+
* deterministic : dict with "use deterministic" : bool, whether to use deterministic training, "seed": seed for RNG
511527
"""
512528

513529
#########################
514530
# error_log = open(results_path +"/error_log.log" % multiprocessing.current_process().name, 'x')
515531
# faulthandler.enable(file=error_log, all_threads=True)
516532
#########################
517533

534+
if self.seed_dict["use deterministic"]:
535+
set_determinism(
536+
seed=self.seed_dict["seed"]
537+
) # use_deterministic_algorithms = True causes cuda error
538+
518539
sys = platform.system()
519540
print(sys)
520541
if sys == "Darwin": # required for macOS ?

src/napari_cellseg3d/plugin_model_training.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,15 @@ def __init__(
273273
"Transfer weights", self.toggle_transfer_param
274274
)
275275

276+
self.use_deterministic_choice = ui.make_checkbox(
277+
"Deterministic training", func=self.toggle_deterministic_param
278+
)
279+
self.box_seed = ui.make_n_spinboxes(max=10000000, default=23498)
280+
self.lbl_seed = ui.make_label("Seed", self)
281+
self.container_seed = ui.combine_blocks(
282+
self.box_seed, self.lbl_seed, horizontal=False
283+
)
284+
276285
self.progress = QProgressBar()
277286
self.progress.setVisible(False)
278287
"""Dock widget containing the progress bar"""
@@ -304,6 +313,12 @@ def toggle_transfer_param(self):
304313
else:
305314
self.custom_weights_choice.setVisible(False)
306315

316+
def toggle_deterministic_param(self):
317+
if self.use_deterministic_choice.isChecked():
318+
self.container_seed.setVisible(True)
319+
else:
320+
self.container_seed.setVisible(False)
321+
307322
def check_ready(self):
308323
"""
309324
Checks that the paths to the images and labels are correctly set
@@ -595,6 +610,22 @@ def build(self):
595610
train_tab_layout.addWidget(train_param_group_w)
596611
# end of training params group
597612
##################
613+
ui.add_blank(self, train_tab_layout)
614+
##################
615+
# deterministic choice group
616+
seed_w, seed_l = ui.make_group(
617+
"Deterministic training", R=1, B=5, T=11
618+
)
619+
620+
seed_l.addWidget(self.use_deterministic_choice, alignment=ui.LEFT_AL)
621+
seed_l.addWidget(self.container_seed, alignment=ui.LEFT_AL)
622+
self.container_seed.setVisible(False)
623+
624+
seed_w.setLayout(seed_l)
625+
train_tab_layout.addWidget(seed_w)
626+
627+
# end of deterministic choice group
628+
##################
598629
# buttons
599630

600631
ui.add_blank(self, train_tab_layout)
@@ -710,7 +741,15 @@ def start(self):
710741
self.data = self.create_train_dataset_dict()
711742
self.max_epochs = self.epoch_choice.value()
712743

713-
self.learning_rate = self.learning_rate_dict[self.learning_rate_choice.currentText()]
744+
self.learning_rate = self.learning_rate_dict[
745+
self.learning_rate_choice.currentText()
746+
]
747+
748+
seed_dict = {
749+
"use deterministic": self.use_deterministic_choice.isChecked(),
750+
"seed": self.box_seed.value(),
751+
}
752+
714753

715754
self.patch_size = []
716755
[
@@ -722,10 +761,9 @@ def start(self):
722761
"class": self.get_model(self.model_choice.currentText()),
723762
"name": self.model_choice.currentText(),
724763
}
725-
726764
self.results_path = (
727-
self.results_path
728-
+ f"/{model_dict['name']}_results_{self.start_time}"
765+
self.results_path
766+
+ f"/{model_dict['name']}_results_{utils.get_date_time()}"
729767
)
730768
os.makedirs(
731769
self.results_path, exist_ok=False
@@ -758,6 +796,7 @@ def start(self):
758796
num_samples=self.num_samples,
759797
sample_size=self.patch_size,
760798
do_augmentation=self.augment_choice.isChecked(),
799+
deterministic=seed_dict,
761800
)
762801

763802
[btn.setVisible(False) for btn in self.close_buttons]
@@ -826,7 +865,7 @@ def on_finish(self):
826865
def on_error(self):
827866
"""Catches errored signal from worker"""
828867
self.log.print_and_log(f"WORKER ERRORED at {utils.get_time()}")
829-
self.worker=None
868+
self.worker = None
830869
self.empty_cuda_cache()
831870
# self.clean_cache()
832871

0 commit comments

Comments
 (0)