Skip to content

Commit 22e4f9e

Browse files
Implement detection widget
1 parent 2152b26 commit 22e4f9e

File tree

4 files changed

+228
-74
lines changed

4 files changed

+228
-74
lines changed

flamingo_tools/model_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def get_default_tiling() -> Dict[str, Dict[str, int]]:
154154
else:
155155
tiling = {
156156
"tile": {"x": 96, "y": 96, "z": 16},
157-
"halo": {"x": 16, "y": 16, "z": 4},
157+
"halo": {"x": 16, "y": 16, "z": 8},
158158
}
159159
print(f"Determining default tiling for CPU: {tiling}")
160160

Lines changed: 144 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,147 @@
1+
import copy
2+
3+
import napari
4+
from napari.utils.notifications import show_info
5+
from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox
6+
from skimage.feature import peak_local_max
7+
from torch_em.util.prediction import predict_with_halo
8+
19
from .base_widget import BaseWidget
10+
from .util import _load_custom_model, _available_devices, _get_current_tiling
11+
from ..model_utils import get_model, get_model_registry, get_device, get_default_tiling
12+
13+
14+
def _run_detection(image, model, model_type, tiling, device):
15+
block_shape = [tiling["tile"][ax] for ax in "zyx"]
16+
halo = [tiling["halo"][ax] for ax in "zyx"]
17+
prediction = predict_with_halo(
18+
image, model, gpu_ids=[device], block_shape=block_shape, halo=halo,
19+
tqdm_desc="Run prediction"
20+
).squeeze()
21+
detections = peak_local_max(prediction, min_distance=2, threshold_abs=0.5)
22+
return detections
23+
24+
25+
class DetectionWidget(BaseWidget):
26+
def __init__(self):
27+
super().__init__()
28+
29+
self.viewer = napari.current_viewer()
30+
layout = QVBoxLayout()
31+
self.tiling = {}
32+
33+
# Create the image selection dropdown.
34+
self.image_selector_name = "Image data"
35+
self.image_selector_widget = self._create_layer_selector(self.image_selector_name, layer_type="Image")
36+
37+
# Create buttons and widgets.
38+
self.predict_button = QPushButton("Run Detection")
39+
self.predict_button.clicked.connect(self.on_predict)
40+
self.model_selector_widget = self.load_model_widget()
41+
self.settings = self._create_settings_widget()
42+
43+
# Add the widgets to the layout.
44+
layout.addWidget(self.image_selector_widget)
45+
layout.addWidget(self.model_selector_widget)
46+
layout.addWidget(self.settings)
47+
layout.addWidget(self.predict_button)
48+
49+
self.setLayout(layout)
50+
51+
def load_model_widget(self):
52+
model_widget = QWidget()
53+
title_label = QLabel("Select Model:")
54+
55+
model_list = list(get_model_registry().urls.keys())
56+
57+
# Exclude the detection models.
58+
segmentation_models = ["Synapses"]
59+
model_list = [name for name in model_list if name in segmentation_models]
60+
61+
models = ["- choose -"] + model_list
62+
self.model_selector = QComboBox()
63+
self.model_selector.addItems(models)
64+
# Create a layout and add the title label and combo box
65+
layout = QVBoxLayout()
66+
layout.addWidget(title_label)
67+
layout.addWidget(self.model_selector)
68+
69+
# Set layout on the model widget
70+
model_widget.setLayout(layout)
71+
return model_widget
72+
73+
def on_predict(self):
74+
# Get the model and postprocessing settings.
75+
model_type = self.model_selector.currentText()
76+
custom_model_path = self.checkpoint_param.text()
77+
if model_type == "- choose -" and custom_model_path is None:
78+
show_info("INFO: Please choose a model.")
79+
return
80+
81+
device = get_device(self.device_dropdown.currentText())
82+
83+
# Load the model. Override if user chose custom model
84+
if custom_model_path:
85+
model = _load_custom_model(custom_model_path, device)
86+
if model:
87+
show_info(f"INFO: Using custom model from path: {custom_model_path}")
88+
model_type = "custom"
89+
else:
90+
show_info(f"ERROR: Failed to load custom model from path: {custom_model_path}")
91+
return
92+
else:
93+
model = get_model(model_type, device)
94+
95+
# Get the image data.
96+
image = self._get_layer_selector_data(self.image_selector_name)
97+
if image is None:
98+
show_info("INFO: Please choose an image.")
99+
return
100+
101+
# Get the current tiling.
102+
self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
103+
# TODO extra segmentation for filtering.
104+
detections = _run_detection(image, model=model, model_type=model_type, tiling=self.tiling, device=device)
105+
106+
self.viewer.add_points(detections, name=model_type)
107+
show_info(f"INFO: Detection of {model_type} added to layers.")
108+
109+
def _create_settings_widget(self):
110+
setting_values = QWidget()
111+
# setting_values.setToolTip(get_tooltip("embedding", "settings"))
112+
setting_values.setLayout(QVBoxLayout())
113+
114+
# Create UI for the device.
115+
device = "auto"
116+
device_options = ["auto"] + _available_devices()
117+
118+
self.device_dropdown, layout = self._add_choice_param("device", device, device_options)
119+
setting_values.layout().addLayout(layout)
120+
121+
# Create UI for the tile shape.
122+
self.default_tiling = get_default_tiling()
123+
self.tiling = copy.deepcopy(self.default_tiling)
124+
self.tiling["tile"]["x"], self.tiling["tile"]["y"], self.tiling["tile"]["z"], layout = self._add_shape_param(
125+
("tile_x", "tile_y", "tile_z"),
126+
(self.default_tiling["tile"]["x"], self.default_tiling["tile"]["y"], self.default_tiling["tile"]["z"]),
127+
min_val=0, max_val=2048, step=16,
128+
# tooltip=get_tooltip("embedding", "tiling")
129+
)
130+
setting_values.layout().addLayout(layout)
131+
132+
# Create UI for the halo.
133+
self.tiling["halo"]["x"], self.tiling["halo"]["y"], self.tiling["halo"]["z"], layout = self._add_shape_param(
134+
("halo_x", "halo_y", "halo_z"),
135+
(self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]),
136+
min_val=0, max_val=512,
137+
)
138+
setting_values.layout().addLayout(layout)
2139

