Skip to content

Commit a79c5b3

Browse files
Fix state precomputation for instance segmentation (#552)
Fix state precomputation for instance segmentation
1 parent 1e4c43b commit a79c5b3

File tree

3 files changed

+33
-8
lines changed

3 files changed

+33
-8
lines changed

micro_sam/instance_segmentation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -887,7 +887,7 @@ def initialize(
887887
pbar_update: Callback to update an external progress bar.
888888
"""
889889
_, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
890-
pbar_init(1, "Initialize instannce segmentation with decoder")
890+
pbar_init(1, "Initialize instance segmentation with decoder")
891891

892892
if image_embeddings is None:
893893
image_embeddings = util.precompute_image_embeddings(self._predictor, image)
@@ -1071,7 +1071,7 @@ def initialize(
10711071
tiling = blocking([0, 0], original_size, tile_shape)
10721072

10731073
_, pbar_init, pbar_update, pbar_close = util.handle_pbar(verbose, pbar_init, pbar_update)
1074-
pbar_init(tiling.numberOfBlocks, "Initialize tiled instannce segmentation with decoder")
1074+
pbar_init(tiling.numberOfBlocks, "Initialize tiled instance segmentation with decoder")
10751075

10761076
foreground = np.zeros(original_size, dtype="float32")
10771077
center_distances = np.zeros(original_size, dtype="float32")

micro_sam/precompute_state.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import pickle
66

7+
from functools import partial
78
from glob import glob
89
from pathlib import Path
910
from typing import Optional, Tuple, Union, List
@@ -159,15 +160,33 @@ def _precompute_state_for_file(
159160
else:
160161
image_data = util.load_image_data(input_path, key)
161162

163+
# Precompute the image embeddings.
162164
output_path = Path(output_path).with_suffix(".zarr")
163165
embeddings = util.precompute_image_embeddings(
164166
predictor, image_data, output_path, ndim=ndim, tile_shape=tile_shape, halo=halo,
165167
)
168+
169+
# Precompute the state for automatic instance segmnetaiton (AMG or AIS).
166170
if precompute_amg_state:
167171
if decoder is None:
168-
cache_amg_state(predictor, image_data, embeddings, output_path, verbose=True)
172+
cache_function = partial(
173+
cache_amg_state, predictor=predictor, image_embeddings=embeddings, save_path=output_path
174+
)
175+
else:
176+
cache_function = partial(
177+
cache_is_state, predictor=predictor, decoder=decoder,
178+
image_embeddings=embeddings, save_path=output_path
179+
)
180+
181+
if ndim is None:
182+
ndim = image_data.ndim
183+
184+
if ndim == 2:
185+
cache_function(raw=image_data, verbose=True)
169186
else:
170-
cache_is_state(predictor, decoder, image_data, embeddings, output_path, verbose=True)
187+
n = image_data.shape[0]
188+
for i in tqdm(range(n), total=n, desc="Precompute instance segmentation state"):
189+
cache_function(raw=image_data, i=i, verbose=False)
171190

172191

173192
def _precompute_state_for_files(

test/test_sam_annotator/test_cli.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
class TestCLI(unittest.TestCase):
14-
model_type = "vit_t" if util.VIT_T_SUPPORT else "vit_b"
14+
model_type = "vit_t_lm" if util.VIT_T_SUPPORT else "vit_b_lm"
1515
tmp_folder = "tmp-files"
1616

1717
def setUp(self):
@@ -53,31 +53,37 @@ def test_precompute_embeddings(self):
5353
emb_path1 = os.path.join(self.tmp_folder, "embedddings1.zarr")
5454
run([
5555
"micro_sam.precompute_embeddings", "-i", im_path, "-e", emb_path1,
56-
"-m", self.model_type
56+
"-m", self.model_type, "--precompute_amg_state"
5757
])
5858
self.assertTrue(os.path.exists(emb_path1))
5959
with zarr.open(emb_path1, "r") as f:
6060
self.assertIn("features", f)
61+
ais_path = os.path.join(emb_path1, "is_state.h5")
62+
self.assertTrue(os.path.exists(ais_path))
6163

6264
# Test precomputation with image stack.
6365
emb_path2 = os.path.join(self.tmp_folder, "embedddings2.zarr")
6466
run([
6567
"micro_sam.precompute_embeddings", "-i", self.tmp_folder, "-e", emb_path2,
66-
"-m", self.model_type, "-k", "*.tif"
68+
"-m", self.model_type, "-k", "*.tif", "--precompute_amg_state"
6769
])
6870
self.assertTrue(os.path.exists(emb_path2))
6971
with zarr.open(emb_path2, "r") as f:
7072
self.assertIn("features", f)
7173
self.assertEqual(f["features"].shape[0], 3)
74+
ais_path = os.path.join(emb_path2, "is_state.h5")
75+
self.assertTrue(os.path.exists(ais_path))
7276

7377
# Test precomputation with pattern to process multiple image.
7478
emb_path3 = os.path.join(self.tmp_folder, "embedddings3")
7579
run([
7680
"micro_sam.precompute_embeddings", "-i", self.tmp_folder, "-e", emb_path3,
77-
"-m", self.model_type, "--pattern", "*.tif"
81+
"-m", self.model_type, "--pattern", "*.tif", "--precompute_amg_state"
7882
])
7983
for i in range(3):
8084
self.assertTrue(os.path.exists(os.path.join(emb_path3, f"image-{i}.zarr")))
85+
ais_path = os.path.join(emb_path3, f"image-{i}.zarr", "is_state.h5")
86+
self.assertTrue(os.path.exists(ais_path))
8187

8288

8389
if __name__ == "__main__":

0 commit comments

Comments
 (0)