Skip to content

Commit f4bfcfd

Browse files
Merge pull request #115 from computational-cell-analytics/99_custom_models_gui_support
99_custom_models_gui_support
2 parents 980f942 + 48e4c25 commit f4bfcfd

File tree

3 files changed

+126
-67
lines changed

3 files changed

+126
-67
lines changed

micro_sam/sam_annotator/annotator.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import os
55
import magicgui
66
import numpy as np
7+
import zarr
78
from magicgui.widgets import Container, Label, LineEdit, SpinBox, ComboBox
89
from magicgui.application import use_app
910
from PyQt5.QtWidgets import QFileDialog, QMessageBox
1011

11-
from ..util import load_image_data
12+
from ..util import load_image_data, get_model_names
1213
from .annotator_2d import annotator_2d
1314
from .annotator_3d import annotator_3d
1415
from .image_series_annotator import image_folder_annotator
@@ -37,6 +38,19 @@ def file_is_hirarchical(path_s):
3738
return os.path.splitext(path_s)[1] in [".hdf5", ".h5", "n5", ".zarr"]
3839

3940

41+
def _set_embeddings_file_attributes(embeddings_path: str, cb_model, re_tile_x, re_tile_y):
42+
f = zarr.open(embeddings_path, "a")
43+
if "tile_shape" in f.attrs:
44+
if f.attrs["tile_shape"] is None:
45+
re_tile_x.value = 0
46+
re_tile_y.value = 0
47+
else:
48+
re_tile_x.value = f.attrs["tile_shape"][0]
49+
re_tile_y.value = f.attrs["tile_shape"][1]
50+
if "model_type" in f.attrs:
51+
cb_model.value = f.attrs["model_type"]
52+
53+
4054
@magicgui.magicgui(call_button="2d annotator", labels=False)
4155
def _on_2d():
4256
global config_dict
@@ -46,6 +60,11 @@ def _on_2d():
4660

4761
le_file_key_img = LineEdit(value="*", label="File key input")
4862
le_file_key_segm = LineEdit(value="*", label="File key segmentation")
63+
re_halo_x = SpinBox(value=0, max=10000, label="Halo x")
64+
re_halo_y = SpinBox(value=0, max=10000, label="Halo y")
65+
re_tile_x = SpinBox(value=0, max=10000, label="Tile x")
66+
re_tile_y = SpinBox(value=0, max=10000, label="Tile y")
67+
cb_model = ComboBox(value="vit_h_lm", choices=get_model_names(), label="Model Type")
4968

5069
@magicgui.magicgui(call_button="Select image", labels=False)
5170
def on_select_image():
@@ -73,6 +92,7 @@ def on_select_embed():
7392
show_error("Precompute embeddings file does not exist or has wrong file extension.")
7493
return
7594
args["embedding_path"] = path
95+
_set_embeddings_file_attributes(path, cb_model, re_tile_x, re_tile_y)
7696
except Exception as e:
7797
show_error(str(e))
7898

@@ -93,12 +113,6 @@ def on_select_segm():
93113
pb_embed_sel = Container(widgets=[on_select_embed], layout="horizontal", labels=False)
94114
pb_seg_segm = Container(widgets=[on_select_segm], layout="horizontal", labels=False)
95115

96-
re_halo_x = SpinBox(value=0, max=10000, label="Halo x")
97-
re_halo_y = SpinBox(value=0, max=10000, label="Halo y")
98-
re_tile_x = SpinBox(value=0, max=10000, label="Tile x")
99-
re_tile_y = SpinBox(value=0, max=10000, label="Tile y")
100-
cb_model = ComboBox(value="vit_h", choices=["vit_h", "vit_l", "vit_b"], label="Model Type")
101-
102116
@magicgui.magicgui(call_button="2d annotator", labels=False)
103117
def on_start():
104118
try:
@@ -125,6 +139,7 @@ def on_start():
125139
sub_widget = Container(widgets=[Container(widgets=[on_start], layout="horizontal", labels=False),
126140
pb_img_sel, lbl_opt, pb_embed_sel, pb_seg_segm, le_file_key_img, le_file_key_segm,
127141
cb_model, re_tile_x, re_tile_y, re_halo_x, re_halo_y])
142+
sub_widget.root_native_widget.setWindowTitle("Segment Anything for Microscopy")
128143
main_widget.close()
129144
sub_widget.show()
130145

