Skip to content

Commit c57d102

Browse files
Merge pull request #155 from computational-cell-analytics/extend-precomputation
Refactor state precomputation
2 parents 0b5d906 + ed8865e commit c57d102

File tree

7 files changed

+216
-86
lines changed

7 files changed

+216
-86
lines changed

micro_sam/instance_segmentation.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,6 +1091,35 @@ def set_state(self, state: Dict[str, Any]) -> None:
10911091
super().set_state(state)
10921092

10931093

1094+
def get_amg(
1095+
predictor: SamPredictor,
1096+
is_tiled: bool,
1097+
embedding_based_amg: bool = False,
1098+
**kwargs,
1099+
) -> AMGBase:
1100+
"""Get the automatic mask generator class.
1101+
1102+
Args:
1103+
predictor: The segment anything predictor.
1104+
is_tiled: Whether tiled embeddings are used.
1105+
embedding_based_amg: Whether to use the embedding based instance segmentation functionality.
1106+
This functionality is still experimental.
1107+
kwargs: The keyword arguments for the amg class.
1108+
1109+
Returns:
1110+
The automatic mask generator.
1111+
"""
1112+
if embedding_based_amg:
1113+
warnings.warn("The embedding based instance segmentation functionality is experimental.")
1114+
if is_tiled:
1115+
amg = TiledEmbeddingMaskGenerator(predictor, **kwargs) if embedding_based_amg else\
1116+
TiledAutomaticMaskGenerator(predictor, **kwargs)
1117+
else:
1118+
amg = EmbeddingMaskGenerator(predictor, **kwargs) if embedding_based_amg else\
1119+
AutomaticMaskGenerator(predictor, **kwargs)
1120+
return amg
1121+
1122+
10941123
#
10951124
# Experimental functionality
10961125
#

