@@ -179,6 +179,7 @@ def __init__(
179179 instance ,
180180 use_window ,
181181 window_infer_size ,
182+ window_overlap_percentage ,
182183 keep_on_cpu ,
183184 stats_csv ,
184185 ):
@@ -205,6 +206,8 @@ def __init__(
205206
206207 * window_infer_size: size of window if use_window is True
207208
209+ * window_overlap_percentage: overlap of sliding windows if use_window is True
210+
208211 * keep_on_cpu: keep images on CPU or no
209212
210213 * stats_csv: compute stats on cells and save them to a csv file
@@ -228,6 +231,7 @@ def __init__(
228231 self .instance_params = instance
229232 self .use_window = use_window
230233 self .window_infer_size = window_infer_size
234+ self .window_overlap_percentage = window_overlap_percentage
231235 self .keep_on_cpu = keep_on_cpu
232236 self .stats_to_csv = stats_csv
233237 """These attributes are all arguments of :py:func:~inference, please see that for reference"""
@@ -350,8 +354,6 @@ def inference(self):
350354 # pad = utils.get_padding_dim(check, anisotropy_factor=zoom)
351355 # else:
352356 self .log ("\n Checking dimensions..." )
353- pad = utils .get_padding_dim (check )
354- # print(pad)
355357 dims = self .model_dict ["segres_size" ]
356358
357359 model = self .model_dict ["class" ].get_net ()
@@ -365,6 +367,14 @@ def inference(self):
365367 out_channels = 1 ,
366368 # dropout_prob=0.3,
367369 )
370+ elif self .model_dict ["name" ] == "SwinUNetR" :
371+ model = self .model_dict ["class" ].get_net ()(
372+ img_size = [dims , dims , dims ],
373+ in_channels = 1 ,
374+ out_channels = 1 ,
375+ feature_size = 48 ,
376+ use_checkpoint = False ,
377+ )
368378
369379 self .log_parameters ()
370380
@@ -380,7 +390,6 @@ def inference(self):
380390 EnsureChannelFirstd (keys = ["image" ]),
381391 # Orientationd(keys=["image"], axcodes="PLI"),
382392 # anisotropic_transform,
383- SpatialPadd (keys = ["image" ], spatial_size = pad ),
384393 EnsureTyped (keys = ["image" ]),
385394 ]
386395 )
@@ -437,10 +446,18 @@ def inference(self):
437446 # print(inputs.shape)
438447
439448 inputs = inputs .to ("cpu" )
449+ print (inputs .shape )
440450
441- model_output = lambda inputs : post_process_transforms (
442- self .model_dict ["class" ].get_output (model , inputs )
443- )
451+ if self .model_dict ["name" ] == "SwinUNetR" :
452+ model_output = lambda inputs : post_process_transforms (
453+ torch .sigmoid (
454+ self .model_dict ["class" ].get_output (model , inputs )
455+ )
456+ )
457+ else :
458+ model_output = lambda inputs : post_process_transforms (
459+ self .model_dict ["class" ].get_output (model , inputs )
460+ )
444461
445462 if self .keep_on_cpu :
446463 dataset_device = "cpu"
@@ -449,22 +466,24 @@ def inference(self):
449466
450467 if self .use_window :
451468 window_size = self .window_infer_size
469+ window_overlap = self .window_overlap_percentage
452470 else :
453471 window_size = None
454-
472+ window_overlap = 0.25
455473 outputs = sliding_window_inference (
456474 inputs ,
457475 roi_size = window_size ,
458476 sw_batch_size = 1 ,
459477 predictor = model_output ,
460478 sw_device = self .device ,
461479 device = dataset_device ,
480+ overlap = window_overlap ,
462481 )
463-
482+ print ( "done window infernce" )
464483 out = outputs .detach ().cpu ()
465484 # del outputs # TODO fix memory ?
466485 # outputs = None
467-
486+ print ( out . shape )
468487 if self .transforms ["zoom" ][0 ]:
469488 zoom = self .transforms ["zoom" ][1 ]
470489 anisotropic_transform = Zoom (
@@ -474,9 +493,11 @@ def inference(self):
474493 )
475494 out = anisotropic_transform (out [0 ])
476495
477- out = post_process_transforms (out )
496+ # out = post_process_transforms(out)
478497 out = np .array (out ).astype (np .float32 )
498+ print (out .shape )
479499 out = np .squeeze (out )
500+ print (out .shape )
480501 to_instance = out # avoid post processing since thresholding is done there anyway
481502
482503 # batch_len = out.shape[1]
@@ -825,6 +846,19 @@ def train(self):
825846 out_channels = 1 ,
826847 dropout_prob = 0.3 ,
827848 )
849+ elif model_name == "SwinUNetR" :
850+ if self .sampling :
851+ size = self .sample_size
852+ else :
853+ size = check
854+ print (f"Size of image : { size } " )
855+ model = model_class .get_net ()(
856+ img_size = utils .get_padding_dim (size ),
857+ in_channels = 1 ,
858+ out_channels = 1 ,
859+ feature_size = 48 ,
860+ use_checkpoint = True ,
861+ )
828862 else :
829863 model = model_class .get_net () # get an instance of the model
830864 model = model .to (self .device )
0 commit comments