140+
self.checkpoint_param, layout = self._add_string_param(
141+
name="checkpoint", value="", title="Load Custom Model",
142+
placeholder="path/to/checkpoint.pt",
143+
)
144+
setting_values.layout().addLayout(layout)
3145

4-
class SegmentationWidget(BaseWidget):
5-
pass
146+
settings = self._make_collapsible(widget=setting_values, title="Advanced Settings")
147+
return settings

flamingo_tools/plugin/segmentation_widget.py

Lines changed: 26 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,35 @@
11
import copy
2-
import re
3-
from typing import Optional, Union
42

53
import napari
6-
import torch
74

5+
from torch_em.util.prediction import predict_with_halo
6+
from torch_em.util.segmentation import watershed_from_center_and_boundary_distances
87
from napari.utils.notifications import show_info
98
from qtpy.QtWidgets import QWidget, QVBoxLayout, QPushButton, QLabel, QComboBox
109

1110
from .base_widget import BaseWidget
11+
from .util import _load_custom_model, _available_devices, _get_current_tiling
1212
from ..model_utils import get_model, get_model_registry, get_device, get_default_tiling
1313

1414

15-
def _load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
16-
model_path = _clean_filepath(model_path)
17-
if device is None:
18-
device = get_device(device)
19-
try:
20-
model = torch.load(model_path, map_location=torch.device(device), weights_only=False)
21-
except Exception as e:
22-
print(e)
23-
print("model path", model_path)
24-
return None
25-
return model
26-
27-
28-
def _available_devices():
29-
available_devices = []
30-
for i in ["cuda", "mps", "cpu"]:
31-
try:
32-
device = get_device(i)
33-
except RuntimeError:
34-
pass
35-
else:
36-
available_devices.append(device)
37-
return available_devices
38-
39-
40-
def _get_current_tiling(tiling: dict, default_tiling: dict, model_type: str):
41-
# get tiling values from qt objects
42-
for k, v in tiling.items():
43-
for k2, v2 in v.items():
44-
if isinstance(v2, int):
45-
continue
46-
elif hasattr(v2, "value"): # If it's a QSpinBox, extract the value
47-
tiling[k][k2] = v2.value()
48-
else:
49-
raise TypeError(f"Unexpected type for tiling value: {type(v2)} at {k}/{k2}")
50-
show_info(f"Using tiling: {tiling}")
51-
return tiling
52-
53-
54-
def _clean_filepath(filepath):
55-
"""Cleans a given filepath by:
56-
- Removing newline characters (\n)
57-
- Removing escape sequences
58-
- Stripping the 'file://' prefix if present
59-
60-
Args:
61-
filepath (str): The original filepath
62-
63-
Returns:
64-
str: The cleaned filepath
65-
"""
66-
# Remove 'file://' prefix if present
67-
if filepath.startswith("file://"):
68-
filepath = filepath[7:]
69-
70-
# Remove escape sequences and newlines
71-
filepath = re.sub(r'\\.', '', filepath)
72-
filepath = filepath.replace('\n', '').replace('\r', '')
73-
74-
return filepath
75-
76-
77-
def _run_segmentation(image, model, model_type, tiling):
78-
# return segmentation
79-
pass
15+
# TODO Expose segmentation kwargs.
16+
def _run_segmentation(image, model, model_type, tiling, device):
17+
block_shape = [tiling["tile"][ax] for ax in "zyx"]
18+
halo = [tiling["halo"][ax] for ax in "zyx"]
19+
prediction = predict_with_halo(
20+
image, model, gpu_ids=[device], block_shape=block_shape, halo=halo,
21+
tqdm_desc="Run prediction"
22+
)
23+
foreground_map, center_distances, boundary_distances = prediction
24+
segmentation = watershed_from_center_and_boundary_distances(
25+
center_distances, boundary_distances, foreground_map,
26+
center_distance_threshold=0.5,
27+
boundary_distance_threshold=0.5,
28+
foreground_threshold=0.5,
29+
distance_smoothing=1.6,
30+
min_size=100,
31+
)
32+
return segmentation
8033

