Skip to content

Commit 55d2278

Browse files
C-Achardvidalmaxime
authored andcommitted
UI overhaul + overlap parameter
- Added overlap parameter for window - Improved UI code slightly
1 parent 09bdef0 commit 55d2278

File tree

4 files changed

+105
-20
lines changed

4 files changed

+105
-20
lines changed

napari_cellseg3d/interface.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import Optional
2-
from typing import Union
32
from typing import List
43

54

@@ -61,6 +60,13 @@ def toggle_visibility(checkbox, widget):
6160
widget.setVisible(checkbox.isChecked())
6261

6362

63+
def add_label(widget, label, label_before=True, horizontal=True):
64+
if label_before:
65+
return combine_blocks(widget, label, horizontal=horizontal)
66+
else:
67+
return combine_blocks(label, widget, horizontal=horizontal)
68+
69+
6470
class Button(QPushButton):
6571
"""Class for a button with a title and connected to a function when clicked. Inherits from QPushButton.
6672
@@ -494,20 +500,33 @@ def __init__(
494500
step=1,
495501
parent: Optional[QWidget] = None,
496502
fixed: Optional[bool] = True,
503+
label: Optional[str] = None,
497504
):
498505
"""Args:
499506
min (Optional[float]): minimum value, defaults to 0
500507
max (Optional[float]): maximum value, defaults to 10
501508
default (Optional[float]): default value, defaults to 0
502509
step (Optional[float]): step value, defaults to 1
503510
parent: parent widget, defaults to None
504-
fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed"""
511+
fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed
512+
label (Optional[str]): if provided, creates a label with the chosen title to use with the counter"""
505513

506514
super().__init__(parent)
507515
set_spinbox(self, min, max, default, step, fixed)
508516

517+
if label is not None:
518+
self.label = make_label(name=label)
519+
520+
# def setToolTip(self, a0: str) -> None:
521+
# self.setToolTip(a0)
522+
# if self.label is not None:
523+
# self.label.setToolTip(a0)
524+
525+
def get_with_label(self, horizontal=True):
526+
return add_label(self, self.label, horizontal=horizontal)
527+
509528
def set_precision(self, decimals):
510-
"""Sets the precision of the box to the speicifed number of decimals"""
529+
"""Sets the precision of the box to the specified number of decimals"""
511530
self.setDecimals(decimals)
512531

