Skip to content

Commit 066ce68

Browse files
committed
moved instance seg + segres dims in infer + fixed popup in inference + rm test in helper
1 parent dd7d15a commit 066ce68

File tree

7 files changed

+171
-122
lines changed

7 files changed

+171
-122
lines changed

napari_cellseg3d/interface.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def add_to_group(title, widget, layout, L=7, T=20, R=7, B=11):
221221
layout.addWidget(group)
222222

223223

224-
def make_group(title, L=7, T=20, R=7, B=11):
224+
def make_group(title, L=7, T=20, R=7, B=11, parent=None):
225225
"""Creates a group widget and layout, with a header (`title`) and content margins for top/left/right/bottom `L, T, R, B` (in pixels)
226226
Group widget and layout returned will have a Fixed size policy.
227227
@@ -231,8 +231,12 @@ def make_group(title, L=7, T=20, R=7, B=11):
231231
T (int): top margin
232232
R (int): right margin
233233
B (int): bottom margin
234+
parent (QWidget) : parent widget. If None, no parent is set
234235
"""
235-
group = QGroupBox(title)
236+
if parent is None:
237+
group = QGroupBox(title)
238+
else:
239+
group = QGroupBox(title, parent=parent)
236240
group.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
237241
layout = QVBoxLayout()
238242
layout.setContentsMargins(L, T, R, B)
@@ -241,10 +245,11 @@ def make_group(title, L=7, T=20, R=7, B=11):
241245
return group, layout
242246

243247

244-
def make_container(L=0, T=0, R=1, B=11, vertical=True):
248+
def make_container(L=0, T=0, R=1, B=11, vertical=True, parent=None):
245249
"""Creates a QWidget and a layout for the purpose of containing other modules, with a Fixed layout.
246250
247251
Args:
252+
parent : parent widget. If None, no widget is set
248253
L (int): left margin of layout
249254
T (int): top margin of layout
250255
R (int): right margin of layout
@@ -255,7 +260,10 @@ def make_container(L=0, T=0, R=1, B=11, vertical=True):
255260
QWidget : widget that contains the other widgets. Fixed size.
256261
QBoxLayout : H/V Box layout to add contained widgets in. Fixed size.
257262
"""
258-
container_widget = QWidget()
263+
if parent is None:
264+
container_widget = QWidget()
265+
else:
266+
container_widget = QWidget(parent)
259267

260268
if vertical:
261269
container_layout = QVBoxLayout()

napari_cellseg3d/model_workers.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040
from tifffile import imwrite
4141

4242
# local
43+
from napari_cellseg3d.model_instance_seg import (
44+
binary_watershed,
45+
binary_connected,
46+
)
4347
from napari_cellseg3d import utils
4448

4549
"""
@@ -64,6 +68,7 @@ class LogSignal(WorkerBaseSignals):
6468

6569
log_signal = Signal(str)
6670
"""qtpy.QtCore.Signal: signal to be sent when some text should be logged"""
71+
6772
# Should not be an instance variable but a class variable, not defined in __init__, see
6873
# https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect
6974

@@ -155,6 +160,9 @@ def log(self, text):
155160

156161
def log_parameters(self):
157162

163+
self.log("-" * 20)
164+
self.log("Parameters summary :")
165+
158166
self.log(f"Model is : {self.model_dict['name']}")
159167
if self.transforms["thresh"][0]:
160168
self.log(
@@ -173,10 +181,19 @@ def log_parameters(self):
173181
else:
174182
self.log(f"Dataset loaded on {self.device}")
175183

184+
if self.transforms["zoom"][0]:
185+
self.log(
186+
f"Anisotropy parameters are : {self.transforms['zoom'][1]} microns in x,y,z"
187+
)
188+
176189
if self.instance_params["do_instance"]:
177-
# TODO move instance seg
178-
self.log(f"Instance segmentation enabled")
190+
self.log(
191+
f"Instance segmentation enabled, method : {self.instance_params['method']}\n"
192+
f"Probability threshold is {self.instance_params['threshold']:.2f}\n"
193+
f"Objects smaller than {self.instance_params['size_small']} pixels will be removed"
194+
)
179195
# self.log(f"")
196+
self.log("-" * 20)
180197

181198
def inference(self):
182199
"""
@@ -234,8 +251,7 @@ def inference(self):
234251
self.log("\nChecking dimensions...")
235252
pad = utils.get_padding_dim(check)
236253
# print(pad)
237-
dims = 128
238-
# dims = 64 # TODO
254+
dims = self.model_dict["segres_size"]
239255

