77from monai .data import CacheDataset
88from monai .data import DataLoader
99from monai .data import Dataset
10+ from monai .data import PatchDataset
1011from monai .data import decollate_batch
1112from monai .data import pad_list_data_collate
12- from monai .data import PatchDataset
1313from monai .inferers import sliding_window_inference
1414from monai .metrics import DiceMetric
1515from monai .transforms import AsDiscrete
2929from monai .transforms import Zoom
3030from napari .qt .threading import GeneratorWorker
3131from napari .qt .threading import WorkerBaseSignals
32-
3332# Qt
3433from qtpy .QtCore import Signal
3534from tifffile import imwrite
@@ -131,6 +130,14 @@ def log(self, text):
131130 """
132131 self .log_signal .emit (text )
133132
133+ def log_parameters (self ):
134+
135+ self .log (f"Model is : { self .model_dict ['name' ]} " )
136+ if self .transforms ["thresh" ][0 ]:
137+ self .log (
138+ f"Thresholding is enabled at { self .transforms ['thresh' ][1 ]} "
139+ )
140+
134141 def inference (self ):
135142 """
136143
@@ -183,6 +190,8 @@ def inference(self):
183190 # dropout_prob=0.3,
184191 )
185192
193+ self .log_parameters ()
194+
186195 model .to (self .device )
187196
188197 print ("FILEPATHS PRINT" )
@@ -389,6 +398,24 @@ def log(self, text):
389398 """
390399 self .log_signal .emit (text )
391400
401+ def log_parameters (self ):
402+
403+ self .log ("\n Parameters summary :" )
404+ self .log (f"Training for { self .max_epochs } epochs" )
405+ self .log (f"Loss function is : { str (self .loss_function )} " )
406+ self .log (f"Validation is performed every { self .val_interval } epochs" )
407+ self .log (f"Batch size is { self .batch_size } " )
408+
409+ if self .sampling :
410+ self .log (
411+ f"Extracting { self .num_samples } patches of size { self .sample_size } "
412+ )
413+ else :
414+ self .log ("Using whole images as dataset" )
415+
416+ if self .do_augment :
417+ self .log ("Data augmentation is enabled" )
418+
392419 def train (self ):
393420 """Trains the Pytorch model for the given number of epochs, with the selected model and data,
394421 using the chosen batch size, validation interval, loss function, and number of samples.
@@ -430,7 +457,6 @@ def train(self):
430457 model_class = self .model_dict ["class" ]
431458
432459 if not self .sampling :
433- self .log ("Sampling is disabled" )
434460 data_check = LoadImaged (keys = ["image" ])(self .data_dicts [0 ])
435461 check = data_check ["image" ].shape
436462
@@ -458,11 +484,17 @@ def train(self):
458484 self .data_dicts [int (len (self .data_dicts ) * 0.9 ) :],
459485 )
460486 print ("Training files :" )
461- [print (f"{ train_file } \n " ) for train_file in train_files ]
487+ [
488+ print (f"{ train_file } \n " )
489+ for train_file in train_files
490+ ]
462491 print ("* " * 20 )
463492 print ("* " * 20 )
464493 print ("Validation files :" )
465- [print (f"{ val_file } \n " ) for val_file in val_files ]
494+ [
495+ print (f"{ val_file } \n " )
496+ for val_file in val_files
497+ ]
466498 # TODO : param patch ROI size
467499
468500 if self .sampling :
@@ -489,7 +521,6 @@ def train(self):
489521 )
490522
491523 if self .do_augment :
492- self .log ("Data augmentation is enabled" )
493524 train_transforms = (
494525 Compose ( # TODO : figure out which ones and values ?
495526 [
@@ -520,14 +551,14 @@ def train(self):
520551 )
521552 # self.log("Loading dataset...\n")
522553 if self .sampling :
523-
554+ print ( "train_ds" )
524555 train_ds = PatchDataset (
525556 data = train_files ,
526557 transform = train_transforms ,
527558 patch_func = sample_loader ,
528559 samples_per_image = self .num_samples ,
529560 )
530-
561+ print ( "val_ds" )
531562 val_ds = PatchDataset (
532563 data = val_files ,
533564 transform = val_transforms ,
@@ -567,7 +598,7 @@ def train(self):
567598 val_loader = DataLoader (
568599 val_ds , batch_size = self .batch_size , num_workers = 4
569600 )
570- # self.log ("\nDone")
601+ print ("\n Done" )
571602
572603 optimizer = torch .optim .Adam (model .parameters (), 1e-3 )
573604 dice_metric = DiceMetric (include_background = True , reduction = "mean" )
@@ -583,6 +614,8 @@ def train(self):
583614 else :
584615 self .log ("Using CPU" )
585616
617+ self .log_parameters ()
618+
586619 for epoch in range (self .max_epochs ):
587620 self .log ("-" * 10 )
588621 self .log (f"Epoch { epoch + 1 } /{ self .max_epochs } " )
0 commit comments