513532
@classmethod
@@ -535,6 +554,7 @@ def __init__(
535554
step=1,
536555
parent: Optional[QWidget] = None,
537556
fixed: Optional[bool] = True,
557+
label: Optional[str] = None,
538558
):
539559
"""Args:
540560
min (Optional[int]): minimum value, defaults to 0
@@ -546,6 +566,9 @@ def __init__(
546566

547567
super().__init__(parent)
548568
set_spinbox(self, min, max, default, step, fixed)
569+
self.label = None
570+
if label is not None:
571+
self.label = make_label(label, self)
549572

550573
@classmethod
551574
def make_n(

napari_cellseg3d/model_framework.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from napari_cellseg3d.log_utility import Log
1515
from napari_cellseg3d.models import model_SegResNet as SegResNet
1616
from napari_cellseg3d.models import model_SwinUNetR as SwinUNetR
17+
1718
# from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP
1819
from napari_cellseg3d.models import model_VNet as VNet
1920
from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS

napari_cellseg3d/model_workers.py

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ def __init__(self):
164164
super().__init__()
165165

166166

167+
# TODO : use dataclass for config instead ?
168+
169+
167170
class InferenceWorker(GeneratorWorker):
168171
"""A custom worker to run inference jobs in.
169172
Inherits from :py:class:`napari.qt.threading.GeneratorWorker`"""
@@ -179,6 +182,7 @@ def __init__(
179182
instance,
180183
use_window,
181184
window_infer_size,
185+
window_overlap,
182186
keep_on_cpu,
183187
stats_csv,
184188
images_filepaths=None,
@@ -230,7 +234,7 @@ def __init__(
230234
self.instance_params = instance
231235
self.use_window = use_window
232236
self.window_infer_size = window_infer_size
233-
self.window_overlap_percentage = 0.8,
237+
self.window_overlap_percentage = window_overlap
234238
self.keep_on_cpu = keep_on_cpu
235239
self.stats_to_csv = stats_csv
236240
############################################
@@ -315,17 +319,53 @@ def load_folder(self):
315319
self.log("\nChecking dimensions...")
316320
pad = utils.get_padding_dim(check)
317321

318-
load_transforms = Compose(
319-
[
320-
LoadImaged(keys=["image"]),
321-
# AddChanneld(keys=["image"]), #already done
322-
EnsureChannelFirstd(keys=["image"]),
323-
# Orientationd(keys=["image"], axcodes="PLI"),
324-
# anisotropic_transform,
325-
SpatialPadd(keys=["image"], spatial_size=pad),
326-
EnsureTyped(keys=["image"]),
327-
]
328-
)
322+
dims = self.model_dict["model_input_size"]
323+
324+
if self.model_dict["name"] == "SegResNet":
325+
model = self.model_dict["class"].get_net(
326+
input_image_size=[
327+
dims,
328+
dims,
329+
dims,
330+
]
331+
)
332+
elif self.model_dict["name"] == "SwinUNetR":
333+
model = self.model_dict["class"].get_net(
334+
img_size=[dims, dims, dims],
335+
use_checkpoint=False,
336+
)
337+
else:
338+
model = self.model_dict["class"].get_net()
339+
340+
self.log_parameters()
341+
342+
model.to(self.device)
343+
344+
# print("FILEPATHS PRINT")
345+
# print(self.images_filepaths)
346+
if self.use_window:
347+
load_transforms = Compose(
348+
[
349+
LoadImaged(keys=["image"]),
350+
# AddChanneld(keys=["image"]), #already done
351+
EnsureChannelFirstd(keys=["image"]),
352+
# Orientationd(keys=["image"], axcodes="PLI"),
353+
# anisotropic_transform,
354+
EnsureTyped(keys=["image"]),
355+
]
356+
)
357+
else:
358+
load_transforms = Compose(
359+
[
360+
LoadImaged(keys=["image"]),
361+
# AddChanneld(keys=["image"]), #already done
362+
EnsureChannelFirstd(keys=["image"]),
363+
# Orientationd(keys=["image"], axcodes="PLI"),
364+
# anisotropic_transform,
365+
SpatialPadd(keys=["image"], spatial_size=pad),
366+
EnsureTyped(keys=["image"]),
367+
]
368+
)
329369

330370
self.log("\nLoading dataset...")
331371
inference_ds = Dataset(data=images_dict, transform=load_transforms)

napari_cellseg3d/plugin_model_inference.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
7878
self.keep_on_cpu = False
7979
self.use_window_inference = False
8080
self.window_inference_size = None
81-
self.window_overlap_percentage = None
81+
self.window_overlap = 0.25
8282

8383
###########################
8484
# interface
@@ -131,7 +131,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
131131
T=7, parent=self
132132
)
133133

134-
self.window_infer_box = ui.make_checkbox("Use window inference")
134+
self.window_infer_box = ui.CheckBox(title="Use window inference")
135135
self.window_infer_box.clicked.connect(self.toggle_display_window_size)
136136

137137
sizes_window = ["8", "16", "32", "64", "128", "256", "512"]
@@ -151,9 +151,18 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
151151
)
152152
self.lbl_window_size_choice = self.window_size_choice.label
153153

154-
self.keep_data_on_cpu_box = ui.make_checkbox("Keep data on CPU")
154+
self.window_overlap_counter = ui.DoubleIncrementCounter(
155+
min=0,
156+
max=1,
157+
default=0.25,
158+
step=0.05,
159+
parent=self,
160+
label="Overlap %",
161+
)
155162

156-
self.window_infer_params = ui.combine_blocks(
163+
self.keep_data_on_cpu_box = ui.CheckBox(title="Keep data on CPU")
164+
165+
window_size_widgets = ui.combine_blocks(
157166
self.window_size_choice,
158167
self.lbl_window_size_choice,
159168
horizontal=False,
@@ -164,6 +173,12 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
164173
# horizontal=False,
165174
# )
166175

176+
self.window_infer_params = ui.combine_blocks(
177+
window_size_widgets,
178+
self.window_overlap_counter.get_with_label(horizontal=False),
179+
horizontal=False,
180+
)
181+
167182
##################
168183
##################
169184
# instance segmentation widgets
@@ -252,6 +267,10 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
252267
"Size of the window to run inference with (in pixels)"
253268
)
254269

270+
self.window_overlap_counter.setToolTip(
271+
"Percentage of overlap between windows to use when using sliding window"
272+
)
273+
255274
# self.window_overlap.setToolTip(
256275
# "Amount of overlap between sliding windows"
257276
# )
@@ -625,7 +644,7 @@ def start(self, on_layer=False):
625644
self.window_inference_size = int(
626645
self.window_size_choice.currentText()
627646
)
628-
# self.window_overlap_percentage = self.window_overlap.value()
647+
self.window_overlap = self.window_overlap_counter.value()
629648

630649
if not on_layer:
631650
self.worker = InferenceWorker(
@@ -639,6 +658,7 @@ def start(self, on_layer=False):
639658
instance=self.instance_params,
640659
use_window=self.use_window_inference,
641660
window_infer_size=self.window_inference_size,
661+
window_overlap=self.window_overlap,
642662
keep_on_cpu=self.keep_on_cpu,
643663
stats_csv=self.stats_to_csv,
644664
)
@@ -655,6 +675,7 @@ def start(self, on_layer=False):
655675
use_window=self.use_window_inference,
656676
window_infer_size=self.window_inference_size,
657677
keep_on_cpu=self.keep_on_cpu,
678+
window_overlap=self.window_overlap,
658679
stats_csv=self.stats_to_csv,
659680
layer=layer,
660681
)

0 commit comments

Comments
 (0)