Skip to content

Commit c3e70ee

Browse files
committed
adding swinunetr & removing extra padding for inference
1 parent fde3c71 commit c3e70ee

File tree

9 files changed

+74
-36
lines changed

9 files changed

+74
-36
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,3 @@ venv/
9999
/napari_cellseg3d/models/saved_weights/
100100
/docs/res/logo/old_logo/
101101
/reqs/
102-

napari_cellseg3d/model_framework.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from napari_cellseg3d import utils
1414
from napari_cellseg3d.log_utility import Log
1515
from napari_cellseg3d.models import model_SegResNet as SegResNet
16+
from napari_cellseg3d.models import model_SwinUNetR as SwinUNetR
1617
from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP
1718
from napari_cellseg3d.models import model_VNet as VNet
1819
from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS
@@ -64,6 +65,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
6465
"SegResNet": SegResNet,
6566
"TRAILMAP": TRAILMAP,
6667
"TRAILMAP_MS": TRAILMAP_MS,
68+
"SwinUNetR": SwinUNetR,
6769
}
6870
"""dict: dictionary of available models, with string for widget display as key
6971

napari_cellseg3d/model_workers.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,6 @@ def log_parameters(self):
301301
f"Probability threshold is {self.instance_params['threshold']:.2f}\n"
302302
f"Objects smaller than {self.instance_params['size_small']} pixels will be removed\n"
303303
)
304-
# self.log(f"")
305-
# self.log("\n")
306304
self.log("-" * 20)
307305

308306
def load_folder(self):
@@ -313,11 +311,7 @@ def load_folder(self):
313311
data_check = LoadImaged(keys=["image"])(images_dict[0])
314312

315313
check = data_check["image"].shape
316-
# TODO remove
317-
# z_aniso = 5 / 1.5
318-
# if zoom is not None :
319-
# pad = utils.get_padding_dim(check, anisotropy_factor=zoom)
320-
# else:
314+
321315
self.log("\nChecking dimensions...")
322316
pad = utils.get_padding_dim(check)
323317

@@ -1027,10 +1021,26 @@ def train(self):
10271021
out_channels=1,
10281022
dropout_prob=0.3,
10291023
)
1024+
elif model_name == "SwinUNetR":
1025+
if self.sampling:
1026+
size = self.sample_size
1027+
else:
1028+
size = check
1029+
print(f"Size of image : {size}")
1030+
model = model_class.get_net()(
1031+
img_size=utils.get_padding_dim(size),
1032+
in_channels=1,
1033+
out_channels=1,
1034+
feature_size=48,
1035+
use_checkpoint=True,
1036+
)
10301037
else:
10311038
model = model_class.get_net() # get an instance of the model
10321039
model = model.to(self.device)
10331040

1041+
1042+
1043+
10341044
epoch_loss_values = []
10351045
val_metric_values = []
10361046

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from monai.networks.nets import SwinUNETR
2+
3+
4+
def get_weights_file():
5+
return ""
6+
7+
8+
def get_net():
9+
return SwinUNETR
10+
11+
12+
def get_output(model, input):
13+
out = model(input)
14+
return out
15+
16+
17+
def get_validation(model, val_inputs):
18+
return model(val_inputs)

napari_cellseg3d/models/model_VNet.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ def get_validation(model, val_inputs):
1919
roi_size = (64, 64, 64)
2020
sw_batch_size = 1
2121
val_outputs = sliding_window_inference(
22-
val_inputs, roi_size, sw_batch_size, model, mode="gaussian"
22+
val_inputs,
23+
roi_size,
24+
sw_batch_size,
25+
model,
26+
mode="gaussian",
27+
overlap=0.7,
2328
)
2429
return val_outputs

napari_cellseg3d/plugin_model_inference.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
7878
self.keep_on_cpu = False
7979
self.use_window_inference = False
8080
self.window_inference_size = None
81+
self.window_overlap_percentage = None
8182

8283
###########################
8384
# interface
@@ -134,6 +135,17 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
134135
self.window_infer_box.clicked.connect(self.toggle_display_window_size)
135136

136137
sizes_window = ["8", "16", "32", "64", "128", "256", "512"]
138+
# (
139+
# self.window_size_choice,
140+
# self.lbl_window_size_choice,
141+
# ) = ui.make_combobox(sizes_window, label="Window size and overlap")
142+
# self.window_overlap = ui.make_n_spinboxes(
143+
# max=1,
144+
# default=0.7,
145+
# step=0.05,
146+
# double=True,
147+
# )
148+
137149
self.window_size_choice = ui.DropdownMenu(
138150
sizes_window, label="Window size"
139151
)
@@ -146,6 +158,11 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
146158
self.lbl_window_size_choice,
147159
horizontal=False,
148160
)
161+
# self.window_infer_params = ui.combine_blocks(
162+
# self.window_overlap,
163+
# self.window_infer_params,
164+
# horizontal=False,
165+
# )
149166

150167
##################
151168
##################
@@ -216,7 +233,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
216233
"Displays the image used for inference in the viewer"
217234
)
218235
self.segres_size.setToolTip(
219-
"Image size on which the SegResNet has been trained (default : 128)"
236+
"Image size on which the model has been trained (default : 128)"
220237
)
221238

222239
thresh_desc = (
@@ -234,6 +251,11 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
234251
self.window_size_choice.setToolTip(
235252
"Size of the window to run inference with (in pixels)"
236253
)
254+
255+
# self.window_overlap.setToolTip(
256+
# "Amount of overlap between sliding windows"
257+
# )
258+
237259
self.keep_data_on_cpu_box.setToolTip(
238260
"If enabled, data will be kept on the RAM rather than the VRAM.\nCan avoid out of memory issues with CUDA"
239261
)
@@ -281,7 +303,10 @@ def check_ready(self):
281303
return False
282304

283305
def toggle_display_segres_size(self):
284-
if self.model_choice.currentText() == "SegResNet":
306+
if (
307+
self.model_choice.currentText() == "SegResNet"
308+
or self.model_choice.currentText() == "SwinUNetR"
309+
):
285310
self.segres_size.setVisible(True)
286311
else:
287312
self.segres_size.setVisible(False)
@@ -600,6 +625,7 @@ def start(self, on_layer=False):
600625
self.window_inference_size = int(
601626
self.window_size_choice.currentText()
602627
)
628+
# self.window_overlap_percentage = self.window_overlap.value()
603629

604630
if not on_layer:
605631
self.worker = InferenceWorker(
@@ -724,8 +750,6 @@ def on_yield(data, widget):
724750

725751
zoom = widget.zoom
726752

727-
# print(data["original"].shape)
728-
# print(data["result"].shape)
729753

730754
viewer.dims.ndisplay = 3
731755
viewer.scale_bar.visible = True

napari_cellseg3d/plugin_model_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from napari_cellseg3d.model_workers import TrainingWorker
3232

3333
NUMBER_TABS = 3
34-
DEFAULT_PATCH_SIZE = 60
34+
DEFAULT_PATCH_SIZE = 64
3535

3636

3737
class Trainer(ModelFramework):

napari_cellseg3d/utils.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def resize(image, zoom_factors):
116116

117117

118118
def align_array_sizes(array_shape, target_shape):
119-
120119
index_differences = []
121120
for i in range(len(target_shape)):
122121
if target_shape[i] != array_shape[i]:
@@ -334,7 +333,6 @@ def fill_list_in_between(lst, n, elem):
334333
Returns :
335334
Filled list
336335
"""
337-
338336
new_list = []
339337
for i in range(len(lst)):
340338
temp_list = [lst[i]]
@@ -608,25 +606,6 @@ def format_Warning(message, category, filename, lineno, line=""):
608606
)
609607

610608

611-
# def dice_coeff(y_true, y_pred):
612-
# smooth = 1.
613-
# y_true_f = y_true.flatten()
614-
# y_pred_f = K.flatten(y_pred)
615-
# intersection = K.sum(y_true_f * y_pred_f)
616-
# score = (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)
617-
# return score
618-
619-
620-
# def dice_loss(y_true, y_pred):
621-
# loss = 1 - dice_coeff(y_true, y_pred)
622-
# return loss
623-
624-
625-
# def bce_dice_loss(y_true, y_pred):
626-
# loss = binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
627-
# return loss
628-
629-
630609
def divide_imgs(images):
631610
H = -(-images.shape[1] // 412)
632611
W = -(-images.shape[2] // 412)

setup.cfg

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,13 @@ install_requires =
5050
tifffile>=2022.2.9
5151
imageio-ffmpeg>=0.4.5
5252
torch>=1.11
53+
monai[nibabel,scikit-image,itk,einops]>=0.9.0
5354
tqdm
5455
monai>=0.9.0
5556
nibabel
5657
scikit-image
5758
pillow
58-
itk>=5.2.0
59+
tqdm
5960
matplotlib
6061
vispy>=0.9.6
6162

0 commit comments

Comments
 (0)