4343# Qt
4444from qtpy .QtCore import Signal
4545
46-
4746from napari_cellseg3d import utils
4847from napari_cellseg3d import log_utility
4948
@@ -168,20 +167,19 @@ class InferenceWorker(GeneratorWorker):
168167 Inherits from :py:class:`napari.qt.threading.GeneratorWorker`"""
169168
170169 def __init__ (
171- self ,
172- device ,
173- model_dict ,
174- weights_dict ,
175- images_filepaths ,
176- results_path ,
177- filetype ,
178- transforms ,
179- instance ,
180- use_window ,
181- window_infer_size ,
182- window_overlap_percentage ,
183- keep_on_cpu ,
184- stats_csv ,
170+ self ,
171+ device ,
172+ model_dict ,
173+ weights_dict ,
174+ images_filepaths ,
175+ results_path ,
176+ filetype ,
177+ transforms ,
178+ instance ,
179+ use_window ,
180+ window_infer_size ,
181+ keep_on_cpu ,
182+ stats_csv ,
185183 ):
186184 """Initializes a worker for inference with the arguments needed by the :py:func:`~inference` function.
187185
@@ -206,8 +204,6 @@ def __init__(
206204
207205 * window_infer_size: size of window if use_window is True
208206
209- * window_overlap_percentage: overlap of sliding windows if use_window is True
210-
211207 * keep_on_cpu: keep images on CPU or no
212208
213209 * stats_csv: compute stats on cells and save them to a csv file
@@ -231,7 +227,7 @@ def __init__(
231227 self .instance_params = instance
232228 self .use_window = use_window
233229 self .window_infer_size = window_infer_size
234- self .window_overlap_percentage = window_overlap_percentage
230+ self .window_overlap_percentage = 0.8 ,
235231 self .keep_on_cpu = keep_on_cpu
236232 self .stats_to_csv = stats_csv
237233 """These attributes are all arguments of :py:func:~inference, please see that for reference"""
@@ -343,36 +339,25 @@ def inference(self):
343339 # if self.device =="cuda": # TODO : fix mem alloc, this does not work it seems
344340 # torch.backends.cudnn.benchmark = False
345341
346- # TODO : better solution than loading first image always ?
342+ self . log ( " \n Checking dimensions..." )
347343 data_check = LoadImaged (keys = ["image" ])(images_dict [0 ])
348- # print(data)
349344 check = data_check ["image" ].shape
350- # print(check)
351- # TODO remove
352- # z_aniso = 5 / 1.5
353- # if zoom is not None :
354- # pad = utils.get_padding_dim(check, anisotropy_factor=zoom)
355- # else:
356- self .log ("\n Checking dimensions..." )
357- dims = self .model_dict ["segres_size" ]
345+ pad = utils .get_padding_dim (check )
346+
347+ dims = self .model_dict ["model_input_size" ]
358348
359349 model = self .model_dict ["class" ].get_net ()
360350 if self .model_dict ["name" ] == "SegResNet" :
361- model = self .model_dict ["class" ].get_net ()(
351+ model = self .model_dict ["class" ].get_net (
362352 input_image_size = [
363353 dims ,
364354 dims ,
365355 dims ,
366- ], # TODO FIX ! find a better way & remove model-specific code
367- out_channels = 1 ,
368- # dropout_prob=0.3,
356+ ]
369357 )
370358 elif self .model_dict ["name" ] == "SwinUNetR" :
371- model = self .model_dict ["class" ].get_net ()(
359+ model = self .model_dict ["class" ].get_net (
372360 img_size = [dims , dims , dims ],
373- in_channels = 1 ,
374- out_channels = 1 ,
375- feature_size = 48 ,
376361 use_checkpoint = False ,
377362 )
378363
@@ -382,17 +367,29 @@ def inference(self):
382367
383368 # print("FILEPATHS PRINT")
384369 # print(self.images_filepaths)
385-
386- load_transforms = Compose (
387- [
388- LoadImaged (keys = ["image" ]),
389- # AddChanneld(keys=["image"]), #already done
390- EnsureChannelFirstd (keys = ["image" ]),
391- # Orientationd(keys=["image"], axcodes="PLI"),
392- # anisotropic_transform,
393- EnsureTyped (keys = ["image" ]),
394- ]
395- )
370+ if self .use_window :
371+ load_transforms = Compose (
372+ [
373+ LoadImaged (keys = ["image" ]),
374+ # AddChanneld(keys=["image"]), #already done
375+ EnsureChannelFirstd (keys = ["image" ]),
376+ # Orientationd(keys=["image"], axcodes="PLI"),
377+ # anisotropic_transform,
378+ EnsureTyped (keys = ["image" ]),
379+ ]
380+ )
381+ else :
382+ load_transforms = Compose (
383+ [
384+ LoadImaged (keys = ["image" ]),
385+ # AddChanneld(keys=["image"]), #already done
386+ EnsureChannelFirstd (keys = ["image" ]),
387+ # Orientationd(keys=["image"], axcodes="PLI"),
388+ # anisotropic_transform,
389+ SpatialPadd (keys = ["image" ], spatial_size = pad ),
390+ EnsureTyped (keys = ["image" ]),
391+ ]
392+ )
396393
397394 if not self .transforms ["thresh" ][0 ]:
398395 post_process_transforms = EnsureType ()
@@ -448,16 +445,9 @@ def inference(self):
448445 inputs = inputs .to ("cpu" )
449446 print (inputs .shape )
450447
451- if self .model_dict ["name" ] == "SwinUNetR" :
452- model_output = lambda inputs : post_process_transforms (
453- torch .sigmoid (
454- self .model_dict ["class" ].get_output (model , inputs )
455- )
456- )
457- else :
458- model_output = lambda inputs : post_process_transforms (
459- self .model_dict ["class" ].get_output (model , inputs )
460- )
448+ model_output = lambda inputs : post_process_transforms (
449+ self .model_dict ["class" ].get_output (model , inputs )
450+ )
461451
462452 if self .keep_on_cpu :
463453 dataset_device = "cpu"
@@ -479,7 +469,6 @@ def inference(self):
479469 device = dataset_device ,
480470 overlap = window_overlap ,
481471 )
482- print ("done window infernce" )
483472 out = outputs .detach ().cpu ()
484473 # del outputs # TODO fix memory ?
485474 # outputs = None
@@ -519,14 +508,14 @@ def inference(self):
519508
520509 # File output save name : original-name_model_date+time_number.filetype
521510 file_path = (
522- self .results_path
523- + "/"
524- + f"Prediction_{ image_id } _"
525- + original_filename
526- + "_"
527- + self .model_dict ["name" ]
528- + f"_{ time } _"
529- + self .filetype
511+ self .results_path
512+ + "/"
513+ + f"Prediction_{ image_id } _"
514+ + original_filename
515+ + "_"
516+ + self .model_dict ["name" ]
517+ + f"_{ time } _"
518+ + self .filetype
530519 )
531520
532521 # print(filename)
@@ -567,14 +556,14 @@ def method(image):
567556 instance_labels = method (to_instance )
568557
569558 instance_filepath = (
570- self .results_path
571- + "/"
572- + f"Instance_seg_labels_{ image_id } _"
573- + original_filename
574- + "_"
575- + self .model_dict ["name" ]
576- + f"_{ time } _"
577- + self .filetype
559+ self .results_path
560+ + "/"
561+ + f"Instance_seg_labels_{ image_id } _"
562+ + original_filename
563+ + "_"
564+ + self .model_dict ["name" ]
565+ + f"_{ time } _"
566+ + self .filetype
578567 )
579568
580569 imwrite (instance_filepath , instance_labels )
@@ -617,23 +606,23 @@ class TrainingWorker(GeneratorWorker):
617606 Inherits from :py:class:`napari.qt.threading.GeneratorWorker`"""
618607
619608 def __init__ (
620- self ,
621- device ,
622- model_dict ,
623- weights_path ,
624- data_dicts ,
625- validation_percent ,
626- max_epochs ,
627- loss_function ,
628- learning_rate ,
629- val_interval ,
630- batch_size ,
631- results_path ,
632- sampling ,
633- num_samples ,
634- sample_size ,
635- do_augmentation ,
636- deterministic ,
609+ self ,
610+ device ,
611+ model_dict ,
612+ weights_path ,
613+ data_dicts ,
614+ validation_percent ,
615+ max_epochs ,
616+ loss_function ,
617+ learning_rate ,
618+ val_interval ,
619+ batch_size ,
620+ results_path ,
621+ sampling ,
622+ num_samples ,
623+ sample_size ,
624+ do_augmentation ,
625+ deterministic ,
637626 ):
638627 """Initializes a worker for inference with the arguments needed by the :py:func:`~train` function. Note: See :py:func:`~train`
639628
@@ -841,9 +830,8 @@ def train(self):
841830 else :
842831 size = check
843832 print (f"Size of image : { size } " )
844- model = model_class .get_net ()(
833+ model = model_class .get_net (
845834 input_image_size = utils .get_padding_dim (size ),
846- out_channels = 1 ,
847835 dropout_prob = 0.3 ,
848836 )
849837 elif model_name == "SwinUNetR" :
@@ -852,11 +840,8 @@ def train(self):
852840 else :
853841 size = check
854842 print (f"Size of image : { size } " )
855- model = model_class .get_net ()(
843+ model = model_class .get_net (
856844 img_size = utils .get_padding_dim (size ),
857- in_channels = 1 ,
858- out_channels = 1 ,
859- feature_size = 48 ,
860845 use_checkpoint = True ,
861846 )
862847 else :
@@ -868,10 +853,10 @@ def train(self):
868853
869854 self .train_files , self .val_files = (
870855 self .data_dicts [
871- 0 : int (len (self .data_dicts ) * self .validation_percent )
856+ 0 : int (len (self .data_dicts ) * self .validation_percent )
872857 ],
873858 self .data_dicts [
874- int (len (self .data_dicts ) * self .validation_percent ) :
859+ int (len (self .data_dicts ) * self .validation_percent ):
875860 ],
876861 )
877862
@@ -1032,10 +1017,10 @@ def train(self):
10321017 if self .device .type == "cuda" :
10331018 self .log ("Memory Usage:" )
10341019 alloc_mem = round (
1035- torch .cuda .memory_allocated (0 ) / 1024 ** 3 , 1
1020+ torch .cuda .memory_allocated (0 ) / 1024 ** 3 , 1
10361021 )
10371022 reserved_mem = round (
1038- torch .cuda .memory_reserved (0 ) / 1024 ** 3 , 1
1023+ torch .cuda .memory_reserved (0 ) / 1024 ** 3 , 1
10391024 )
10401025 self .log (f"Allocated: { alloc_mem } GB" )
10411026 self .log (f"Cached: { reserved_mem } GB" )
@@ -1117,7 +1102,7 @@ def train(self):
11171102 yield train_report
11181103
11191104 weights_filename = (
1120- f"{ model_name } _best_metric" + f"_epoch_{ epoch + 1 } .pth"
1105+ f"{ model_name } _best_metric" + f"_epoch_{ epoch + 1 } .pth"
11211106 )
11221107
11231108 if metric > best_metric :
@@ -1158,7 +1143,6 @@ def train(self):
11581143
11591144 # self.close()
11601145
1161-
11621146# def this_is_fine(self):
11631147# import numpy as np
11641148#
0 commit comments