Skip to content

Commit 315f97b

Browse files
Add test for tiled embedding mask generator and add some type annotations
1 parent 748277b commit 315f97b

File tree

2 files changed

+66
-17
lines changed

2 files changed

+66
-17
lines changed

micro_sam/instance_segmentation.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,13 @@ def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_em
320320
return data
321321

322322
@torch.no_grad()
323-
def initialize(self, image: np.ndarray, image_embeddings=None, i=None, verbose=False):
323+
def initialize(
324+
self,
325+
image: np.ndarray,
326+
image_embeddings=None,
327+
i: Optional[int] = None,
328+
verbose: bool = False
329+
):
324330
"""
325331
"""
326332
original_size = image.shape[:2]
@@ -403,14 +409,14 @@ class EmbeddingMaskGenerator(_AMGBase):
403409
def __init__(
404410
self,
405411
predictor: SamPredictor,
406-
offsets=None,
407-
min_initial_size=0,
408-
distance_type="l2",
409-
bias=0.0,
410-
use_box=True,
411-
use_mask=True,
412-
use_points=False,
413-
box_extension=0.05,
412+
offsets: Optional[List[List[int]]] = None,
413+
min_initial_size: int = 0,
414+
distance_type: str = "l2",
415+
bias: float = 0.0,
416+
use_box: bool = True,
417+
use_mask: bool = True,
418+
use_points: bool = False,
419+
box_extension: float = 0.05,
414420
):
415421
super().__init__()
416422

@@ -475,7 +481,13 @@ def _compute_mask_data(self, initial_segmentation, original_size, verbose):
475481
return mask_data
476482

477483
@torch.no_grad()
478-
def initialize(self, image: np.ndarray, image_embeddings=None, i=None, verbose=False):
484+
def initialize(
485+
self,
486+
image: np.ndarray,
487+
image_embeddings=None,
488+
i: Optional[int] = None,
489+
verbose: bool = False
490+
):
479491
"""
480492
"""
481493
original_size = image.shape[:2]
@@ -545,10 +557,17 @@ def get_initial_segmentation(self):
545557
class TiledEmbeddingMaskGenerator(EmbeddingMaskGenerator):
546558
"""
547559
"""
548-
def __init__(self, n_threads=mp.cpu_count(), with_background=True, **kwargs):
549-
super().__init__(**kwargs)
560+
def __init__(
561+
self,
562+
predictor: SamPredictor,
563+
n_threads: int = mp.cpu_count(),
564+
with_background: bool = True,
565+
**kwargs
566+
):
567+
super().__init__(predictor=predictor, **kwargs)
550568
self.n_threads = n_threads
551569
self.with_background = with_background
570+
# additional data for 'initialize'
552571
self._tile_shape = None
553572
self._halo = None
554573
self._stitched_initial_segmentation = None
@@ -599,14 +618,17 @@ def initialize(
599618
tile_shape: List[int],
600619
halo: List[int],
601620
image_embeddings=None,
602-
i=None,
603-
verbose=False,
621+
i: Optional[int] = None,
622+
verbose: bool = False,
623+
embedding_save_path: Optional[str] = None,
604624
):
605625
"""
606626
"""
607627
original_size = image.shape[:2]
608628
if image_embeddings is None:
609-
image_embeddings = util.precompute_image_embeddings(self.predictor, image, tile_shape=tile_shape, halo=halo)
629+
image_embeddings = util.precompute_image_embeddings(
630+
self.predictor, image, tile_shape=tile_shape, halo=halo, save_path=embedding_save_path
631+
)
610632

611633
tiling = blocking([0, 0], original_size, tile_shape)
612634
n_tiles = tiling.numberOfBlocks

test/test_instance_segmentation.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import os
12
import unittest
3+
from shutil import rmtree
24

35
import micro_sam.util as util
46
import numpy as np
@@ -9,6 +11,8 @@
911

1012

1113
class TestInstanceSegmentation(unittest.TestCase):
14+
embedding_path = "./tmp_embeddings.zarr"
15+
1216
# create an input image with three objects
1317
@staticmethod
1418
def _get_input(shape=(256, 256)):
@@ -32,9 +36,12 @@ def write_object(center, radius):
3236
return mask, image
3337

3438
@staticmethod
35-
def _get_model(image):
39+
def _get_model(image, tile_shape=None, halo=None, save_path=None):
3640
predictor = util.get_sam_model(model_type="vit_b")
37-
image_embeddings = util.precompute_image_embeddings(predictor, image)
41+
42+
image_embeddings = util.precompute_image_embeddings(
43+
predictor, image, tile_shape=tile_shape, halo=halo, save_path=save_path
44+
)
3845
return predictor, image_embeddings
3946

4047
# we compute the default mask and predictor once for the class
@@ -44,6 +51,11 @@ def setUpClass(cls):
4451
cls.mask, cls.image = cls._get_input()
4552
cls.predictor, cls.image_embeddings = cls._get_model(cls.image)
4653

54+
# remove temp embeddings if any
55+
def tearDown(self):
56+
if os.path.exists(self.embedding_path):
57+
rmtree(self.embedding_path)
58+
4759
def test_automatic_mask_generator(self):
4860
from micro_sam.instance_segmentation import AutomaticMaskGenerator, mask_data_to_segmentation
4961

@@ -73,6 +85,21 @@ def test_embedding_mask_generator(self):
7385
initial_seg = amg.get_initial_segmentation()
7486
self.assertEqual(initial_seg.shape, image.shape)
7587

88+
def test_tiled_embedding_mask_generator(self):
89+
from micro_sam.instance_segmentation import TiledEmbeddingMaskGenerator
90+
91+
tile_shape, halo = (576, 576), (64, 64)
92+
mask, image = self._get_input(shape=(1024, 1024))
93+
predictor, image_embeddings = self._get_model(image, tile_shape, halo, self.embedding_path)
94+
95+
amg = TiledEmbeddingMaskGenerator(predictor)
96+
amg.initialize(image, image_embeddings=image_embeddings, tile_shape=tile_shape, halo=halo)
97+
predicted = amg.generate(pred_iou_thresh=0.96)
98+
initial_seg = amg.get_initial_segmentation()
99+
100+
self.assertGreater(matching(predicted, mask, threshold=0.75)["precision"], 0.99)
101+
self.assertEqual(initial_seg.shape, image.shape)
102+
76103

77104
if __name__ == "__main__":
78105
unittest.main()

0 commit comments

Comments
 (0)