240256
model = self.model_dict["class"].get_net()
241257
if self.model_dict["name"] == "SegResNet":
@@ -304,7 +320,7 @@ def inference(self):
304320
for i, inf_data in enumerate(inference_loader):
305321

306322
self.log("-" * 10)
307-
self.log(f"Inference started on image n°{i+1}...")
323+
self.log(f"Inference started on image n°{i + 1}...")
308324

309325
inputs = inf_data["image"]
310326
# print(inputs.shape)
@@ -350,6 +366,7 @@ def inference(self):
350366
out = post_process_transforms(out)
351367
out = np.array(out).astype(np.float32)
352368
out = np.squeeze(out)
369+
to_instance = out # avoid post processing since thresholding is done there anyway
353370

354371
# batch_len = out.shape[1]
355372
# print("trying to check len")
@@ -391,8 +408,31 @@ def inference(self):
391408
self.log(
392409
f"\nRunning instance segmentation for image n°{image_id}"
393410
)
394-
method = self.instance_params["method"]
395-
instance_labels = method(out)
411+
412+
threshold = self.instance_params["threshold"]
413+
size_small = self.instance_params["size_small"]
414+
method_name = self.instance_params["method"]
415+
416+
if method_name == "Watershed":
417+
418+
def method(image):
419+
return binary_watershed(
420+
image, threshold, size_small
421+
)
422+
423+
elif method_name == "Connected components":
424+
425+
def method(image):
426+
return binary_connected(
427+
image, threshold, size_small
428+
)
429+
430+
else:
431+
raise NotImplementedError(
432+
"Selected instance segmentation method is not defined"
433+
)
434+
435+
instance_labels = method(to_instance)
396436

397437
instance_filepath = (
398438
self.results_path
@@ -526,10 +566,11 @@ def log(self, text):
526566

527567
def log_parameters(self):
528568

529-
self.log("\nParameters summary :\n")
569+
self.log("-" * 20)
570+
self.log("Parameters summary :\n")
530571

531572
self.log(
532-
f"Percentage of dataset used for validation : {self.validation_percent*100}%"
573+
f"Percentage of dataset used for validation : {self.validation_percent * 100}%"
533574
)
534575
self.log("-" * 10)
535576
self.log("Training files :\n")
@@ -892,7 +933,7 @@ def train(self):
892933
yield train_report
893934

894935
weights_filename = (
895-
f"{model_name}_best_metric" + f"_epoch_{epoch+1}.pth"
936+
f"{model_name}_best_metric" + f"_epoch_{epoch + 1}.pth"
896937
)
897938

898939
if metric > best_metric:

napari_cellseg3d/plugin_convert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def build(self):
101101
ui.add_blank(layout=layout, widget=self)
102102
#############################################################
103103
folder_group_w, folder_group_l = ui.make_group(
104-
"Convert folder", l, t, r, b
104+
"Convert folder", l, t, r, b, parent=None
105105
)
106106

107107
folder_group_l.addWidget(
@@ -127,7 +127,7 @@ def build(self):
127127
ui.add_blank(layout=layout, widget=self)
128128
#############################################################
129129
layer_group_w, layer_group_l = ui.make_group(
130-
"Convert selected layer", l, t, r, b
130+
"Convert selected layer", l, t, r, b, parent=None
131131
)
132132

133133
ui.add_widgets(
@@ -142,7 +142,7 @@ def build(self):
142142
ui.add_blank(layout=layout, widget=self)
143143
#############################################################
144144
small_group_w, small_group_l = ui.make_group(
145-
"Remove small objects", l, t, r, b
145+
"Remove small objects", l, t, r, b,parent=None
146146
)
147147

148148
ui.add_widgets(

napari_cellseg3d/plugin_helper.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ class Helper(QWidget):
1414
def __init__(self, viewer: "napari.viewer.Viewer"):
1515
super().__init__()
1616

17-
self.help_url = "https://adaptivemotorcontrollab.github.io/cellseg3d-docs/"
17+
self.help_url = (
18+
"https://adaptivemotorcontrollab.github.io/cellseg3d-docs/"
19+
)
1820

1921
self.about_url = "https://wysscenter.ch/advances/3d-computer-vision-for-brain-analysis"
2022
self._viewer = viewer
@@ -35,14 +37,11 @@ def build(self):
3537
vbox = QVBoxLayout()
3638

3739
widgets = [
40+
self.info_label,
3841
self.btn1,
3942
self.btn2,
4043
self.btnc,
4144
]
42-
#################
43-
if self.test:
44-
widgets.append(self.epoch)
45-
#################
4645
ui.add_widgets(vbox, widgets)
4746
self.setLayout(vbox)
4847
# self.show()

0 commit comments

Comments
 (0)