@@ -137,6 +152,12 @@ def _on_3d():
137152
le_file_key_img = LineEdit(value="*", label="File key input")
138153
le_file_key_segm = LineEdit(value="*", label="File key segmentation")
139154

155+
re_halo_x = SpinBox(value=0, max=10000, label="Halo x")
156+
re_halo_y = SpinBox(value=0, max=10000, label="Halo y")
157+
re_tile_x = SpinBox(value=0, max=10000, label="Tile x")
158+
re_tile_y = SpinBox(value=0, max=10000, label="Tile y")
159+
cb_model = ComboBox(value="vit_h_lm", choices=get_model_names(), label="Model Type")
160+
140161
@magicgui.magicgui(call_button="Select images", labels=False)
141162
def on_select_image():
142163
try:
@@ -175,6 +196,7 @@ def on_select_embed():
175196
show_error("Precompute embeddings file does not exist or has wrong file extension.")
176197
return
177198
args["embedding_path"] = path
199+
_set_embeddings_file_attributes(path, cb_model, re_tile_x, re_tile_y)
178200
except Exception as e:
179201
show_error(str(e))
180202

@@ -208,12 +230,6 @@ def on_select_segm_dir():
208230
pb_embed_sel = Container(widgets=[on_select_embed], layout="horizontal", labels=False)
209231
pb_seg_segm = Container(widgets=[on_select_segm, on_select_segm_dir], layout="horizontal", labels=False)
210232

211-
re_halo_x = SpinBox(value=0, max=10000, label="Halo x")
212-
re_halo_y = SpinBox(value=0, max=10000, label="Halo y")
213-
re_tile_x = SpinBox(value=0, max=10000, label="Tile x")
214-
re_tile_y = SpinBox(value=0, max=10000, label="Tile y")
215-
cb_model = ComboBox(value="vit_h", choices=["vit_h", "vit_l", "vit_b"], label="Model Type")
216-
217233
@magicgui.magicgui(call_button="3d annotator", labels=False)
218234
def on_start():
219235
try:
@@ -240,6 +256,7 @@ def on_start():
240256
sub_widget = Container(widgets=[Container(widgets=[on_start], layout="horizontal", labels=False),
241257
pb_img_sel, lbl_opt, pb_embed_sel, pb_seg_segm, le_file_key_img, le_file_key_segm,
242258
cb_model, re_tile_x, re_tile_y, re_halo_x, re_halo_y])
259+
sub_widget.root_native_widget.setWindowTitle("Segment Anything for Microscopy")
243260
main_widget.close()
244261
sub_widget.show()
245262

@@ -250,6 +267,12 @@ def _on_series():
250267
config_dict["args"] = {}
251268
args = config_dict["args"]
252269

270+
re_halo_x = SpinBox(value=0, max=10000, label="Halo x")
271+
re_halo_y = SpinBox(value=0, max=10000, label="Halo y")
272+
re_tile_x = SpinBox(value=0, max=10000, label="Tile x")
273+
re_tile_y = SpinBox(value=0, max=10000, label="Tile y")
274+
cb_model = ComboBox(value="vit_h_lm", choices=get_model_names(), label="Model Type")
275+
253276
@magicgui.magicgui(call_button="Select input directory", labels=False)
254277
def on_select_input_dir():
255278
try:
@@ -288,19 +311,14 @@ def on_select_embed():
288311
show_error("Precompute embeddings file does not exist or has wrong file extension.")
289312
return
290313
args["embedding_path"] = path
314+
_set_embeddings_file_attributes(path, cb_model, re_tile_x, re_tile_y)
291315
except Exception as e:
292316
show_error(str(e))
293317

