4545# Qt
4646from qtpy .QtCore import Signal
4747
48-
4948from napari_cellseg3d import utils
5049from napari_cellseg3d import log_utility
5150
@@ -165,6 +164,9 @@ def __init__(self):
165164 super ().__init__ ()
166165
167166
167+ # TODO : use dataclass for config instead ?
168+
169+
168170class InferenceWorker (GeneratorWorker ):
169171 """A custom worker to run inference jobs in.
170172 Inherits from :py:class:`napari.qt.threading.GeneratorWorker`"""
@@ -180,6 +182,7 @@ def __init__(
180182 instance ,
181183 use_window ,
182184 window_infer_size ,
185+ window_overlap ,
183186 keep_on_cpu ,
184187 stats_csv ,
185188 images_filepaths = None ,
@@ -231,6 +234,7 @@ def __init__(
231234 self .instance_params = instance
232235 self .use_window = use_window
233236 self .window_infer_size = window_infer_size
237+ self .window_overlap_percentage = window_overlap
234238 self .keep_on_cpu = keep_on_cpu
235239 self .stats_to_csv = stats_csv
236240 ############################################
@@ -301,8 +305,6 @@ def log_parameters(self):
301305 f"Probability threshold is { self .instance_params ['threshold' ]:.2f} \n "
302306 f"Objects smaller than { self .instance_params ['size_small' ]} pixels will be removed\n "
303307 )
304- # self.log(f"")
305- # self.log("\n")
306308 self .log ("-" * 20 )
307309
308310 def load_folder (self ):
@@ -313,25 +315,57 @@ def load_folder(self):
313315 data_check = LoadImaged (keys = ["image" ])(images_dict [0 ])
314316
315317 check = data_check ["image" ].shape
316- # TODO remove
317- # z_aniso = 5 / 1.5
318- # if zoom is not None :
319- # pad = utils.get_padding_dim(check, anisotropy_factor=zoom)
320- # else:
318+
321319 self .log ("\n Checking dimensions..." )
322320 pad = utils .get_padding_dim (check )
323321
324- load_transforms = Compose (
325- [
326- LoadImaged (keys = ["image" ]),
327- # AddChanneld(keys=["image"]), #already done
328- EnsureChannelFirstd (keys = ["image" ]),
329- # Orientationd(keys=["image"], axcodes="PLI"),
330- # anisotropic_transform,
331- SpatialPadd (keys = ["image" ], spatial_size = pad ),
332- EnsureTyped (keys = ["image" ]),
333- ]
334- )
322+ # dims = self.model_dict["model_input_size"]
323+ #
324+ # if self.model_dict["name"] == "SegResNet":
325+ # model = self.model_dict["class"].get_net(
326+ # input_image_size=[
327+ # dims,
328+ # dims,
329+ # dims,
330+ # ]
331+ # )
332+ # elif self.model_dict["name"] == "SwinUNetR":
333+ # model = self.model_dict["class"].get_net(
334+ # img_size=[dims, dims, dims],
335+ # use_checkpoint=False,
336+ # )
337+ # else:
338+ # model = self.model_dict["class"].get_net()
339+ #
340+ # self.log_parameters()
341+ #
342+ # model.to(self.device)
343+
344+ # print("FILEPATHS PRINT")
345+ # print(self.images_filepaths)
346+ if self .use_window :
347+ load_transforms = Compose (
348+ [
349+ LoadImaged (keys = ["image" ]),
350+ # AddChanneld(keys=["image"]), #already done
351+ EnsureChannelFirstd (keys = ["image" ]),
352+ # Orientationd(keys=["image"], axcodes="PLI"),
353+ # anisotropic_transform,
354+ EnsureTyped (keys = ["image" ]),
355+ ]
356+ )
357+ else :
358+ load_transforms = Compose (
359+ [
360+ LoadImaged (keys = ["image" ]),
361+ # AddChanneld(keys=["image"]), #already done
362+ EnsureChannelFirstd (keys = ["image" ]),
363+ # Orientationd(keys=["image"], axcodes="PLI"),
364+ # anisotropic_transform,
365+ SpatialPadd (keys = ["image" ], spatial_size = pad ),
366+ EnsureTyped (keys = ["image" ]),
367+ ]
368+ )
335369
336370 self .log ("\n Loading dataset..." )
337371 inference_ds = Dataset (data = images_dict , transform = load_transforms )
@@ -364,19 +398,32 @@ def load_layer(self):
364398
365399 # print(volume.shape)
366400 # print(volume.dtype)
367-
368- load_transforms = Compose (
369- [
370- ToTensor (),
371- # anisotropic_transform,
372- AddChannel (),
373- SpatialPad (spatial_size = pad ),
374- AddChannel (),
375- EnsureType (),
376- ],
377- map_items = False ,
378- log_stats = True ,
379- )
401+ if self .use_window :
402+ load_transforms = Compose (
403+ [
404+ ToTensor (),
405+ # anisotropic_transform,
406+ AddChannel (),
407+ # SpatialPad(spatial_size=pad),
408+ AddChannel (),
409+ EnsureType (),
410+ ],
411+ map_items = False ,
412+ log_stats = True ,
413+ )
414+ else :
415+ load_transforms = Compose (
416+ [
417+ ToTensor (),
418+ # anisotropic_transform,
419+ AddChannel (),
420+ SpatialPad (spatial_size = pad ),
421+ AddChannel (),
422+ EnsureType (),
423+ ],
424+ map_items = False ,
425+ log_stats = True ,
426+ )
380427
381428 self .log ("\n Loading dataset..." )
382429 input_image = load_transforms (volume )
@@ -405,8 +452,10 @@ def model_output(
405452
406453 if self .use_window :
407454 window_size = self .window_infer_size
455+ window_overlap = self .window_overlap_percentage
408456 else :
409457 window_size = None
458+ window_overlap = 0.25
410459
411460 outputs = sliding_window_inference (
412461 inputs ,
@@ -415,6 +464,7 @@ def model_output(
415464 predictor = model_output ,
416465 sw_device = self .device ,
417466 device = dataset_device ,
467+ overlap = window_overlap ,
418468 )
419469
420470 out = outputs .detach ().cpu ()
@@ -508,13 +558,12 @@ def save_image(
508558 )
509559
510560 imwrite (file_path , image )
561+ filename = os .path .split (file_path )[1 ]
511562
512563 if from_layer :
513- self .log (f"\n Layer prediction saved as :" )
564+ self .log (f"\n Layer prediction saved as : { filename } " )
514565 else :
515- self .log (f"\n File n°{ i + 1 } saved as :" )
516- filename = os .path .split (file_path )[1 ]
517- self .log (filename )
566+ self .log (f"\n File n°{ i + 1 } saved as : { filename } " )
518567
519568 def aniso_transform (self , image ):
520569 zoom = self .transforms ["zoom" ][1 ]
@@ -630,9 +679,13 @@ def inference_on_layer(self, image, model, post_process_transforms):
630679
631680 self .save_image (out , from_layer = True )
632681
633- instance_labels , data_dict = self .get_instance_result (out ,from_layer = True )
682+ instance_labels , data_dict = self .get_instance_result (
683+ out , from_layer = True
684+ )
634685
635- return self .create_result_dict (out , instance_labels , from_layer = True , data_dict = data_dict )
686+ return self .create_result_dict (
687+ out , instance_labels , from_layer = True , data_dict = data_dict
688+ )
636689
637690 def inference (self ):
638691 """
@@ -674,29 +727,27 @@ def inference(self):
674727 torch .set_num_threads (1 ) # required for threading on macOS ?
675728 self .log ("Number of threads has been set to 1 for macOS" )
676729
677- # if self.device =="cuda": # TODO : fix mem alloc, this does not work it seems
678- # torch.backends.cudnn.benchmark = False
679-
680- # TODO : better solution than loading first image always ?
681- # data_check = LoadImaged(keys=["image"])(images_dict[0])
682- # print(data)
683- # check = data_check["image"].shape
684- # print(check)
685-
686730 try :
687- dims = self .model_dict ["segres_size" ]
731+ dims = self .model_dict ["model_input_size" ]
732+ self .log (f"MODEL DIMS : { dims } " )
733+ self .log (self .model_dict ["name" ])
688734
689- model = self .model_dict ["class" ].get_net ()
690735 if self .model_dict ["name" ] == "SegResNet" :
691- model = self .model_dict ["class" ].get_net ()(
736+ model = self .model_dict ["class" ].get_net (
692737 input_image_size = [
693738 dims ,
694739 dims ,
695740 dims ,
696741 ], # TODO FIX ! find a better way & remove model-specific code
697- out_channels = 1 ,
698- # dropout_prob=0.3,
699742 )
743+ elif self .model_dict ["name" ] == "SwinUNetR" :
744+ model = self .model_dict ["class" ].get_net (
745+ img_size = [dims , dims , dims ],
746+ use_checkpoint = False ,
747+ )
748+ else :
749+ model = self .model_dict ["class" ].get_net ()
750+ model = model .to (self .device )
700751
701752 self .log_parameters ()
702753
@@ -722,10 +773,7 @@ def inference(self):
722773 AsDiscrete (threshold = t ), EnsureType ()
723774 )
724775
725-
726- self .log (
727- "\n Loading weights..."
728- )
776+ self .log ("\n Loading weights..." )
729777
730778 if self .weights_dict ["custom" ]:
731779 weights = self .weights_dict ["path" ]
@@ -1022,11 +1070,21 @@ def train(self):
10221070 else :
10231071 size = check
10241072 print (f"Size of image : { size } " )
1025- model = model_class .get_net ()(
1073+ model = model_class .get_net (
10261074 input_image_size = utils .get_padding_dim (size ),
10271075 out_channels = 1 ,
10281076 dropout_prob = 0.3 ,
10291077 )
1078+ elif model_name == "SwinUNetR" :
1079+ if self .sampling :
1080+ size = self .sample_size
1081+ else :
1082+ size = check
1083+ print (f"Size of image : { size } " )
1084+ model = model_class .get_net (
1085+ img_size = utils .get_padding_dim (size ),
1086+ use_checkpoint = True ,
1087+ )
10301088 else :
10311089 model = model_class .get_net () # get an instance of the model
10321090 model = model .to (self .device )
0 commit comments