micro_sam/precompute_state.py

Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import os
2+
import pickle
3+
4+
from glob import glob
5+
from pathlib import Path
6+
from typing import Optional, Tuple, Union
7+
8+
import numpy as np
9+
from segment_anything.predictor import SamPredictor
10+
from tqdm import tqdm
11+
12+
from . import instance_segmentation, util
13+
14+
15+
def cache_amg_state(
16+
predictor: SamPredictor,
17+
raw: np.ndarray,
18+
image_embeddings: util.ImageEmbeddings,
19+
save_path: Union[str, os.PathLike],
20+
verbose: bool = True,
21+
**kwargs,
22+
) -> instance_segmentation.AMGBase:
23+
"""Compute and cache or load the state for the automatic mask generator.
24+
25+
Args:
26+
predictor: The segment anything predictor.
27+
raw: The image data.
28+
image_embeddings: The image embeddings.
29+
save_path: The embedding save path. The AMG state will be stored in <save_path>/amg_state.pickle.
30+
verbose: Whether to run the computation verbose.
31+
kwargs: The keyword arguments for the amg class.
32+
33+
Returns:
34+
The automatic mask generator class with the cached state.
35+
"""
36+
is_tiled = image_embeddings["input_size"] is None
37+
amg = instance_segmentation.get_amg(predictor, is_tiled, **kwargs)
38+
39+
save_path_amg = os.path.join(save_path, "amg_state.pickle")
40+
if os.path.exists(save_path_amg):
41+
if verbose:
42+
print("Load the AMG state from", save_path_amg)
43+
with open(save_path_amg, "rb") as f:
44+
amg_state = pickle.load(f)
45+
amg.set_state(amg_state)
46+
return amg
47+
48+
if verbose:
49+
print("Precomputing the state for instance segmentation.")
50+
amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose)
51+
with open(save_path_amg, "wb") as f:
52+
pickle.dump(amg.get_state(), f)
53+
54+
return amg
55+
56+
57+
def _precompute_state_for_file(
58+
predictor, input_path, output_path, key, ndim, tile_shape, halo, precompute_amg_state,
59+
):
60+
image_data = util.load_image_data(input_path, key)
61+
output_path = Path(output_path).with_suffix(".zarr")
62+
embeddings = util.precompute_image_embeddings(
63+
predictor, image_data, output_path, ndim=ndim, tile_shape=tile_shape, halo=halo,
64+
)
65+
if precompute_amg_state:
66+
cache_amg_state(predictor, image_data, embeddings, output_path, verbose=True)
67+
68+
69+
def _precompute_state_for_files(
70+
predictor, input_files, output_path, ndim, tile_shape, halo, precompute_amg_state,
71+
):
72+
os.makedirs(output_path, exist_ok=True)
73+
for file_path in tqdm(input_files, desc="Precompute state for files."):
74+
out_path = os.path.join(output_path, os.path.basename(file_path))
75+
_precompute_state_for_file(
76+
predictor, file_path, out_path,
77+
key=None, ndim=ndim, tile_shape=tile_shape, halo=halo,
78+
precompute_amg_state=precompute_amg_state,
79+
)
80+
81+
82+
def precompute_state(
83+
input_path: Union[os.PathLike, str],
84+
output_path: Union[os.PathLike, str],
85+
model_type: str = util._DEFAULT_MODEL,
86+
checkpoint_path: Optional[Union[os.PathLike, str]] = None,
87+
key: Optional[str] = None,
88+
ndim: Union[int] = None,
89+
tile_shape: Optional[Tuple[int, int]] = None,
90+
halo: Optional[Tuple[int, int]] = None,
91+
precompute_amg_state: bool = False,
92+
) -> None:
93+
"""Precompute the image embeddings and other optional state for the input image(s).
94+
95+
Args:
96+
input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
97+
a container file (e.g. hdf5 or zarr) or a folder with images files.
98+
In case of a container file the argument `key` must be given. In case of a folder
99+
it can be given to provide a glob pattern to subselect files from the folder.
100+
output_path: The output path were the embeddings and other state will be saved.
101+
model_type: The SegmentAnything model to use. Will use the standard vit_h model by default.
102+
checkpoint_path: Path to a checkpoint for a custom model.
103+
key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr)
104+
and can be used to provide a glob pattern if the input is a folder with image files.
105+
ndim: The dimensionality of the data.
106+
tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
107+
halo: Overlap of the tiles for tiled prediction.
108+
precompute_amg_state: Whether to precompute the state for automatic instance segmentation
109+
in addition to the image embeddings.
110+
"""
111+
predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path)
112+
# check if we precompute the state for a single file or for a folder with image files
113+
if os.path.isdir(input_path) and Path(input_path).suffix not in (".n5", ".zarr"):
114+
pattern = "*" if key is None else key
115+
input_files = glob(os.path.join(input_path, pattern))
116+
_precompute_state_for_files(
117+
predictor, input_files, output_path,
118+
ndim=ndim, tile_shape=tile_shape, halo=halo,
119+
precompute_amg_state=precompute_amg_state,
120+
)
121+
else:
122+
_precompute_state_for_file(
123+
predictor, input_path, output_path, key,
124+
ndim=ndim, tile_shape=tile_shape, halo=halo,
125+
precompute_amg_state=precompute_amg_state,
126+
)
127+
128+
129+
def main():
130+
"""@private"""
131+
import argparse
132+
133+
parser = argparse.ArgumentParser(description="Compute the embeddings for an image.")
134+
parser.add_argument("-i", "--input_path", required=True)
135+
parser.add_argument("-o", "--output_path", required=True)
136+
parser.add_argument("-m", "--model_type", default="vit_h")
137+
parser.add_argument("-c", "--checkpoint_path", default=None)
138+
parser.add_argument("-k", "--key")
139+
parser.add_argument(
140+
"--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None
141+
)
142+
parser.add_argument(
143+
"--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None
144+
)
145+
parser.add_argument("-n", "--ndim")
146+
parser.add_argument("-p", "--precompute_amg_state")
147+
148+
args = parser.parse_args()
149+
precompute_state(
150+
args.input_path, args.output_path, args.model_type, args.checkpoint_path,
151+
key=args.key, tile_shape=args.tile_shape, halo=args.halo, ndim=args.ndim,
152+
precompute_amg_state=args.precompute_amg_state,
153+
)
154+
155+
156+
if __name__ == "__main__":
157+
main()

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from segment_anything import SamPredictor
1010

