Skip to content

Commit b1be111

Browse files
committed
adding swinunetr & removing extra padding for inference
1 parent e462c4e commit b1be111

File tree

9 files changed

+101
-40
lines changed

9 files changed

+101
-40
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,3 @@ venv/
9999
/napari_cellseg3d/models/saved_weights/
100100
/docs/res/logo/old_logo/
101101
/reqs/
102-

napari_cellseg3d/model_framework.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from napari_cellseg3d import utils
1414
from napari_cellseg3d.log_utility import Log
1515
from napari_cellseg3d.models import model_SegResNet as SegResNet
16+
from napari_cellseg3d.models import model_SwinUNetR as SwinUNetR
1617
from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP
1718
from napari_cellseg3d.models import model_VNet as VNet
1819
from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS
@@ -64,6 +65,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
6465
"SegResNet": SegResNet,
6566
"TRAILMAP": TRAILMAP,
6667
"TRAILMAP_MS": TRAILMAP_MS,
68+
"SwinUNetR": SwinUNetR,
6769
}
6870
"""dict: dictionary of available models, with string for widget display as key
6971

napari_cellseg3d/model_workers.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def __init__(
179179
instance,
180180
use_window,
181181
window_infer_size,
182+
window_overlap_percentage,
182183
keep_on_cpu,
183184
stats_csv,
184185
):
@@ -205,6 +206,8 @@ def __init__(
205206
206207
* window_infer_size: size of window if use_window is True
207208
209+
* window_overlap_percentage: overlap of sliding windows if use_window is True
210+
208211
* keep_on_cpu: keep images on CPU or no
209212
210213
* stats_csv: compute stats on cells and save them to a csv file
@@ -228,6 +231,7 @@ def __init__(
228231
self.instance_params = instance
229232
self.use_window = use_window
230233
self.window_infer_size = window_infer_size
234+
self.window_overlap_percentage = window_overlap_percentage
231235
self.keep_on_cpu = keep_on_cpu
232236
self.stats_to_csv = stats_csv
233237
"""These attributes are all arguments of :py:func:~inference, please see that for reference"""
@@ -350,8 +354,6 @@ def inference(self):
350354
# pad = utils.get_padding_dim(check, anisotropy_factor=zoom)
351355
# else:
352356
self.log("\nChecking dimensions...")
353-
pad = utils.get_padding_dim(check)
354-
# print(pad)
355357
dims = self.model_dict["segres_size"]
356358

357359
model = self.model_dict["class"].get_net()
@@ -365,6 +367,14 @@ def inference(self):
365367
out_channels=1,
366368
# dropout_prob=0.3,
367369
)
370+
elif self.model_dict["name"] == "SwinUNetR":
371+
model = self.model_dict["class"].get_net()(
372+
img_size=[dims, dims, dims],
373+
in_channels=1,
374+
out_channels=1,
375+
feature_size=48,
376+
use_checkpoint=False,
377+
)
368378

369379
self.log_parameters()
370380

@@ -380,7 +390,6 @@ def inference(self):
380390
EnsureChannelFirstd(keys=["image"]),
381391
# Orientationd(keys=["image"], axcodes="PLI"),
382392
# anisotropic_transform,
383-
SpatialPadd(keys=["image"], spatial_size=pad),
384393
EnsureTyped(keys=["image"]),
385394
]
386395
)
@@ -437,10 +446,18 @@ def inference(self):
437446
# print(inputs.shape)
438447

439448
inputs = inputs.to("cpu")
449+
print(inputs.shape)
440450

441-
model_output = lambda inputs: post_process_transforms(
442-
self.model_dict["class"].get_output(model, inputs)
443-
)
451+
if self.model_dict["name"] == "SwinUNetR":
452+
model_output = lambda inputs: post_process_transforms(
453+
torch.sigmoid(
454+
self.model_dict["class"].get_output(model, inputs)
455+
)
456+
)
457+
else:
458+
model_output = lambda inputs: post_process_transforms(
459+
self.model_dict["class"].get_output(model, inputs)
460+
)
444461

445462
if self.keep_on_cpu:
446463
dataset_device = "cpu"
@@ -449,22 +466,24 @@ def inference(self):
449466

450467
if self.use_window:
451468
window_size = self.window_infer_size
469+
window_overlap = self.window_overlap_percentage
452470
else:
453471
window_size = None
454-
472+
window_overlap = 0.25
455473
outputs = sliding_window_inference(
456474
inputs,
457475
roi_size=window_size,
458476
sw_batch_size=1,
459477
predictor=model_output,
460478
sw_device=self.device,
461479
device=dataset_device,
480+
overlap=window_overlap,
462481
)
463-
482+
print("done window infernce")
464483
out = outputs.detach().cpu()
465484
# del outputs # TODO fix memory ?
466485
# outputs = None
467-
486+
print(out.shape)
468487
if self.transforms["zoom"][0]:
469488
zoom = self.transforms["zoom"][1]
470489
anisotropic_transform = Zoom(
@@ -474,9 +493,11 @@ def inference(self):
474493
)
475494
out = anisotropic_transform(out[0])
476495

477-
out = post_process_transforms(out)
496+
# out = post_process_transforms(out)
478497
out = np.array(out).astype(np.float32)
498+
print(out.shape)
479499
out = np.squeeze(out)
500+
print(out.shape)
480501
to_instance = out # avoid post processing since thresholding is done there anyway
481502

482503
# batch_len = out.shape[1]
@@ -825,6 +846,19 @@ def train(self):
825846
out_channels=1,
826847
dropout_prob=0.3,
827848
)
849+
elif model_name == "SwinUNetR":
850+
if self.sampling:
851+
size = self.sample_size
852+
else:
853+
size = check
854+
print(f"Size of image : {size}")
855+
model = model_class.get_net()(
856+
img_size=utils.get_padding_dim(size),
857+
in_channels=1,
858+
out_channels=1,
859+
feature_size=48,
860+
use_checkpoint=True,
861+
)
828862
else:
829863
model = model_class.get_net() # get an instance of the model
830864
model = model.to(self.device)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from monai.networks.nets import SwinUNETR
2+
3+
4+
def get_weights_file():
5+
return ""
6+
7+
8+
def get_net():
9+
return SwinUNETR
10+
11+
12+
def get_output(model, input):
13+
out = model(input)
14+
return out
15+
16+
17+
def get_validation(model, val_inputs):
18+
return model(val_inputs)

napari_cellseg3d/models/model_VNet.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ def get_validation(model, val_inputs):
1919
roi_size = (64, 64, 64)
2020
sw_batch_size = 1
2121
val_outputs = sliding_window_inference(
22-
val_inputs, roi_size, sw_batch_size, model, mode="gaussian"
22+
val_inputs,
23+
roi_size,
24+
sw_batch_size,
25+
model,
26+
mode="gaussian",
27+
overlap=0.7,
2328
)
2429
return val_outputs

napari_cellseg3d/plugin_model_inference.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
7676
self.keep_on_cpu = False
7777
self.use_window_inference = False
7878
self.window_inference_size = None
79+
self.window_overlap_percentage = None
7980

8081
###########################
8182
# interface
@@ -132,6 +133,17 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
132133
self.window_infer_box.clicked.connect(self.toggle_display_window_size)
133134

134135
sizes_window = ["8", "16", "32", "64", "128", "256", "512"]
136+
# (
137+
# self.window_size_choice,
138+
# self.lbl_window_size_choice,
139+
# ) = ui.make_combobox(sizes_window, label="Window size and overlap")
140+
# self.window_overlap = ui.make_n_spinboxes(
141+
# max=1,
142+
# default=0.7,
143+
# step=0.05,
144+
# double=True,
145+
# )
146+
135147
self.window_size_choice = ui.DropdownMenu(
136148
sizes_window, label="Window size"
137149
)
@@ -144,6 +156,11 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
144156
self.lbl_window_size_choice,
145157
horizontal=False,
146158
)
159+
# self.window_infer_params = ui.combine_blocks(
160+
# self.window_overlap,
161+
# self.window_infer_params,
162+
# horizontal=False,
163+
# )
147164

148165
##################
149166
##################
@@ -210,7 +227,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
210227
"Displays the image used for inference in the viewer"
211228
)
212229
self.segres_size.setToolTip(
213-
"Image size on which the SegResNet has been trained (default : 128)"
230+
"Image size on which the model has been trained (default : 128)"
214231
)
215232

216233
thresh_desc = "Thresholding : all values in the image below the chosen probability threshold will be set to 0, and all others to 1."
@@ -224,6 +241,11 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
224241
self.window_size_choice.setToolTip(
225242
"Size of the window to run inference with (in pixels)"
226243
)
244+
245+
# self.window_overlap.setToolTip(
246+
# "Amount of overlap between sliding windows"
247+
# )
248+
227249
self.keep_data_on_cpu_box.setToolTip(
228250
"If enabled, data will be kept on the RAM rather than the VRAM.\nCan avoid out of memory issues with CUDA"
229251
)
@@ -263,7 +285,10 @@ def check_ready(self):
263285
return False
264286

265287
def toggle_display_segres_size(self):
266-
if self.model_choice.currentText() == "SegResNet":
288+
if (
289+
self.model_choice.currentText() == "SegResNet"
290+
or self.model_choice.currentText() == "SwinUNetR"
291+
):
267292
self.segres_size.setVisible(True)
268293
else:
269294
self.segres_size.setVisible(False)
@@ -575,6 +600,7 @@ def start(self):
575600
self.window_inference_size = int(
576601
self.window_size_choice.currentText()
577602
)
603+
# self.window_overlap_percentage = self.window_overlap.value()
578604

579605
self.worker = InferenceWorker(
580606
device=device,
@@ -587,6 +613,7 @@ def start(self):
587613
instance=self.instance_params,
588614
use_window=self.use_window_inference,
589615
window_infer_size=self.window_inference_size,
616+
# window_overlap_percentage=self.window_overlap_percentage,
590617
keep_on_cpu=self.keep_on_cpu,
591618
stats_csv=self.stats_to_csv,
592619
)

napari_cellseg3d/plugin_model_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from napari_cellseg3d.model_workers import TrainingWorker
3232

3333
NUMBER_TABS = 3
34-
DEFAULT_PATCH_SIZE = 60
34+
DEFAULT_PATCH_SIZE = 64
3535

3636

3737
class Trainer(ModelFramework):

napari_cellseg3d/utils.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ def resize(image, zoom_factors):
113113

114114

115115
def align_array_sizes(array_shape, target_shape):
116-
117116
index_differences = []
118117
for i in range(len(target_shape)):
119118
if target_shape[i] != array_shape[i]:
@@ -331,7 +330,6 @@ def fill_list_in_between(lst, n, elem):
331330
Returns :
332331
Filled list
333332
"""
334-
335333
new_list = []
336334
for i in range(len(lst)):
337335
temp_list = [lst[i]]
@@ -605,25 +603,6 @@ def format_Warning(message, category, filename, lineno, line=""):
605603
)
606604

607605

608-
# def dice_coeff(y_true, y_pred):
609-
# smooth = 1.
610-
# y_true_f = y_true.flatten()
611-
# y_pred_f = K.flatten(y_pred)
612-
# intersection = K.sum(y_true_f * y_pred_f)
613-
# score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
614-
# return score
615-
616-
617-
# def dice_loss(y_true, y_pred):
618-
# loss = 1 - dice_coeff(y_true, y_pred)
619-
# return loss
620-
621-
622-
# def bce_dice_loss(y_true, y_pred):
623-
# loss = binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
624-
# return loss
625-
626-
627606
def divide_imgs(images):
628607
H = -(-images.shape[1] // 412)
629608
W = -(-images.shape[2] // 412)

setup.cfg

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,8 @@ install_requires =
5050
tifffile>=2022.2.9
5151
imageio-ffmpeg>=0.4.5
5252
torch>=1.11
53-
monai>=0.9.0
54-
nibabel
55-
scikit-image
53+
monai[nibabel,scikit-image,itk,einops]>=0.9.0
5654
pillow
57-
itk>=5.2.0
5855
matplotlib
5956
vispy>=0.9.6
6057

0 commit comments

Comments
 (0)