Skip to content

Commit da713e5

Browse files
committed
UI overhaul + overlap parameter
- Added overlap parameter for window - Improved UI code slightly
1 parent a6d3cce commit da713e5

File tree

4 files changed

+61
-12
lines changed

4 files changed

+61
-12
lines changed

napari_cellseg3d/interface.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import Optional
2-
from typing import Union
32
from typing import List
43

54

@@ -61,6 +60,13 @@ def toggle_visibility(checkbox, widget):
6160
widget.setVisible(checkbox.isChecked())
6261

6362

63+
def add_label(widget, label, label_before=True, horizontal=True):
64+
if label_before:
65+
return combine_blocks(widget, label, horizontal=horizontal)
66+
else:
67+
return combine_blocks(label, widget, horizontal=horizontal)
68+
69+
6470
class Button(QPushButton):
6571
"""Class for a button with a title and connected to a function when clicked. Inherits from QPushButton.
6672
@@ -492,20 +498,33 @@ def __init__(
492498
step=1,
493499
parent: Optional[QWidget] = None,
494500
fixed: Optional[bool] = True,
501+
label: Optional[str] = None,
495502
):
496503
"""Args:
497504
min (Optional[int]): minimum value, defaults to 0
498505
max (Optional[int]): maximum value, defaults to 10
499506
default (Optional[int]): default value, defaults to 0
500507
step (Optional[int]): step value, defaults to 1
501508
parent: parent widget, defaults to None
502-
fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed"""
509+
fixed (bool): if True, sets the QSizePolicy of the spinbox to Fixed
510+
label (Optional[str]): if provided, creates a label with the chosen title to use with the counter"""
503511

504512
super().__init__(parent)
505513
set_spinbox(self, min, max, default, step, fixed)
506514

515+
if label is not None:
516+
self.label = make_label(name=label)
517+
518+
# def setToolTip(self, a0: str) -> None:
519+
# self.setToolTip(a0)
520+
# if self.label is not None:
521+
# self.label.setToolTip(a0)
522+
523+
def get_with_label(self, horizontal=True):
524+
return add_label(self, self.label, horizontal=horizontal)
525+
507526
def set_precision(self, decimals):
508-
"""Sets the precision of the box to the speicifed number of decimals"""
527+
"""Sets the precision of the box to the specified number of decimals"""
509528
self.setDecimals(decimals)
510529