1111
from .. import instance_segmentation, util
12+
from ..precompute_state import cache_amg_state
1213
from ..visualization import project_embeddings_for_visualization
1314
from . import util as vutil
1415
from .gui_utils import show_wrong_file_warning
@@ -57,7 +58,7 @@ def _autosegment_widget(
5758
global AMG
5859
is_tiled = IMAGE_EMBEDDINGS["input_size"] is None
5960
if AMG is None:
60-
AMG = vutil.get_amg(PREDICTOR, is_tiled)
61+
AMG = instance_segmentation.get_amg(PREDICTOR, is_tiled)
6162

6263
if not AMG.is_initialized:
6364
AMG.initialize(v.layers["raw"].data, image_embeddings=IMAGE_EMBEDDINGS, verbose=True)
@@ -230,7 +231,7 @@ def annotator_2d(
230231
wrong_file_callback=show_wrong_file_warning
231232
)
232233
if precompute_amg_state and (embedding_path is not None):
233-
AMG = vutil.cache_amg_state(PREDICTOR, raw, IMAGE_EMBEDDINGS, embedding_path)
234+
AMG = cache_amg_state(PREDICTOR, raw, IMAGE_EMBEDDINGS, embedding_path)
234235

235236
# we set the pre-computed image embeddings if we don't use tiling
236237
# (if we use tiling we cannot directly set it because the tile will be chosen dynamically)

micro_sam/sam_annotator/image_series_annotator.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,19 @@
11
import os
22
import warnings
3+
34
from glob import glob
5+
from pathlib import Path
46
from typing import List, Optional, Union
57

68
import imageio.v3 as imageio
79
import napari
810

911
from magicgui import magicgui
10-
from napari.utils import progress as tqdm
1112
from segment_anything import SamPredictor
1213

1314
from .. import util
15+
from ..precompute_state import _precompute_state_for_files
1416
from .annotator_2d import annotator_2d
15-
from .util import cache_amg_state
16-
17-
18-
def _precompute_embeddings_for_image_series(
19-
predictor,
20-
image_files,
21-
embedding_root,
22-
tile_shape,
23-
halo,
24-
precompute_amg_state,
25-
):
26-
os.makedirs(embedding_root, exist_ok=True)
27-
embedding_paths = []
28-
for image_file in tqdm(image_files, desc="Precompute embeddings"):
29-
fname = os.path.basename(image_file)
30-
fname = os.path.splitext(fname)[0] + ".zarr"
31-
embedding_path = os.path.join(embedding_root, fname)
32-
image = imageio.imread(image_file)
33-
embeddings = util.precompute_image_embeddings(
34-
predictor, image, save_path=embedding_path, ndim=2,
35-
tile_shape=tile_shape, halo=halo
36-
)
37-
if precompute_amg_state:
38-
cache_amg_state(predictor, image, embeddings, embedding_path)
39-
embedding_paths.append(embedding_path)
40-
return embedding_paths
4117

4218

4319
def image_series_annotator(
@@ -73,12 +49,16 @@ def image_series_annotator(
7349
if embedding_path is None:
7450
embedding_paths = None
7551
else:
76-
embedding_paths = _precompute_embeddings_for_image_series(
77-
predictor, image_files, embedding_path,
52+
_precompute_state_for_files(
53+
predictor, image_files, embedding_path, ndim=2,
7854
tile_shape=kwargs.get("tile_shape", None),
7955
halo=kwargs.get("halo", None),
8056
precompute_amg_state=kwargs.get("precompute_amg_state", False),
8157
)
58+
embedding_paths = [
59+
os.path.join(embedding_path, f"{Path(path).stem}.zarr") for path in image_files
60+
]
61+
assert all(os.path.exists(emb_path) for emb_path in embedding_paths)
8262

8363
def _save_segmentation(image_path, segmentation):
8464
fname = os.path.basename(image_path)
@@ -151,7 +131,7 @@ def main():
151131
"""@private"""
152132
import argparse
153133

154-
available_models = list(util._MODEL_URLS.keys())
134+
available_models = list(util.get_model_names())
155135
available_models = ", ".join(available_models)
156136

157137
parser = argparse.ArgumentParser(description="Annotate a series of images from a folder.")

micro_sam/sam_annotator/util.py

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -364,39 +364,9 @@ def toggle_label(prompts):
364364
prompts.refresh_colors()
365365

366366

367-
def get_amg(predictor, is_tiled):
368-
"""@private
369-
"""
370-
if is_tiled:
371-
amg = instance_segmentation.TiledAutomaticMaskGenerator(predictor)
372-
else:
373-
amg = instance_segmentation.AutomaticMaskGenerator(predictor)
374-
return amg
375-
376-
377-
def cache_amg_state(predictor, raw, image_embeddings, save_path, verbose=True):
378-
"""@private"""
379-
is_tiled = image_embeddings["input_size"] is None
380-
amg = get_amg(predictor, is_tiled)
381-
382-
save_path_amg = os.path.join(save_path, "amg_state.pickle")
383-
if os.path.exists(save_path_amg):
384-
with open(save_path_amg, "rb") as f:
385-
amg_state = pickle.load(f)
386-
amg.set_state(amg_state)
387-
return amg
388-
389-
print("Precomputing the state for instance segmentation.")
390-
amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose)
391-
with open(save_path_amg, "wb") as f:
392-
pickle.dump(amg.get_state(), f)
393-
394-
return amg
395-
396-
397367
def _initialize_parser(description, with_segmentation_result=True, with_show_embeddings=True):
398368

399-
available_models = list(util._MODEL_URLS.keys())
369+
available_models = list(util.get_model_names())
400370
available_models = ", ".join(available_models)
401371

402372
parser = argparse.ArgumentParser(description=description)

micro_sam/util.py

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@
6868
"""@private"""
6969

7070

71+
#
72+
# Functionality for model download and export
73+
#
74+
75+
7176
def _download(url, path, model_type):
7277
with requests.get(url, stream=True, verify=True) as r:
7378
if r.status_code != 200:
@@ -246,6 +251,11 @@ def get_model_names() -> Iterable:
246251
return _MODEL_URLS.keys()
247252

248253

254+
#
255+
# Functionality for precomputing embeddings and other state
256+
#
257+
258+
249259
def _to_image(input_):
250260
# we require the input to be uint8
251261
if input_.dtype != np.dtype("uint8"):
@@ -570,6 +580,11 @@ def set_precomputed(
570580
return predictor
571581

572582

583+
#
584+
# Misc functionality
585+
#
586+
587+
573588
def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
574589
"""Compute the intersection over union of two masks.
575590
@@ -642,25 +657,3 @@ def load_image_data(
642657
if not lazy_loading:
643658
image_data = image_data[:]
644659
return image_data
645-
646-
647-
def main():
648-
"""@private"""
649-
import argparse
650-
651-
parser = argparse.ArgumentParser(description="Compute the embeddings for an image.")
652-
parser.add_argument("-i", "--input_path", required=True)
653-
parser.add_argument("-o", "--output_path", required=True)
654-
parser.add_argument("-m", "--model_type", default="vit_h")
655-
parser.add_argument("-c", "--checkpoint_path", default=None)
656-
parser.add_argument("-k", "--key")
657-
args = parser.parse_args()
658-
659-
predictor = get_sam_model(model_type=args.model_type, checkpoint_path=args.checkpoint_path)
660-
with open_file(args.input_path, mode="r") as f:
661-
data = f[args.key]
662-
precompute_image_embeddings(predictor, data, save_path=args.output_path)
663-
664-
665-
if __name__ == "__main__":
666-
main()

0 commit comments

Comments
 (0)