Skip to content

Commit 58b831e

Browse files
committed
:wrench move inherent model params where they are instantiated and fixed padding only for sliding window inference
1 parent 43807ec commit 58b831e

File tree

4 files changed

+18
-17
lines changed

4 files changed

+18
-17
lines changed

napari_cellseg3d/model_workers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
# Qt
4646
from qtpy.QtCore import Signal
4747

48-
4948
from napari_cellseg3d import utils
5049
from napari_cellseg3d import log_utility
5150

@@ -231,6 +230,7 @@ def __init__(
231230
self.instance_params = instance
232231
self.use_window = use_window
233232
self.window_infer_size = window_infer_size
233+
self.window_overlap_percentage = 0.8,
234234
self.keep_on_cpu = keep_on_cpu
235235
self.stats_to_csv = stats_csv
236236
############################################
@@ -399,8 +399,10 @@ def model_output(
399399

400400
if self.use_window:
401401
window_size = self.window_infer_size
402+
window_overlap = self.window_overlap_percentage
402403
else:
403404
window_size = None
405+
window_overlap = 0.25
404406

405407
outputs = sliding_window_inference(
406408
inputs,
@@ -409,6 +411,7 @@ def model_output(
409411
predictor=model_output,
410412
sw_device=self.device,
411413
device=dataset_device,
414+
overlap=window_overlap,
412415
)
413416

414417
out = outputs.detach().cpu()
@@ -1029,9 +1032,6 @@ def train(self):
10291032
print(f"Size of image : {size}")
10301033
model = model_class.get_net()(
10311034
img_size=utils.get_padding_dim(size),
1032-
in_channels=1,
1033-
out_channels=1,
1034-
feature_size=48,
10351035
use_checkpoint=True,
10361036
)
10371037
else:

napari_cellseg3d/models/model_SegResNet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from monai.networks.nets import SegResNetVAE
22

33

4-
def get_net():
5-
return SegResNetVAE
4+
def get_net(input_image_size, dropout_prob=None):
5+
return SegResNetVAE(input_image_size, out_channels=1, dropout_prob=dropout_prob)
66

77

88
def get_weights_file():

napari_cellseg3d/models/model_SwinUNetR.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1+
import torch
12
from monai.networks.nets import SwinUNETR
23

34

45
def get_weights_file():
56
return ""
67

78

8-
def get_net():
9-
return SwinUNETR
9+
def get_net(img_size, use_checkpoint=True):
10+
return SwinUNETR(img_size, in_channels=1, out_channels=1, feature_size=48, use_checkpoint=use_checkpoint)
1011

1112

1213
def get_output(model, input):
1314
out = model(input)
14-
return out
15+
return torch.sigmoid(out)
1516

1617

1718
def get_validation(model, val_inputs):

napari_cellseg3d/plugin_model_inference.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
100100
######################
101101
######################
102102
# TODO : better way to handle SegResNet size reqs ?
103-
self.segres_size = ui.IntIncrementCounter(min=1, max=1024, default=128)
103+
self.model_input_size = ui.IntIncrementCounter(min=1, max=1024, default=128)
104104
self.model_choice.currentIndexChanged.connect(
105-
self.toggle_display_segres_size
105+
self.toggle_display_model_input_size
106106
)
107107
self.model_choice.setCurrentIndex(0)
108108

@@ -232,7 +232,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
232232
self.show_original_checkbox.setToolTip(
233233
"Displays the image used for inference in the viewer"
234234
)
235-
self.segres_size.setToolTip(
235+
self.model_input_size.setToolTip(
236236
"Image size on which the model has been trained (default : 128)"
237237
)
238238

@@ -302,14 +302,14 @@ def check_ready(self):
302302
warnings.warn("Image and label paths are not correctly set")
303303
return False
304304

305-
def toggle_display_segres_size(self):
305+
def toggle_display_model_input_size(self):
306306
if (
307307
self.model_choice.currentText() == "SegResNet"
308308
or self.model_choice.currentText() == "SwinUNetR"
309309
):
310-
self.segres_size.setVisible(True)
310+
self.model_input_size.setVisible(True)
311311
else:
312-
self.segres_size.setVisible(False)
312+
self.model_input_size.setVisible(False)
313313

314314
def toggle_display_number(self):
315315
"""Shows the choices for viewing results depending on whether :py:attr:`self.view_checkbox` is checked"""
@@ -418,7 +418,7 @@ def build(self):
418418
self.model_choice,
419419
self.custom_weights_choice,
420420
self.weights_path_container,
421-
self.segres_size,
421+
self.model_input_size,
422422
],
423423
)
424424
self.weights_path_container.setVisible(False)
@@ -576,7 +576,7 @@ def start(self, on_layer=False):
576576
model_dict = { # gather model info
577577
"name": model_key,
578578
"class": self.get_model(model_key),
579-
"segres_size": self.segres_size.value(),
579+
"model_input_size": self.model_input_size.value(),
580580
}
581581

582582
if self.custom_weights_choice.isChecked():

0 commit comments

Comments
 (0)