Skip to content

Commit 5295f91

Browse files
Enable passing pre-computed embeddings to AutomaticMaskGenerator
1 parent 9300458 commit 5295f91

File tree

2 files changed

+47
-19
lines changed

2 files changed

+47
-19
lines changed

micro_sam/instance_segmentation.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,14 @@ def _process_batch(self, points, im_size):
288288

289289
return data
290290

291-
def _process_crop(self, image, crop_box, crop_layer_idx, verbose):
291+
def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_embeddings):
292292
# crop the image and calculate embeddings
293293
x0, y0, x1, y1 = crop_box
294294
cropped_im = image[y0:y1, x0:x1, :]
295295
cropped_im_size = cropped_im.shape[:2]
296-
self.predictor.set_image(cropped_im)
296+
297+
if not precomputed_embeddings:
298+
self.predictor.set_image(cropped_im)
297299

298300
# get the points for this crop
299301
points_scale = np.array(cropped_im_size)[None, ::-1]
@@ -312,23 +314,39 @@ def _process_crop(self, image, crop_box, crop_layer_idx, verbose):
312314
data.cat(batch_data)
313315
del batch_data
314316

315-
self.predictor.reset_image()
317+
if not precomputed_embeddings:
318+
self.predictor.reset_image()
319+
316320
return data
317321

318-
# TODO enable initializeing with embeddings
319-
# (which can be done for only a single crop box)
320322
@torch.no_grad()
321-
def initialize(self, image: np.ndarray, verbose=False):
323+
def initialize(self, image: np.ndarray, image_embeddings=None, i=None, embedding_path=None, verbose=False):
322324
"""
323325
"""
324-
image = util._to_image(image)
325326
original_size = image.shape[:2]
326327
crop_boxes, layer_idxs = amg_utils.generate_crop_boxes(
327328
original_size, self.crop_n_layers, self.crop_overlap_ratio
328329
)
330+
331+
# we can set fixed image embeddings if we only have a single crop box
332+
# (which is the default setting)
333+
# otherwise we have to recompute the embeddings for each crop and can't precompute
334+
if len(crop_boxes) == 1:
335+
if image_embeddings is None:
336+
image_embeddings = util.precompute_image_embeddings(self.predictor, image, save_path=embedding_path)
337+
util.set_precomputed(self.predictor, image_embeddings, i=i)
338+
precomputed_embeddings = True
339+
else:
340+
precomputed_embeddings = False
341+
342+
# we need to cast to the image representation that is compatible with SAM
343+
image = util._to_image(image)
344+
329345
crop_list = []
330346
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
331-
crop_data = self._process_crop(image, crop_box, layer_idx, verbose=verbose)
347+
crop_data = self._process_crop(
348+
image, crop_box, layer_idx, verbose=verbose, precomputed_embeddings=precomputed_embeddings
349+
)
332350
crop_list.append(crop_data)
333351

334352
self._is_initialized = True

test/test_instance_segmentation.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010

1111
class TestInstanceSegmentation(unittest.TestCase):
12-
1312
# create an input image with three objects
14-
def _get_input(self, shape=(512, 512)):
13+
@staticmethod
14+
def _get_input(shape=(256, 256)):
1515
mask = np.zeros(shape, dtype="uint8")
1616

1717
def write_object(center, radius):
@@ -31,30 +31,40 @@ def write_object(center, radius):
3131
mask = label(mask)
3232
return mask, image
3333

34-
def _get_model(self):
35-
return util.get_sam_model(model_type="vit_b", return_sam=False)
34+
@staticmethod
35+
def _get_model(image):
36+
predictor = util.get_sam_model(model_type="vit_b")
37+
image_embeddings = util.precompute_image_embeddings(predictor, image)
38+
return predictor, image_embeddings
39+
40+
# we compute the default mask and predictor once for the class
41+
# so that we don't have to precompute it every time
42+
@classmethod
43+
def setUpClass(cls):
44+
cls.mask, cls.image = cls._get_input()
45+
cls.predictor, cls.image_embeddings = cls._get_model(cls.image)
3646

3747
def test_automatic_mask_generator(self):
3848
from micro_sam.instance_segmentation import AutomaticMaskGenerator, mask_data_to_segmentation
3949

40-
mask, image = self._get_input(shape=(256, 256))
41-
predictor = self._get_model()
50+
mask, image = self.mask, self.image
51+
predictor, image_embeddings = self.predictor, self.image_embeddings
4252

4353
amg = AutomaticMaskGenerator(predictor, points_per_side=10, points_per_batch=16)
44-
amg.initialize(image, verbose=False)
54+
amg.initialize(image, image_embeddings=image_embeddings, verbose=False)
4555
predicted = amg.generate()
4656
predicted = mask_data_to_segmentation(predicted, image.shape, with_background=True)
4757

4858
self.assertGreater(matching(predicted, mask, threshold=0.75)["precision"], 0.99)
4959

50-
def test_embedding_based_mask_generator(self):
60+
def test_embedding_mask_generator(self):
5161
from micro_sam.instance_segmentation import EmbeddingMaskGenerator, mask_data_to_segmentation
5262

53-
mask, image = self._get_input()
54-
predictor = self._get_model()
63+
mask, image = self.mask, self.image
64+
predictor, image_embeddings = self.predictor, self.image_embeddings
5565

5666
amg = EmbeddingMaskGenerator(predictor)
57-
amg.initialize(image, verbose=False)
67+
amg.initialize(image, image_embeddings=image_embeddings, verbose=False)
5868
predicted = amg.generate(pred_iou_thresh=0.96)
5969
predicted = mask_data_to_segmentation(predicted, image.shape, with_background=True)
6070

0 commit comments

Comments
 (0)