Skip to content

Commit eab113d

Browse files
committed
ignore reqs+warning+window inference+docs
1 parent 75022cd commit eab113d

File tree

11 files changed

+197
-325
lines changed

11 files changed

+197
-325
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ venv/
9090

9191
########
9292
#project specific
93-
#dataset, weights, old logos
93+
#dataset, weights, old logos, requirements
9494
/napari_cellseg3d/models/dataset/
9595
/napari_cellseg3d/models/saved_weights/
9696
/docs/res/logo/old_logo/
97+
/reqs/

docs/res/guides/inference_module_guide.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ Interface and functionalities
4242
| You may also choose to **load custom weights** rather than the pre-trained ones, simply ensure they are **compatible** (e.g. produced from the training module for the same model)
4343
4444

45+
* **Inference parameters** :
46+
47+
| You can choose to use inference on the whole image at once, which generally yields better performance at the cost of more memory,
48+
| or you can use a specific window size to run inference on smaller chunks one by one, for lower memory usage.
49+
| You can also choose to keep the dataset in the RAM rather than the VRAM (cpu vs cuda device) to avoid running out of VRAM
50+
| if you have several images.
51+
52+
4553
* **Anisotropy** :
4654

4755
| If you want to see your results without **anisotropy** when you have anisotropic images, you can specify that you have anisotropic data

napari_cellseg3d/interface.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Union
2+
13
from qtpy.QtCore import Qt
24
from qtpy.QtCore import QUrl
35
from qtpy.QtGui import QDesktopServices
@@ -62,7 +64,7 @@ def add_blank(widget, layout):
6264