8134

8235
class SegmentationWidget(BaseWidget):
@@ -109,9 +62,12 @@ def load_model_widget(self):
10962
model_widget = QWidget()
11063
title_label = QLabel("Select Model:")
11164

112-
# Exclude the models that are only offered through the CLI and not in the plugin.
11365
model_list = list(get_model_registry().urls.keys())
11466

67+
# Exclude the detection models.
68+
segmentation_models = ["SGN", "IHC"]
69+
model_list = [name for name in model_list if name in segmentation_models]
70+
11571
models = ["- choose -"] + model_list
11672
self.model_selector = QComboBox()
11773
self.model_selector.addItems(models)
@@ -154,7 +110,7 @@ def on_predict(self):
154110

155111
# Get the current tiling.
156112
self.tiling = _get_current_tiling(self.tiling, self.default_tiling, model_type)
157-
segmentation = _run_segmentation(image, model=model, model_type=model_type, tiling=self.tiling)
113+
segmentation = _run_segmentation(image, model=model, model_type=model_type, tiling=self.tiling, device=device)
158114

159115
self.viewer.add_labels(segmentation, name=model_type)
160116
show_info(f"INFO: Segmentation of {model_type} added to layers.")
@@ -187,7 +143,6 @@ def _create_settings_widget(self):
187143
("halo_x", "halo_y", "halo_z"),
188144
(self.default_tiling["halo"]["x"], self.default_tiling["halo"]["y"], self.default_tiling["halo"]["z"]),
189145
min_val=0, max_val=512,
190-
# tooltip=get_tooltip("embedding", "halo")
191146
)
192147
setting_values.layout().addLayout(layout)
193148

flamingo_tools/plugin/util.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import re
2+
from typing import Optional, Union
3+
4+
import torch
5+
from napari.utils.notifications import show_info
6+
from ..model_utils import get_device
7+
8+
9+
def _load_custom_model(model_path: str, device: Optional[Union[str, torch.device]] = None) -> torch.nn.Module:
10+
model_path = _clean_filepath(model_path)
11+
if device is None:
12+
device = get_device(device)
13+
try:
14+
model = torch.load(model_path, map_location=torch.device(device), weights_only=False)
15+
except Exception as e:
16+
print(e)
17+
print("model path", model_path)
18+
return None
19+
return model
20+
21+
22+
def _available_devices():
23+
available_devices = []
24+
for i in ["cuda", "mps", "cpu"]:
25+
try:
26+
device = get_device(i)
27+
except RuntimeError:
28+
pass
29+
else:
30+
available_devices.append(device)
31+
return available_devices
32+
33+
34+
def _get_current_tiling(tiling: dict, default_tiling: dict, model_type: str):
35+
# get tiling values from qt objects
36+
for k, v in tiling.items():
37+
for k2, v2 in v.items():
38+
if isinstance(v2, int):
39+
continue
40+
elif hasattr(v2, "value"): # If it's a QSpinBox, extract the value
41+
tiling[k][k2] = v2.value()
42+
else:
43+
raise TypeError(f"Unexpected type for tiling value: {type(v2)} at {k}/{k2}")
44+
show_info(f"Using tiling: {tiling}")
45+
return tiling
46+
47+
48+
def _clean_filepath(filepath):
49+
# Remove 'file://' prefix if present
50+
if filepath.startswith("file://"):
51+
filepath = filepath[7:]
52+
53+
# Remove escape sequences and newlines
54+
filepath = re.sub(r'\\.', '', filepath)
55+
filepath = filepath.replace('\n', '').replace('\r', '')
56+
57+
return filepath

0 commit comments

Comments
 (0)