511530
@classmethod
@@ -533,6 +552,7 @@ def __init__(
533552
step=1,
534553
parent: Optional[QWidget] = None,
535554
fixed: Optional[bool] = True,
555+
label: Optional[str] = None,
536556
):
537557
"""Args:
538558
min (Optional[int]): minimum value, defaults to 0
@@ -544,6 +564,9 @@ def __init__(
544564

545565
super().__init__(parent)
546566
set_spinbox(self, min, max, default, step, fixed)
567+
self.label = None
568+
if label is not None:
569+
self.label = make_label(label, self)
547570

548571
@classmethod
549572
def make_n(

napari_cellseg3d/model_framework.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from napari_cellseg3d.log_utility import Log
1515
from napari_cellseg3d.models import model_SegResNet as SegResNet
1616
from napari_cellseg3d.models import model_SwinUNetR as SwinUNetR
17+
1718
# from napari_cellseg3d.models import model_TRAILMAP as TRAILMAP
1819
from napari_cellseg3d.models import model_VNet as VNet
1920
from napari_cellseg3d.models import model_TRAILMAP_MS as TRAILMAP_MS

napari_cellseg3d/model_workers.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,9 @@ def __init__(self):
162162
super().__init__()
163163

164164

165+
# TODO : use dataclass for config instead ?
166+
167+
165168
class InferenceWorker(GeneratorWorker):
166169
"""A custom worker to run inference jobs in.
167170
Inherits from :py:class:`napari.qt.threading.GeneratorWorker`"""
@@ -178,6 +181,7 @@ def __init__(
178181
instance,
179182
use_window,
180183
window_infer_size,
184+
window_overlap,
181185
keep_on_cpu,
182186
stats_csv,
183187
):
@@ -227,7 +231,7 @@ def __init__(
227231
self.instance_params = instance
228232
self.use_window = use_window
229233
self.window_infer_size = window_infer_size
230-
self.window_overlap_percentage = 0.8
234+
self.window_overlap_percentage = window_overlap
231235
self.keep_on_cpu = keep_on_cpu
232236
self.stats_to_csv = stats_csv
233237
"""These attributes are all arguments of :py:func:~inference, please see that for reference"""
@@ -346,7 +350,6 @@ def inference(self):
346350

347351
dims = self.model_dict["model_input_size"]
348352

349-
350353
if self.model_dict["name"] == "SegResNet":
351354
model = self.model_dict["class"].get_net(
352355
input_image_size=[
@@ -445,7 +448,7 @@ def inference(self):
445448
# print(inputs.shape)
446449

447450
inputs = inputs.to("cpu")
448-
print(inputs.shape)
451+
# print(inputs.shape)
449452

450453
# self.log("output")
451454
model_output = lambda inputs: post_process_transforms(
@@ -477,7 +480,7 @@ def inference(self):
477480
out = outputs.detach().cpu()
478481
# del outputs # TODO fix memory ?
479482
# outputs = None
480-
print(out.shape)
483+
# print(out.shape)
481484
if self.transforms["zoom"][0]:
482485
zoom = self.transforms["zoom"][1]
483486
anisotropic_transform = Zoom(
@@ -489,9 +492,9 @@ def inference(self):
489492

490493
# out = post_process_transforms(out)
491494
out = np.array(out).astype(np.float32)
492-
print(out.shape)
495+
# print(out.shape)
493496
out = np.squeeze(out)
494-
print(out.shape)
497+
# print(out.shape)
495498
to_instance = out # avoid post processing since thresholding is done there anyway
496499

497500
# batch_len = out.shape[1]

napari_cellseg3d/plugin_model_inference.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
7676
self.keep_on_cpu = False
7777
self.use_window_inference = False
7878
self.window_inference_size = None
79+
self.window_overlap = 0.25
7980

8081
###########################
8182
# interface
@@ -130,7 +131,7 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
130131
T=7, parent=self
131132
)
132133

133-
self.window_infer_box = ui.make_checkbox("Use window inference")
134+
self.window_infer_box = ui.CheckBox(title="Use window inference")
134135
self.window_infer_box.clicked.connect(self.toggle_display_window_size)
135136

136137
sizes_window = ["8", "16", "32", "64", "128", "256", "512"]
@@ -140,14 +141,29 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
140141
)
141142
self.lbl_window_size_choice = self.window_size_choice.label
142143

143-
self.keep_data_on_cpu_box = ui.make_checkbox("Keep data on CPU")
144+
self.window_overlap_counter = ui.DoubleIncrementCounter(
145+
min=0,
146+
max=1,
147+
default=0.25,
148+
step=0.05,
149+
parent=self,
150+
label="Overlap %",
151+
)
144152

145-
self.window_infer_params = ui.combine_blocks(
153+
self.keep_data_on_cpu_box = ui.CheckBox(title="Keep data on CPU")
154+
155+
window_size_widgets = ui.combine_blocks(
146156
self.window_size_choice,
147157
self.lbl_window_size_choice,
148158
horizontal=False,
149159
)
150160

161+
self.window_infer_params = ui.combine_blocks(
162+
window_size_widgets,
163+
self.window_overlap_counter.get_with_label(horizontal=False),
164+
horizontal=False,
165+
)
166+
151167
##################
152168
##################
153169
# instance segmentation widgets
@@ -228,6 +244,10 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
228244
"Size of the window to run inference with (in pixels)"
229245
)
230246

247+
self.window_overlap_counter.setToolTip(
248+
"Percentage of overlap between windows to use when using sliding window"
249+
)
250+
231251
self.keep_data_on_cpu_box.setToolTip(
232252
"If enabled, data will be kept on the RAM rather than the VRAM.\nCan avoid out of memory issues with CUDA"
233253
)
@@ -582,6 +602,7 @@ def start(self):
582602
self.window_inference_size = int(
583603
self.window_size_choice.currentText()
584604
)
605+
self.window_overlap = self.window_overlap_counter.value()
585606

586607
self.worker = InferenceWorker(
587608
device=device,
@@ -594,6 +615,7 @@ def start(self):
594615
instance=self.instance_params,
595616
use_window=self.use_window_inference,
596617
window_infer_size=self.window_inference_size,
618+
window_overlap=self.window_overlap,
597619
keep_on_cpu=self.keep_on_cpu,
598620
stats_csv=self.stats_to_csv,
599621
)

0 commit comments

Comments
 (0)