Skip to content

Commit ba2857a

Browse files
Enable calling generate multiple times in mask generators
1 parent 315f97b commit ba2857a

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

micro_sam/instance_segmentation.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import multiprocessing as mp
22
from abc import ABC
33
from concurrent import futures
4+
from copy import deepcopy
45
from typing import List, Optional
56

67
import numpy as np
@@ -373,10 +374,12 @@ def generate(
373374
):
374375
if not self.is_initialized:
375376
raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
377+
376378
data = amg_utils.MaskData()
377379
for data_, crop_box in zip(self.crop_list, self.crop_boxes):
378380
crop_data = self._postprocess_batch(
379-
data=data_, crop_box=crop_box, original_size=self.original_size,
381+
data=deepcopy(data_),
382+
crop_box=crop_box, original_size=self.original_size,
380383
pred_iou_thresh=pred_iou_thresh,
381384
stability_score_thresh=stability_score_thresh,
382385
stability_score_offset=stability_score_offset,
@@ -527,7 +530,8 @@ def generate(
527530
raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
528531

529532
data = self._postprocess_batch(
530-
data=self.crop_list[0], crop_box=self.crop_boxes[0], original_size=self.original_size,
533+
data=deepcopy(self.crop_list[0]), crop_box=self.crop_boxes[0],
534+
original_size=self.original_size,
531535
pred_iou_thresh=pred_iou_thresh,
532536
stability_score_thresh=stability_score_thresh,
533537
stability_score_offset=stability_score_offset,
@@ -541,12 +545,12 @@ def generate(
541545
def _resize_segmentation(self, segmentation, shape):
542546
longest_size = max(shape)
543547
longest_shape = (longest_size, longest_size)
544-
segmentation_ = resize(
548+
resized_segmentation = resize(
545549
segmentation, longest_shape, order=0, preserve_range=True, anti_aliasing=False
546550
).astype(segmentation.dtype)
547551
crop = tuple(slice(0, sh) for sh in shape)
548-
segmentation_ = segmentation_[crop]
549-
return segmentation_
552+
resized_segmentation = resized_segmentation[crop]
553+
return resized_segmentation
550554

551555
def get_initial_segmentation(self):
552556
if not self.is_initialized:
@@ -667,7 +671,7 @@ def generate(
667671

668672
def segment_tile(_, tile_id):
669673
tile = tiling.getBlockWithHalo(tile_id, list(self._halo)).outerBlock
670-
mask_data = self._crop_list[tile_id]
674+
mask_data = deepcopy(self._crop_list[tile_id])
671675
crop_box = self.crop_boxes[tile_id]
672676
this_tile_shape = tuple(end - beg for beg, end in zip(tile.begin, tile.end))
673677
mask_data = self._postprocess_batch(

test/test_instance_segmentation.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,15 @@ def test_automatic_mask_generator(self):
6464

6565
amg = AutomaticMaskGenerator(predictor, points_per_side=10, points_per_batch=16)
6666
amg.initialize(image, image_embeddings=image_embeddings, verbose=False)
67+
6768
predicted = amg.generate()
6869
predicted = mask_data_to_segmentation(predicted, image.shape, with_background=True)
69-
7070
self.assertGreater(matching(predicted, mask, threshold=0.75)["precision"], 0.99)
7171

72+
predicted2 = amg.generate()
73+
predicted2 = mask_data_to_segmentation(predicted2, image.shape, with_background=True)
74+
self.assertTrue(np.array_equal(predicted, predicted2))
75+
7276
def test_embedding_mask_generator(self):
7377
from micro_sam.instance_segmentation import EmbeddingMaskGenerator, mask_data_to_segmentation
7478

@@ -85,6 +89,11 @@ def test_embedding_mask_generator(self):
8589
initial_seg = amg.get_initial_segmentation()
8690
self.assertEqual(initial_seg.shape, image.shape)
8791

92+
predicted2 = amg.generate(pred_iou_thresh=0.96)
93+
predicted2 = mask_data_to_segmentation(predicted2, image.shape, with_background=True)
94+
95+
self.assertTrue(np.array_equal(predicted, predicted2))
96+
8897
def test_tiled_embedding_mask_generator(self):
8998
from micro_sam.instance_segmentation import TiledEmbeddingMaskGenerator
9099

@@ -100,6 +109,9 @@ def test_tiled_embedding_mask_generator(self):
100109
self.assertGreater(matching(predicted, mask, threshold=0.75)["precision"], 0.99)
101110
self.assertEqual(initial_seg.shape, image.shape)
102111

112+
predicted2 = amg.generate(pred_iou_thresh=0.96)
113+
self.assertTrue(np.array_equal(predicted, predicted2))
114+
103115

104116
if __name__ == "__main__":
105117
unittest.main()

0 commit comments

Comments
 (0)