294318
pb_input_sel = Container(widgets=[on_select_input_dir], layout="horizontal", labels=False)
295319
pb_output_sel = Container(widgets=[on_select_output_dir], layout="horizontal", labels=False)
296320
pb_embed_sel = Container(widgets=[on_select_embed], layout="horizontal", labels=False)
297321

298-
re_halo_x = SpinBox(value=0, max=10000, label="Halo x")
299-
re_halo_y = SpinBox(value=0, max=10000, label="Halo y")
300-
re_tile_x = SpinBox(value=0, max=10000, label="Tile x")
301-
re_tile_y = SpinBox(value=0, max=10000, label="Tile y")
302-
cb_model = ComboBox(value="vit_h", choices=["vit_h", "vit_l", "vit_b"], label="Model Type")
303-
304322
@magicgui.magicgui(call_button="Image series annotator", labels=False)
305323
def on_start():
306324
try:
@@ -329,6 +347,7 @@ def on_start():
329347
sub_widget = Container(widgets=[Container(widgets=[on_start], layout="horizontal", labels=False),
330348
pb_input_sel, pb_output_sel, lbl_opt, pb_embed_sel, cb_model, re_tile_x,
331349
re_tile_y, re_halo_x, re_halo_y])
350+
sub_widget.root_native_widget.setWindowTitle("Segment Anything for Microscopy")
332351
main_widget.close()
333352
sub_widget.show()
334353

@@ -341,6 +360,12 @@ def _on_tracking():
341360
le_file_key_img = LineEdit(value="*", label="File key input")
342361
le_file_key_segm = LineEdit(value="*", label="File key segmentation")
343362

363+
re_halo_x = SpinBox(value=0, max=10000, label="Halo x")
364+
re_halo_y = SpinBox(value=0, max=10000, label="Halo y")
365+
re_tile_x = SpinBox(value=0, max=10000, label="Tile x")
366+
re_tile_y = SpinBox(value=0, max=10000, label="Tile y")
367+
cb_model = ComboBox(value="vit_h_lm", choices=get_model_names(), label="Model Type")
368+
344369
@magicgui.magicgui(call_button="Select images", labels=False)
345370
def on_select_image():
346371
try:
@@ -379,6 +404,7 @@ def on_select_embed():
379404
show_error("Precompute embeddings file does not exist or has wrong file extension.")
380405
return
381406
args["embedding_path"] = path
407+
_set_embeddings_file_attributes(path, cb_model, re_tile_x, re_tile_y)
382408
except Exception as e:
383409
show_error(str(e))
384410

@@ -412,12 +438,6 @@ def on_select_result_dir():
412438
pb_embed_sel = Container(widgets=[on_select_embed], layout="horizontal", labels=False)
413439
pb_seg_sel = Container(widgets=[on_select_results, on_select_result_dir], layout="horizontal", labels=False)
414440

415-
re_halo_x = SpinBox(value=0, max=10000, label="Halo x")
416-
re_halo_y = SpinBox(value=0, max=10000, label="Halo y")
417-
re_tile_x = SpinBox(value=0, max=10000, label="Tile x")
418-
re_tile_y = SpinBox(value=0, max=10000, label="Tile y")
419-
cb_model = ComboBox(value="vit_h", choices=["vit_h", "vit_l", "vit_b"], label="Model Type")
420-
421441
@magicgui.magicgui(call_button="Tracking annotator", labels=False)
422442
def on_start():
423443
try:
@@ -444,6 +464,7 @@ def on_start():
444464
sub_widget = Container(widgets=[Container(widgets=[on_start], layout="horizontal", labels=False),
445465
pb_img_sel, lbl_opt, pb_embed_sel, pb_seg_sel, le_file_key_img, le_file_key_segm,
446466
cb_model, re_tile_x, re_tile_y, re_halo_x, re_halo_y])
467+
sub_widget.root_native_widget.setWindowTitle("Segment Anything for Microscopy")
447468
main_widget.close()
448469
sub_widget.show()
449470