6365
def open_file_dialog(
6466
widget,
65-
possible_paths=[""],
67+
possible_paths: list = [""],
6668
load_as_folder: bool = False,
6769
filetype: str = "Image file (*.tif *.tiff)",
6870
):
@@ -159,7 +161,7 @@ def make_n_spinboxes(
159161
parent=None,
160162
double=False,
161163
fixed=True,
162-
):
164+
) -> Union[list, QWidget]:
163165
"""
164166
165167
Args:
@@ -303,7 +305,7 @@ def make_combobox(
303305
"""Creates a dropdown menu with a title and adds specified entries to it
304306
305307
Args:
306-
entries array(str): Entries to add to the dropdown menu. Defaults to None, no entries if None
308+
entries (array(str)): Entries to add to the dropdown menu. Defaults to None, no entries if None
307309
parent (QWidget): parent QWidget to add dropdown menu to. Defaults to None, no parent is set if None
308310
label (str) : if not None, creates a QLabel with the contents of 'label', and returns the label as well
309311
fixed (bool): if True, will set the size policy of the dropdown menu to Fixed in h and w. Defaults to True.
@@ -367,14 +369,22 @@ def make_checkbox(
367369

368370

369371
def combine_blocks(
370-
second, first, min_spacing=0, horizontal=True, l=11, t=3, r=11, b=11
372+
right_or_below,
373+
left_or_above,
374+
min_spacing=0,
375+
horizontal=True,
376+
l=11,
377+
t=3,
378+
r=11,
379+
b=11,
371380
):
372-
"""Combines two QWidget objects and puts them side by side (label on the left and button on the right)
381+
"""Combines two QWidget objects and puts them side by side (first on the left/top and second on the right/bottom depending on "horizontal")
382+
Weird argument names due the initial implementation of it. # TODO maybe fix arg names
373383
374384
Args:
375-
horizontal (bool): whether to stack widgets laterally or horizontally
376-
second (QWidget): Second widget, to be displayed right/below of the label
377-
first (QWidget): First widget, to be added on the left/above of button
385+
horizontal (bool): whether to stack widgets vertically (False) or horizontally (True)
386+
left_or_above (QWidget): First widget, to be added on the left/above of "second"
387+
right_or_below (QWidget): Second widget, to be displayed right/below of "first"
378388
min_spacing (int): Minimum spacing between the two widgets (from the start of label to the start of button)
379389
380390
Returns:
@@ -407,9 +417,9 @@ def combine_blocks(
407417
# temp_layout.setColumnMinimumWidth(1,100)
408418
# temp_layout.setSizeConstraint(QLayout.SetMinAndMaxSize)
409419

410-
temp_layout.addWidget(first, r1, c1) # , alignment=LEFT_AL)
420+
temp_layout.addWidget(left_or_above, r1, c1) # , alignment=LEFT_AL)
411421
# temp_layout.addStretch(100)
412-
temp_layout.addWidget(second, r2, c2) # , alignment=LEFT_AL)
422+
temp_layout.addWidget(right_or_below, r2, c2) # , alignment=LEFT_AL)
413423
temp_widget.setLayout(temp_layout)
414424
return temp_widget
415425

napari_cellseg3d/model_workers.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ def __init__(
8585
filetype,
8686
transforms,
8787
instance,
88+
use_window,
89+
window_infer_size,
90+
keep_on_cpu,
8891
):
8992
"""Initializes a worker for inference with the arguments needed by the :py:func:`~inference` function.
9093
@@ -103,7 +106,13 @@ def __init__(
103106
104107
* transforms: a dict containing transforms to perform at various times.
105108
106-
* instance : a dict containing parameters regarding instance segmentation
109+
* instance: a dict containing parameters regarding instance segmentation
110+
111+
* use_window: use window inference with specific size or whole image
112+
113+
* window_infer_size: size of window if use_window is True
114+
115+
* keep_on_cpu: keep images on CPU or no
107116
108117
Note: See :py:func:`~self.inference`
109118
"""
@@ -121,6 +130,9 @@ def __init__(
121130
self.filetype = filetype
122131
self.transforms = transforms
123132
self.instance_params = instance
133+
self.use_window = use_window
134+
self.window_infer_size = window_infer_size
135+
self.keep_on_cpu = keep_on_cpu
124136

125137
"""These attributes are all arguments of :py:func:~inference, please see that for reference"""
126138

@@ -149,9 +161,25 @@ def log_parameters(self):
149161
f"Thresholding is enabled at {self.transforms['thresh'][1]}"
150162
)
151163

164+
if self.use_window:
165+
status = "enabled"
166+
else:
167+
status="disabled"
168+
169+
self.log(f"Window inference is {status}")
170+
171+
if self.keep_on_cpu:
172+
self.log(f"Dataset loaded to CPU")
173+
else:
174+
self.log(f"Dataset loaded on {self.device}")
175+
176+
if self.instance_params["do_instance"]:
177+
# TODO move instance seg
178+
self.log(f"Instance segmentation enabled")
179+
# self.log(f"")
180+
152181
def inference(self):
153182
"""
154-
155183
Requires:
156184
* device: cuda or cpu device to use for torch
157185
@@ -167,6 +195,12 @@ def inference(self):
167195
168196
* transforms: a dict containing transforms to perform at various times.
169197
198+
* use_window: use window inference with specific size or whole image
199+
200+
* window_infer_size: size of window if use_window is True
201+
202+
* keep_on_cpu: keep images on CPU or no
203+
170204
Yields:
171205
dict: contains :
172206
* "image_id" : index of the returned image
@@ -215,8 +249,8 @@ def inference(self):
215249

216250
model.to(self.device)
217251

218-
print("FILEPATHS PRINT")
219-
print(self.images_filepaths)
252+
# print("FILEPATHS PRINT")
253+
# print(self.images_filepaths)
220254

221255
load_transforms = Compose(
222256
[
@@ -272,18 +306,31 @@ def inference(self):
272306

273307
inputs = inf_data["image"]
274308
# print(inputs.shape)
275-
inputs = inputs.to(self.device)
309+
# TODO figure out sliding window device
310+
inputs = inputs.to("cpu")
276311

277312
model_output = lambda inputs: post_process_transforms(
278313
self.model_dict["class"].get_output(model, inputs)
279314
)
315+
# TODO add more params
316+
317+
if self.keep_on_cpu:
318+
dataset_device = "cpu"
319+
else:
320+
dataset_device = self.device
321+
322+
if self.use_window:
323+
window_size = self.window_infer_size
324+
else:
325+
window_size = None
280326

281327
outputs = sliding_window_inference(
282328
inputs,
283-
roi_size=None,
329+
roi_size=window_size,
284330
sw_batch_size=1,
285331
predictor=model_output,
286-
device=self.device,
332+
sw_device=self.device,
333+
device=dataset_device,
287334
)
288335

289336
out = outputs.detach().cpu()

napari_cellseg3d/plugin_convert.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def build(self):
8383
w, layout = ui.make_container_widget()
8484

8585
results_widget = ui.combine_blocks(
86-
second=self.btn_result_path,
87-
first=self.lbl_result_path,
86+
right_or_below=self.btn_result_path,
87+
left_or_above=self.lbl_result_path,
8888
min_spacing=70,
8989
)
9090

@@ -105,8 +105,8 @@ def build(self):
105105

106106
folder_group_l.addWidget(
107107
ui.combine_blocks(
108-
second=self.btn_label_files,
109-
first=self.lbl_label_files,
108+
right_or_below=self.btn_label_files,
109+
left_or_above=self.lbl_label_files,
110110
min_spacing=70,
111111
),
112112
alignment=ui.LEFT_AL,

napari_cellseg3d/plugin_dock.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def __init__(self, parent: "napari.viewer.Viewer"):
5151
io_panel, io_layout = ui.make_container_widget(vertical=False)
5252
io_layout.addWidget(
5353
ui.combine_blocks(
54-
first=self.button, second=self.time_label, horizontal=True
54+
left_or_above=self.button,
55+
right_or_below=self.time_label,
56+
horizontal=True,
5557
)
5658
) # , alignment=utils.ABS_AL)
5759
io_panel.setLayout(io_layout)

napari_cellseg3d/plugin_metrics.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
class MetricsUtils(BasePluginFolder):
20-
"""Plugin to evaluate metrics between two sets of labels, ground truh and prediction"""
20+
"""Plugin to evaluate metrics between two sets of labels, ground truth and prediction"""
2121

2222
def __init__(self, viewer: "napari.viewer.Viewer", parent):
2323
"""Creates a MetricsUtils widget for computing and plotting dice metrics between labels.
@@ -47,7 +47,9 @@ def __init__(self, viewer: "napari.viewer.Viewer", parent):
4747

4848
self.btn_reset_plot = ui.make_button("Clear plots", self.remove_plots)
4949

50-
self.threshold_box = ui.make_n_spinboxes(min=0.1, max = 1, default=DEFAULT_THRESHOLD,step=0.1, double=True)
50+
self.threshold_box = ui.make_n_spinboxes(
51+
min=0.1, max=1, default=DEFAULT_THRESHOLD, step=0.1, double=True
52+
)
5153

5254
self.btn_result_path.setVisible(False)
5355
self.lbl_result_path.setVisible(False)
@@ -85,8 +87,8 @@ def build(self):
8587

8688
metrics_group_l.addWidget(
8789
ui.combine_blocks(
88-
second=self.btn_image_files,
89-
first=self.lbl_image_files,
90+
right_or_below=self.btn_image_files,
91+
left_or_above=self.lbl_image_files,
9092
min_spacing=70,
9193
),
9294
alignment=ui.LEFT_AL,
@@ -96,8 +98,8 @@ def build(self):
9698

9799
metrics_group_l.addWidget(
98100
ui.combine_blocks(
99-
second=self.btn_label_files,
100-
first=self.lbl_label_files,
101+
right_or_below=self.btn_label_files,
102+
left_or_above=self.lbl_label_files,
101103
min_spacing=70,
102104
),
103105
alignment=ui.LEFT_AL,
@@ -262,9 +264,11 @@ def compute_dice(self):
262264
scores.append(utils.dice_coeff(p, ground))
263265
scores.append(utils.dice_coeff(np.flip(p), ground))
264266
for i in range(3):
265-
scores.append(utils.dice_coeff(np.flip(p, axis=i), ground))
267+
scores.append(
268+
utils.dice_coeff(np.flip(p, axis=i), ground)
269+
)
266270
else:
267-
i=0
271+
i = 0
268272
scores.append(utils.dice_coeff(pred, ground))
269273
# if t < 1:
270274
# for i in range(3):

0 commit comments

Comments
 (0)