Skip to content

Commit 1e4c43b

Browse files
Update state precomputation and extend test (#551)
Update state precomputation and extend test
1 parent 11129a4 commit 1e4c43b

File tree

6 files changed

+156
-60
lines changed

6 files changed

+156
-60
lines changed

micro_sam/precompute_state.py

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,9 @@ def cache_is_state(
9393
save_path: Union[str, os.PathLike],
9494
verbose: bool = True,
9595
i: Optional[int] = None,
96+
skip_load: bool = False,
9697
**kwargs,
97-
) -> instance_segmentation.AMGBase:
98+
) -> Optional[instance_segmentation.AMGBase]:
9899
"""Compute and cache or load the state for the automatic mask generator.
99100
100101
Args:
@@ -105,6 +106,7 @@ def cache_is_state(
105106
save_path: The embedding save path. The AMG state will be stored in 'save_path/amg_state.pickle'.
106107
verbose: Whether to run the computation verbose.
107108
i: The index for which to cache the state.
109+
skip_load: Skip loading the state if it is precomputed.
108110
kwargs: The keyword arguments for the amg class.
109111
110112
Returns:
@@ -120,6 +122,9 @@ def cache_is_state(
120122

121123
with h5py.File(save_path, "a") as f:
122124
if save_key in f:
125+
if skip_load: # Skip loading to speed this up for cases where we don't need the return val.
126+
return
127+
123128
if verbose:
124129
print("Load instance segmentation state from", save_path, ":", save_key)
125130
g = f[save_key]
@@ -169,6 +174,7 @@ def _precompute_state_for_files(
169174
predictor: SamPredictor,
170175
input_files: Union[List[Union[os.PathLike, str]], List[np.ndarray]],
171176
output_path: Union[os.PathLike, str],
177+
key: Optional[str] = None,
172178
ndim: Optional[int] = None,
173179
tile_shape: Optional[Tuple[int, int]] = None,
174180
halo: Optional[Tuple[int, int]] = None,
@@ -185,14 +191,15 @@ def _precompute_state_for_files(
185191

186192
_precompute_state_for_file(
187193
predictor, file_path, out_path,
188-
key=None, ndim=ndim, tile_shape=tile_shape, halo=halo,
194+
key=key, ndim=ndim, tile_shape=tile_shape, halo=halo,
189195
precompute_amg_state=precompute_amg_state, decoder=decoder,
190196
)
191197

192198

193199
def precompute_state(
194200
input_path: Union[os.PathLike, str],
195201
output_path: Union[os.PathLike, str],
202+
pattern: Optional[str] = None,
196203
model_type: str = util._DEFAULT_MODEL,
197204
checkpoint_path: Optional[Union[os.PathLike, str]] = None,
198205
key: Optional[str] = None,
@@ -209,31 +216,41 @@ def precompute_state(
209216
In case of a container file the argument `key` must be given. In case of a folder
210217
it can be given to provide a glob pattern to subselect files from the folder.
211218
output_path: The output path were the embeddings and other state will be saved.
219+
pattern: Glob pattern to select files in a folder. The embeddings will be computed
220+
for each of these files. To select all files in a folder pass "*".
212221
model_type: The SegmentAnything model to use. Will use the standard vit_h model by default.
213222
checkpoint_path: Path to a checkpoint for a custom model.
214223
key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr)
215-
and can be used to provide a glob pattern if the input is a folder with image files.
224+
or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case.
216225
ndim: The dimensionality of the data.
217226
tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
218227
halo: Overlap of the tiles for tiled prediction.
219228
precompute_amg_state: Whether to precompute the state for automatic instance segmentation
220229
in addition to the image embeddings.
221230
"""
222-
predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path)
223-
# check if we precompute the state for a single file or for a folder with image files
224-
if os.path.isdir(input_path) and Path(input_path).suffix not in (".n5", ".zarr"):
225-
pattern = "*" if key is None else key
226-
input_files = glob(os.path.join(input_path, pattern))
227-
_precompute_state_for_files(
228-
predictor, input_files, output_path,
231+
predictor, state = util.get_sam_model(
232+
model_type=model_type, checkpoint_path=checkpoint_path, return_state=True,
233+
)
234+
if "decoder_state" in state:
235+
decoder = instance_segmentation.get_decoder(predictor.model.image_encoder, state["decoder_state"])
236+
else:
237+
decoder = None
238+
239+
# Check if we precompute the state for a single file or for a folder with image files.
240+
if pattern is None:
241+
_precompute_state_for_file(
242+
predictor, input_path, output_path, key,
229243
ndim=ndim, tile_shape=tile_shape, halo=halo,
230244
precompute_amg_state=precompute_amg_state,
245+
decoder=decoder,
231246
)
232247
else:
233-
_precompute_state_for_file(
234-
predictor, input_path, output_path, key,
248+
input_files = glob(os.path.join(input_path, pattern))
249+
_precompute_state_for_files(
250+
predictor, input_files, output_path, key=key,
235251
ndim=ndim, tile_shape=tile_shape, halo=halo,
236252
precompute_amg_state=precompute_amg_state,
253+
decoder=decoder,
237254
)
238255

239256

@@ -253,11 +270,16 @@ def main():
253270
parser.add_argument(
254271
"-e", "--embedding_path", required=True, help="The path where the embeddings will be saved."
255272
)
273+
274+
parser.add_argument(
275+
"--pattern", help="Pattern / wildcard for selecting files in a folder. To select all files use '*'."
276+
)
256277
parser.add_argument(
257278
"-k", "--key",
258279
help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, "
259-
"for a image series it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
280+
"for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
260281
)
282+
261283
parser.add_argument(
262284
"-m", "--model_type", default=util._DEFAULT_MODEL,
263285
help=f"The segment anything model that will be used, one of {available_models}."
@@ -284,8 +306,10 @@ def main():
284306

285307
args = parser.parse_args()
286308
precompute_state(
287-
args.input_path, args.embedding_path, args.model_type, args.checkpoint,
288-
key=args.key, tile_shape=args.tile_shape, halo=args.halo, ndim=args.ndim,
309+
args.input_path, args.embedding_path,
310+
model_type=args.model_type, checkpoint_path=args.checkpoint,
311+
pattern=args.pattern, key=args.key,
312+
tile_shape=args.tile_shape, halo=args.halo, ndim=args.ndim,
289313
precompute_amg_state=args.precompute_amg_state,
290314
)
291315

micro_sam/sam_annotator/_state.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,9 @@ def initialize_predictor(
126126
if save_path is None:
127127
raise RuntimeError("Require a save path to precompute the amg state")
128128

129-
cache_state = cache_amg_state if self.decoder is None else partial(cache_is_state, decoder=self.decoder)
129+
cache_state = cache_amg_state if self.decoder is None else partial(
130+
cache_is_state, decoder=self.decoder, skip_load=True,
131+
)
130132

131133
if ndim == 2:
132134
self.amg = cache_state(

micro_sam/sam_annotator/_widgets.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,11 +891,22 @@ def _update_model(self):
891891
tile_shape=[self.tile_x, self.tile_y],
892892
halo=[self.halo_x, self.halo_y]
893893
)
894+
895+
# Set the default settings for this model in the autosegment widget if it is part of
896+
# the currently used plugin.
894897
if "autosegment" in state.widgets:
895898
with_decoder = state.decoder is not None
896899
vutil._sync_autosegment_widget(
897900
state.widgets["autosegment"], self.model_type, self.custom_weights, update_decoder=with_decoder
898901
)
902+
# Load the AMG/AIS state if we have a 3d segmentation plugin.
903+
if state.widgets["autosegment"].volumetric and with_decoder:
904+
state.amg_state = vutil._load_is_state(state.embedding_path)
905+
elif state.widgets["autosegment"].volumetric and not with_decoder:
906+
state.amg_state = vutil._load_amg_state(state.embedding_path)
907+
908+
# Set the default settings for this model in the nd-segmentation widget if it is part of
909+
# the currently used plugin.
899910
if "segment_nd" in state.widgets:
900911
vutil._sync_ndsegment_widget(state.widgets["segment_nd"], self.model_type, self.custom_weights)
901912

micro_sam/sam_annotator/annotator_3d.py

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,16 @@
1-
import os
2-
import pickle
3-
from glob import glob
4-
from pathlib import Path
51
from typing import Optional, Tuple, Union
62

7-
import h5py
83
import napari
94
import numpy as np
105
import torch
116

127
from ._annotator import _AnnotatorBase
138
from ._state import AnnotatorState
149
from . import _widgets as widgets
15-
from .util import _initialize_parser, _sync_embedding_widget
10+
from .util import _initialize_parser, _sync_embedding_widget, _load_amg_state, _load_is_state
1611
from .. import util
1712

1813

19-
def _load_amg_state(embedding_path):
20-
if embedding_path is None or not os.path.exists(embedding_path):
21-
return {"cache_folder": None}
22-
23-
cache_folder = os.path.join(embedding_path, "amg_state")
24-
os.makedirs(cache_folder, exist_ok=True)
25-
amg_state = {"cache_folder": cache_folder}
26-
27-
state_paths = glob(os.path.join(cache_folder, "*.pkl"))
28-
for path in state_paths:
29-
with open(path, "rb") as f:
30-
state = pickle.load(f)
31-
i = int(Path(path).stem.split("-")[-1])
32-
amg_state[i] = state
33-
return amg_state
34-
35-
36-
def _load_is_state(embedding_path):
37-
if embedding_path is None or not os.path.exists(embedding_path):
38-
return {"cache_path": None}
39-
40-
cache_path = os.path.join(embedding_path, "is_state.h5")
41-
is_state = {"cache_path": cache_path}
42-
43-
with h5py.File(cache_path, "a") as f:
44-
for name, g in f.items():
45-
i = int(name.split("-")[-1])
46-
state = {
47-
"foreground": g["foreground"][:],
48-
"boundary_distances": g["boundary_distances"][:],
49-
"center_distances": g["center_distances"][:],
50-
}
51-
is_state[i] = state
52-
53-
return is_state
54-
55-
5614
class Annotator3d(_AnnotatorBase):
5715
def _get_widgets(self):
5816
autosegment = widgets.AutoSegmentWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True)

micro_sam/sam_annotator/util.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import argparse
2+
import os
3+
import pickle
24
import warnings
5+
6+
from glob import glob
7+
from pathlib import Path
38
from typing import List, Optional, Tuple
49

10+
import h5py
511
import napari
612
import numpy as np
713

@@ -729,3 +735,40 @@ def _sync_ndsegment_widget(widget, model_type, checkpoint_path):
729735
for param in params:
730736
if param in settings:
731737
getattr(widget, f"{param}_param").setValue(settings[param])
738+
739+
740+
def _load_amg_state(embedding_path):
741+
if embedding_path is None or not os.path.exists(embedding_path):
742+
return {"cache_folder": None}
743+
744+
cache_folder = os.path.join(embedding_path, "amg_state")
745+
os.makedirs(cache_folder, exist_ok=True)
746+
amg_state = {"cache_folder": cache_folder}
747+
748+
state_paths = glob(os.path.join(cache_folder, "*.pkl"))
749+
for path in state_paths:
750+
with open(path, "rb") as f:
751+
state = pickle.load(f)
752+
i = int(Path(path).stem.split("-")[-1])
753+
amg_state[i] = state
754+
return amg_state
755+
756+
757+
def _load_is_state(embedding_path):
758+
if embedding_path is None or not os.path.exists(embedding_path):
759+
return {"cache_path": None}
760+
761+
cache_path = os.path.join(embedding_path, "is_state.h5")
762+
is_state = {"cache_path": cache_path}
763+
764+
with h5py.File(cache_path, "a") as f:
765+
for name, g in f.items():
766+
i = int(name.split("-")[-1])
767+
state = {
768+
"foreground": g["foreground"][:],
769+
"boundary_distances": g["boundary_distances"][:],
770+
"center_distances": g["center_distances"][:],
771+
}
772+
is_state[i] = state
773+
774+
return is_state

test/test_sam_annotator/test_cli.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
1+
import os
2+
import platform
13
import unittest
2-
from shutil import which
4+
from shutil import which, rmtree
5+
from subprocess import run
6+
7+
import imageio.v3 as imageio
8+
import micro_sam.util as util
9+
import zarr
10+
from skimage.data import binary_blobs
311

412

513
class TestCLI(unittest.TestCase):
14+
model_type = "vit_t" if util.VIT_T_SUPPORT else "vit_b"
15+
tmp_folder = "tmp-files"
16+
17+
def setUp(self):
18+
os.makedirs(self.tmp_folder, exist_ok=True)
19+
20+
def tearDown(self):
21+
rmtree(self.tmp_folder)
22+
623
def _test_command(self, cmd):
724
self.assertTrue(which(cmd) is not None)
825

@@ -21,6 +38,47 @@ def test_image_series_annotator(self):
2138
def test_precompute_embeddings(self):
2239
self._test_command("micro_sam.precompute_embeddings")
2340

41+
# The filepaths can't be found on windows, probably due different filepath conventions.
42+
# The actual functionality likely works despite this issue.
43+
if platform.system() == "Windows":
44+
return
45+
46+
# Create 3 images as testdata.
47+
for i in range(3):
48+
im_path = os.path.join(self.tmp_folder, f"image-{i}.tif")
49+
image_data = binary_blobs(512).astype("uint8") * 255
50+
imageio.imwrite(im_path, image_data)
51+
52+
# Test precomputation with a single image.
53+
emb_path1 = os.path.join(self.tmp_folder, "embedddings1.zarr")
54+
run([
55+
"micro_sam.precompute_embeddings", "-i", im_path, "-e", emb_path1,
56+
"-m", self.model_type
57+
])
58+
self.assertTrue(os.path.exists(emb_path1))
59+
with zarr.open(emb_path1, "r") as f:
60+
self.assertIn("features", f)
61+
62+
# Test precomputation with image stack.
63+
emb_path2 = os.path.join(self.tmp_folder, "embedddings2.zarr")
64+
run([
65+
"micro_sam.precompute_embeddings", "-i", self.tmp_folder, "-e", emb_path2,
66+
"-m", self.model_type, "-k", "*.tif"
67+
])
68+
self.assertTrue(os.path.exists(emb_path2))
69+
with zarr.open(emb_path2, "r") as f:
70+
self.assertIn("features", f)
71+
self.assertEqual(f["features"].shape[0], 3)
72+
73+
# Test precomputation with pattern to process multiple image.
74+
emb_path3 = os.path.join(self.tmp_folder, "embedddings3")
75+
run([
76+
"micro_sam.precompute_embeddings", "-i", self.tmp_folder, "-e", emb_path3,
77+
"-m", self.model_type, "--pattern", "*.tif"
78+
])
79+
for i in range(3):
80+
self.assertTrue(os.path.exists(os.path.join(emb_path3, f"image-{i}.zarr")))
81+
2482

2583
if __name__ == "__main__":
2684
unittest.main()

0 commit comments

Comments
 (0)