@@ -460,6 +481,7 @@ def annotator():
460481
sub_container2 = Container(widgets=[_on_3d, _on_tracking], labels=False)
461482
sub_container3 = Container(widgets=[sub_container1, sub_container2], layout="horizontal", labels=False)
462483
main_widget = Container(widgets=[Label(value="Segment Anything for Microscopy"), sub_container3], labels=False)
484+
main_widget.root_native_widget.setWindowTitle("Segment Anything for Microscopy")
463485
main_widget.show(run=True)
464486

465487
if config_dict["workflow"] == "2d":
Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
import os
2+
import magicgui
23
from shutil import rmtree
34
from typing import Union
45

5-
from PyQt5 import QtCore, QtWidgets
6+
from magicgui.widgets import Container
7+
from magicgui.application import use_app
8+
9+
from PyQt5 import QtWidgets
610

711

812
def show_wrong_file_warning(file_path: Union[str, os.PathLike]) -> Union[str, os.PathLike]:
@@ -20,37 +24,62 @@ def show_wrong_file_warning(file_path: Union[str, os.PathLike]) -> Union[str, os
2024
Returns:
2125
Path to a file (new or old) depending on user decision
2226
"""
23-
msgbox = QtWidgets.QMessageBox()
24-
msgbox.setWindowFlags(QtCore.Qt.CustomizeWindowHint | QtCore.Qt.WindowTitleHint)
25-
msgbox.setWindowTitle("Warning")
26-
msgbox.setText("The input data does not match the embeddings file.")
27-
ignore_btn = msgbox.addButton("Ignore", QtWidgets.QMessageBox.RejectRole)
28-
overwrite_btn = msgbox.addButton("Overwrite file", QtWidgets.QMessageBox.DestructiveRole)
29-
select_btn = msgbox.addButton("Select different file", QtWidgets.QMessageBox.AcceptRole)
30-
create_btn = msgbox.addButton("Create new file", QtWidgets.QMessageBox.AcceptRole)
31-
msgbox.setDefaultButton(create_btn)
32-
33-
msgbox.exec()
34-
msgbox.clickedButton()
35-
if msgbox.clickedButton() == ignore_btn:
36-
return file_path
37-
elif msgbox.clickedButton() == overwrite_btn:
27+
# q_app = QtWidgets.QApplication([])
28+
# msgbox = QtWidgets.QMessageBox()
29+
# msgbox.setWindowFlags(QtCore.Qt.CustomizeWindowHint | QtCore.Qt.WindowTitleHint)
30+
# msgbox.setWindowTitle("Warning")
31+
# msgbox.setText("The input data does not match the embeddings file.")
32+
# create_btn = msgbox.addButton("Create new file", QtWidgets.QMessageBox.AcceptRole)
33+
34+
msg_box = None
35+
new_path = {"value": ""}
36+
37+
@magicgui.magicgui(call_button="Ignore", labels=False)
38+
def _ignore():
39+
msg_box.close()
40+
new_path["value"] = file_path
41+
42+
@magicgui.magicgui(call_button="Overwrite file", labels=False)
43+
def _overwrite():
44+
msg_box.close()
3845
rmtree(file_path)
39-
return file_path
40-
elif msgbox.clickedButton() == create_btn:
46+
new_path["value"] = file_path
47+
48+
@magicgui.magicgui(call_button="Create new file", labels=False)
49+
def _create():
50+
msg_box.close()
4151
# unfortunately there exists no dialog to create a directory so we have
4252
# to use "create new file" dialog with some adjustments.
4353
dialog = QtWidgets.QFileDialog(None)
4454
dialog.setFileMode(QtWidgets.QFileDialog.AnyFile)
4555
dialog.setOption(QtWidgets.QFileDialog.ShowDirsOnly)
4656
dialog.setNameFilter("Archives (*.zarr)")
47-
new_path = ""
48-
while os.path.splitext(new_path)[1] != ".zarr":
49-
dialog.exec()
50-
new_path = dialog.selectedFiles()[0]
51-
os.makedirs(new_path)
52-
return(new_path)
53-
elif msgbox.clickedButton() == select_btn:
54-
return QtWidgets.QFileDialog.getExistingDirectory(
55-
None, "Open a folder", os.path.split(file_path)[0], QtWidgets.QFileDialog.ShowDirsOnly
56-
)
57+
try_cnt = 0
58+
while os.path.splitext(new_path["value"])[1] != ".zarr":
59+
if try_cnt > 3:
60+
new_path["value"] = file_path
61+
return
62+
dialog.exec_()
63+
res = dialog.selectedFiles()
64+
new_path["value"] = res[0] if len(res) > 0 else ""
65+
try_cnt += 1
66+
os.makedirs(new_path["value"])
67+
68+
@magicgui.magicgui(call_button="Select different file", labels=False)
69+
def _select():
70+
msg_box.close()
71+
try_cnt = 0
72+
while not os.path.exists(new_path["value"]):
73+
if try_cnt > 3:
74+
new_path["value"] = file_path
75+
return
76+
new_path["value"] = QtWidgets.QFileDialog.getExistingDirectory(
77+
None, "Open a folder", os.path.split(file_path)[0], QtWidgets.QFileDialog.ShowDirsOnly
78+
)
79+
try_cnt += 1
80+
81+
msg_box = Container(widgets=[_select, _ignore, _overwrite, _create], layout='horizontal', labels=False)
82+
msg_box.root_native_widget.setWindowTitle("The input data does not match the embeddings file")
83+
msg_box.show(run=True)
84+
use_app().quit()
85+
return new_path["value"]

micro_sam/util.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import warnings
88
from shutil import copyfileobj
9-
from typing import Any, Callable, Dict, Optional, Tuple
9+
from typing import Any, Callable, Dict, Optional, Tuple, Iterable
1010

1111
import imageio.v3 as imageio
1212
import numpy as np
@@ -135,11 +135,16 @@ def get_sam_model(
135135
sam = sam_model_registry[model_type_](checkpoint=checkpoint)
136136
sam.to(device=device)
137137
predictor = SamPredictor(sam)
138+
predictor.model_type = model_type
138139
if return_sam:
139140
return predictor, sam
140141
return predictor
141142

142143

144+
def get_model_names() -> Iterable:
145+
return _MODEL_URLS.keys()
146+
147+
143148
def _to_image(input_):
144149
# we require the input to be uint8
145150
if input_.dtype != np.dtype("uint8"):
@@ -392,19 +397,22 @@ def precompute_image_embeddings(
392397
data_signature = _compute_data_signature(input_)
393398

394399
f = zarr.open(save_path, "a")
395-
if "input_size" in f.attrs: # we have computed the embeddings already
396-
397-
# data signature does not match or is not in the file
398-
if "data_signature" not in f.attrs or f.attrs["data_signature"] != data_signature:
399-
warnings.warn("Embeddings file is invalid. Please recompute embeddings in a new file.")
400-
if wrong_file_callback is not None:
401-
save_path = wrong_file_callback(save_path)
402-
f = zarr.open(save_path, "a")
403-
if "data_signature" not in f.attrs:
404-
f.attrs["data_signature"] = data_signature
405-
406-
else: # embeddings have not yet been computed
407-
f.attrs["data_signature"] = data_signature
400+
key_vals = [("data_signature", data_signature),
401+
("tile_shape", tile_shape), ("model_type", predictor.model_type)]
402+
for key, val in key_vals:
403+
if "input_size" in f.attrs: # we have computed the embeddings already
404+
# key signature does not match or is not in the file
405+
if key not in f.attrs or f.attrs[key] != val:
406+
warnings.warn(f"Embeddings file is invalid due to unmatching {key}. \
407+
Please recompute embeddings in a new file.")
408+
if wrong_file_callback is not None:
409+
save_path = wrong_file_callback(save_path)
410+
f = zarr.open(save_path, "a")
411+
break
412+
413+
for key, val in key_vals:
414+
if key not in f.attrs:
415+
f.attrs[key] = val
408416

409417
if ndim == 2:
410418
image_embeddings = _compute_2d(input_, predictor) if save_path is None else\

0 commit comments

Comments
 (0)