Skip to content

Commit 9809835

Browse files
Merge pull request #78 from computational-cell-analytics/refactor-gui
Refactor the wrong file dialog gui
2 parents 47dff10 + d785ad7 commit 9809835

File tree

5 files changed

+83
-58
lines changed

5 files changed

+83
-58
lines changed

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .. import util
1010
from .. import instance_segmentation
1111
from ..visualization import project_embeddings_for_visualization
12+
from .gui_utils import show_wrong_file_warning
1213
from .util import (
1314
clear_all_prompts, commit_segmentation_widget, create_prompt_menu,
1415
prompt_layer_to_boxes, prompt_layer_to_points, prompt_segmentation, toggle_label, LABEL_COLOR_CYCLE,
@@ -215,7 +216,8 @@ def annotator_2d(
215216
else:
216217
PREDICTOR = predictor
217218
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(
218-
PREDICTOR, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo
219+
PREDICTOR, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo,
220+
wrong_file_callback=show_wrong_file_warning
219221
)
220222

221223
# we set the pre-computed image embeddings if we don't use tiling

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .. import util
99
from ..prompt_based_segmentation import segment_from_mask
1010
from ..visualization import project_embeddings_for_visualization
11+
from .gui_utils import show_wrong_file_warning
1112
from .util import (
1213
clear_all_prompts, commit_segmentation_widget, create_prompt_menu,
1314
prompt_layer_to_boxes, prompt_layer_to_points, prompt_segmentation,
@@ -195,7 +196,8 @@ def annotator_3d(
195196
global PREDICTOR, IMAGE_EMBEDDINGS
196197
PREDICTOR = util.get_sam_model(model_type=model_type)
197198
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(
198-
PREDICTOR, raw, save_path=embedding_path, tile_shape=tile_shape, halo=halo
199+
PREDICTOR, raw, save_path=embedding_path, tile_shape=tile_shape, halo=halo,
200+
wrong_file_callback=show_wrong_file_warning,
199201
)
200202

201203
#

micro_sam/sam_annotator/annotator_tracking.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from .. import util
1414
from ..prompt_based_segmentation import segment_from_mask
15+
from .gui_utils import show_wrong_file_warning
1516
from .util import (
1617
create_prompt_menu, clear_all_prompts,
1718
prompt_layer_to_boxes, prompt_layer_to_points,
@@ -358,7 +359,8 @@ def annotator_tracking(
358359

359360
PREDICTOR = util.get_sam_model(model_type=model_type)
360361
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(
361-
PREDICTOR, raw, save_path=embedding_path, tile_shape=tile_shape, halo=halo
362+
PREDICTOR, raw, save_path=embedding_path, tile_shape=tile_shape, halo=halo,
363+
wrong_file_callback=show_wrong_file_warning,
362364
)
363365

364366
CURRENT_TRACK_ID = 1
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import os
2+
from shutil import rmtree
3+
4+
from PyQt5 import QtCore, QtWidgets
5+
6+
7+
def show_wrong_file_warning(file_path):
8+
"""If the data signature does not match to the signature,
9+
the user can choose from the following options in this dialog:
10+
- Ignore: continue with input file (return file_path).
11+
- Overwrite: delete file_path and recompute the embeddings at same location.
12+
- Select a different file
13+
- Select a new file
14+
15+
Arguments:
16+
file_path (string or os.path): path of the problematic file
17+
18+
Returns:
19+
string or os.path: path to a file (new or old) depending on user decision
20+
"""
21+
msgbox = QtWidgets.QMessageBox()
22+
msgbox.setWindowFlags(QtCore.Qt.CustomizeWindowHint | QtCore.Qt.WindowTitleHint)
23+
msgbox.setWindowTitle("Warning")
24+
msgbox.setText("The input data does not match the embeddings file.")
25+
ignore_btn = msgbox.addButton("Ignore", QtWidgets.QMessageBox.RejectRole)
26+
overwrite_btn = msgbox.addButton("Overwrite file", QtWidgets.QMessageBox.DestructiveRole)
27+
select_btn = msgbox.addButton("Select different file", QtWidgets.QMessageBox.AcceptRole)
28+
create_btn = msgbox.addButton("Create new file", QtWidgets.QMessageBox.AcceptRole)
29+
msgbox.setDefaultButton(create_btn)
30+
31+
msgbox.exec()
32+
msgbox.clickedButton()
33+
if msgbox.clickedButton() == ignore_btn:
34+
return file_path
35+
elif msgbox.clickedButton() == overwrite_btn:
36+
rmtree(file_path)
37+
return file_path
38+
elif msgbox.clickedButton() == create_btn:
39+
# unfortunately there exists no dialog to create a directory so we have
40+
# to use "create new file" dialog with some adjustments.
41+
dialog = QtWidgets.QFileDialog(None)
42+
dialog.setFileMode(QtWidgets.QFileDialog.AnyFile)
43+
dialog.setOption(QtWidgets.QFileDialog.ShowDirsOnly)
44+
dialog.setNameFilter("Archives (*.zarr)")
45+
new_path = ""
46+
while os.path.splitext(new_path)[1] != ".zarr":
47+
dialog.exec()
48+
new_path = dialog.selectedFiles()[0]
49+
os.makedirs(new_path)
50+
return(new_path)
51+
elif msgbox.clickedButton() == select_btn:
52+
return QtWidgets.QFileDialog.getExistingDirectory(
53+
None, "Open a folder", os.path.split(file_path)[0], QtWidgets.QFileDialog.ShowDirsOnly
54+
)

micro_sam/util.py

Lines changed: 20 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import hashlib
22
import os
33
import warnings
4-
from shutil import copyfileobj, rmtree
4+
from shutil import copyfileobj
55

66
import numpy as np
77
import requests
88
import torch
99
import vigra
1010
import zarr
1111

12-
from PyQt5 import QtCore, QtWidgets
1312
from elf.io import open_file
1413
from nifty.tools import blocking
1514
from skimage.measure import regionprops
@@ -316,54 +315,12 @@ def _precompute_3d(input_, predictor, save_path, lazy_loading, tile_shape=None,
316315
}
317316
return image_embeddings
318317

319-
def show_wrong_file_warning(file_path):
320-
"""If the data signature does not match to the signature, user will can choose from the following options in this dialog:
321-
- Ignore: continue with input file (return file_path).
322-
- Overwrite: delete file_path and recompute the embeddings at same location.
323-
- Select a different file
324-
- Select a new file
325-
326-
Arguments:
327-
file_path (string or os.path): path of the problematic file
328-
329-
Returns:
330-
string or os.path: path to a file (new or old) depending on user decision
331-
"""
332-
msgbox = QtWidgets.QMessageBox()
333-
msgbox.setWindowFlags(QtCore.Qt.CustomizeWindowHint | QtCore.Qt.WindowTitleHint)
334-
msgbox.setWindowTitle("Warning")
335-
msgbox.setText('The input data does not match the embeddings file.')
336-
ignore_btn = msgbox.addButton("Ignore" , QtWidgets.QMessageBox.RejectRole)
337-
overwrite_btn = msgbox.addButton("Overwrite file" , QtWidgets.QMessageBox.DestructiveRole)
338-
select_btn = msgbox.addButton("Select different file" ,QtWidgets.QMessageBox.AcceptRole)
339-
create_btn = msgbox.addButton("Create new file" ,QtWidgets.QMessageBox.AcceptRole)
340-
msgbox.setDefaultButton(create_btn)
341-
342-
msgbox.exec()
343-
msgbox.clickedButton()
344-
if msgbox.clickedButton() == ignore_btn:
345-
return file_path
346-
elif msgbox.clickedButton() == overwrite_btn:
347-
rmtree(file_path)
348-
return file_path
349-
elif msgbox.clickedButton() == create_btn:
350-
# unfortunately there exists no dialog to create a directory so we have to use "create new file" dialog with some adjustments.
351-
dialog = QtWidgets.QFileDialog(None)
352-
dialog.setFileMode(QtWidgets.QFileDialog.AnyFile)
353-
dialog.setOption(QtWidgets.QFileDialog.ShowDirsOnly)
354-
dialog.setNameFilter("Archives (*.zarr)")
355-
new_path = ""
356-
while os.path.splitext(new_path)[1] != ".zarr":
357-
dialog.exec()
358-
new_path = dialog.selectedFiles()[0]
359-
os.makedirs(new_path)
360-
return(new_path)
361-
elif msgbox.clickedButton() == select_btn:
362-
return QtWidgets.QFileDialog.getExistingDirectory(None, "Open a folder", os.path.split(file_path)[0], QtWidgets.QFileDialog.ShowDirsOnly)
363-
364318

365319
def precompute_image_embeddings(
366-
predictor, input_, save_path=None, lazy_loading=False, ndim=None, tile_shape=None, halo=None
320+
predictor, input_,
321+
save_path=None, lazy_loading=False,
322+
ndim=None, tile_shape=None, halo=None,
323+
wrong_file_callback=None,
367324
):
368325
"""Compute the image embeddings (output of the encoder) for the input.
369326
@@ -380,23 +337,31 @@ def precompute_image_embeddings(
380337
tile_shape [tuple] - shape of tiles for tiled prediction.
381338
By default prediction is run without tiling. (default: None)
382339
halo [tuple] - additional overlap of the tiles for tiled prediction. (default: None)
340+
wrong_file_callback [callable] - function to call when an embedding file with wrong file signature
341+
is passed. If none is given a wrong file signature will cause a warning.
342+
If passed, the callback should have the signature 'def callback(save_path): return str',
343+
where the return value is the (potentially updated) embedding save path (default: None)
383344
"""
384345
ndim = input_.ndim if ndim is None else ndim
385346
if tile_shape is not None:
386347
assert save_path is not None, "Tiled prediction is only supported when the embeddings are saved to file."
387-
348+
388349
if save_path is not None:
389350
data_signature = hashlib.sha1(input_.tobytes()).hexdigest()
390-
351+
391352
f = zarr.open(save_path, "a")
392-
if "input_size" in f.attrs:
353+
if "input_size" in f.attrs: # we have computed the embeddings already
354+
355+
# data signature does not match or is not in the file
393356
if "data_signature" not in f.attrs or f.attrs["data_signature"] != data_signature:
394-
warnings.warn("Embeddings file is invalid. Please recompute embeddings to new file.")
395-
save_path = show_wrong_file_warning(save_path)
357+
warnings.warn("Embeddings file is invalid. Please recompute embeddings in a new file.")
358+
if wrong_file_callback is not None:
359+
save_path = wrong_file_callback(save_path)
396360
f = zarr.open(save_path, "a")
397361
if "data_signature" not in f.attrs:
398-
f.attrs["data_signature"] = data_signature
399-
else:
362+
f.attrs["data_signature"] = data_signature
363+
364+
else: # embeddings have not yet been computed
400365
f.attrs["data_signature"] = data_signature
401366

402367
if ndim == 2:

0 commit comments

Comments
 (0)