Skip to content

Commit e4ef8a8

Browse files
Implement state (de)serialization for the mask generators
1 parent 0e11316 commit e4ef8a8

File tree

2 files changed

+63
-5
lines changed

2 files changed

+63
-5
lines changed

micro_sam/instance_segmentation.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ class _AMGBase(ABC):
7272
"""
7373
"""
7474
def __init__(self):
75-
# the data that has to be computed by the 'initialize' method
76-
# of the child classes
75+
# the state that has to be computed by the 'initialize' method of the child classes
7776
self._is_initialized = False
7877
self._crop_list = None
7978
self._crop_boxes = None
@@ -234,6 +233,17 @@ def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, cr
234233

235234
return curr_anns
236235

236+
def get_state(self):
237+
if not self.is_initialized:
238+
raise RuntimeError("The state has not been computed yet. Call initialize first.")
239+
return {"crop_list": self.crop_list, "crop_boxes": self.crop_boxes, "original_size": self.original_size}
240+
241+
def set_state(self, state):
242+
self._crop_list = state["crop_list"]
243+
self._crop_boxes = state["crop_boxes"]
244+
self._original_size = state["original_size"]
245+
self._is_initialized = True
246+
237247

238248
class AutomaticMaskGenerator(_AMGBase):
239249
"""
@@ -434,7 +444,7 @@ def __init__(
434444
self.use_points = use_points
435445
self.box_extension = box_extension
436446

437-
# additional data that is computed by 'initialize'
447+
# additional state that is set 'initialize'
438448
self._initial_segmentation = None
439449

440450
def _compute_initial_segmentation(self):
@@ -558,6 +568,15 @@ def get_initial_segmentation(self):
558568
raise RuntimeError("AutomaticMaskGenerator has not been initialized. Call initialize first.")
559569
return self._resize_segmentation(self._initial_segmentation, self.original_size)
560570

571+
def get_state(self):
572+
state = super().get_state()
573+
state["initial_segmentation"] = self._initial_segmentation
574+
return state
575+
576+
def set_state(self, state):
577+
self._initial_segmentation = state["initial_segmentation"]
578+
super().set_state(state)
579+
561580

562581
class TiledEmbeddingMaskGenerator(EmbeddingMaskGenerator):
563582
"""
@@ -572,9 +591,13 @@ def __init__(
572591
super().__init__(predictor=predictor, **kwargs)
573592
self.n_threads = n_threads
574593
self.with_background = with_background
575-
# additional data for 'initialize'
594+
595+
# additional state for 'initialize'
576596
self._tile_shape = None
577597
self._halo = None
598+
599+
# state for saving the stitched initial segmentation
600+
# (this is quite complex, so we save it to only compute once)
578601
self._stitched_initial_segmentation = None
579602

580603
def _compute_initial_segmentations(self, image_embeddings, i, n_tiles, verbose):
@@ -747,6 +770,17 @@ def segment_tile(_, tile_id):
747770
self._stitched_initial_segmentation = initial_segmentation
748771
return initial_segmentation
749772

773+
def get_state(self):
774+
state = super().get_state()
775+
state["tile_shape"] = self._tile_shape
776+
state["halo"] = self._halo
777+
return state
778+
779+
def set_state(self, state):
780+
self._tile_shape = state["tile_shape"]
781+
self._halo = state["halo"]
782+
super().set_state(state)
783+
750784

751785
#
752786
# Functional interfaces to run instance segmentation

test/test_instance_segmentation.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,19 @@ def test_automatic_mask_generator(self):
6969
predicted = mask_data_to_segmentation(predicted, image.shape, with_background=True)
7070
self.assertGreater(matching(predicted, mask, threshold=0.75)["precision"], 0.99)
7171

72+
# check that regenerating the segmentation works
7273
predicted2 = amg.generate()
7374
predicted2 = mask_data_to_segmentation(predicted2, image.shape, with_background=True)
7475
self.assertTrue(np.array_equal(predicted, predicted2))
7576

77+
# check that serializing and reserializing the state works
78+
state = amg.get_state()
79+
amg = AutomaticMaskGenerator(predictor, points_per_side=10, points_per_batch=16)
80+
amg.set_state(state)
81+
predicted3 = amg.generate()
82+
predicted3 = mask_data_to_segmentation(predicted3, image.shape, with_background=True)
83+
self.assertTrue(np.array_equal(predicted, predicted3))
84+
7685
def test_embedding_mask_generator(self):
7786
from micro_sam.instance_segmentation import EmbeddingMaskGenerator, mask_data_to_segmentation
7887

@@ -89,11 +98,19 @@ def test_embedding_mask_generator(self):
8998
initial_seg = amg.get_initial_segmentation()
9099
self.assertEqual(initial_seg.shape, image.shape)
91100

101+
# check that regenerating the segmentation works
92102
predicted2 = amg.generate(pred_iou_thresh=0.96)
93103
predicted2 = mask_data_to_segmentation(predicted2, image.shape, with_background=True)
94-
95104
self.assertTrue(np.array_equal(predicted, predicted2))
96105

106+
# check that serializing and reserializing the state works
107+
state = amg.get_state()
108+
amg = EmbeddingMaskGenerator(predictor)
109+
amg.set_state(state)
110+
predicted3 = amg.generate(pred_iou_thresh=0.96)
111+
predicted3 = mask_data_to_segmentation(predicted3, image.shape, with_background=True)
112+
self.assertTrue(np.array_equal(predicted, predicted3))
113+
97114
def test_tiled_embedding_mask_generator(self):
98115
from micro_sam.instance_segmentation import TiledEmbeddingMaskGenerator
99116

@@ -112,6 +129,13 @@ def test_tiled_embedding_mask_generator(self):
112129
predicted2 = amg.generate(pred_iou_thresh=0.96)
113130
self.assertTrue(np.array_equal(predicted, predicted2))
114131

132+
# check that serializing and reserializing the state works
133+
state = amg.get_state()
134+
amg = TiledEmbeddingMaskGenerator(predictor)
135+
amg.set_state(state)
136+
predicted3 = amg.generate(pred_iou_thresh=0.96)
137+
self.assertTrue(np.array_equal(predicted, predicted3))
138+
115139

116140
if __name__ == "__main__":
117141
unittest.main()

0 commit comments

Comments
 (0)