@@ -72,8 +72,7 @@ class _AMGBase(ABC):
7272 """
7373 """
7474 def __init__ (self ):
75- # the data that has to be computed by the 'initialize' method
76- # of the child classes
75+ # the state that has to be computed by the 'initialize' method of the child classes
7776 self ._is_initialized = False
7877 self ._crop_list = None
7978 self ._crop_boxes = None
@@ -234,6 +233,17 @@ def _postprocess_masks(self, mask_data, min_mask_region_area, box_nms_thresh, cr
234233
235234 return curr_anns
236235
236+ def get_state (self ):
237+ if not self .is_initialized :
238+ raise RuntimeError ("The state has not been computed yet. Call initialize first." )
239+ return {"crop_list" : self .crop_list , "crop_boxes" : self .crop_boxes , "original_size" : self .original_size }
240+
241+ def set_state (self , state ):
242+ self ._crop_list = state ["crop_list" ]
243+ self ._crop_boxes = state ["crop_boxes" ]
244+ self ._original_size = state ["original_size" ]
245+ self ._is_initialized = True
246+
237247
238248class AutomaticMaskGenerator (_AMGBase ):
239249 """
@@ -434,7 +444,7 @@ def __init__(
434444 self .use_points = use_points
435445 self .box_extension = box_extension
436446
437- # additional data that is computed by 'initialize'
447+ # additional state that is set 'initialize'
438448 self ._initial_segmentation = None
439449
440450 def _compute_initial_segmentation (self ):
@@ -558,6 +568,15 @@ def get_initial_segmentation(self):
558568 raise RuntimeError ("AutomaticMaskGenerator has not been initialized. Call initialize first." )
559569 return self ._resize_segmentation (self ._initial_segmentation , self .original_size )
560570
571+ def get_state (self ):
572+ state = super ().get_state ()
573+ state ["initial_segmentation" ] = self ._initial_segmentation
574+ return state
575+
576+ def set_state (self , state ):
577+ self ._initial_segmentation = state ["initial_segmentation" ]
578+ super ().set_state (state )
579+
561580
562581class TiledEmbeddingMaskGenerator (EmbeddingMaskGenerator ):
563582 """
@@ -572,9 +591,13 @@ def __init__(
572591 super ().__init__ (predictor = predictor , ** kwargs )
573592 self .n_threads = n_threads
574593 self .with_background = with_background
575- # additional data for 'initialize'
594+
595+ # additional state for 'initialize'
576596 self ._tile_shape = None
577597 self ._halo = None
598+
599+ # state for saving the stitched initial segmentation
600+ # (this is quite complex, so we save it to only compute once)
578601 self ._stitched_initial_segmentation = None
579602
580603 def _compute_initial_segmentations (self , image_embeddings , i , n_tiles , verbose ):
@@ -747,6 +770,17 @@ def segment_tile(_, tile_id):
747770 self ._stitched_initial_segmentation = initial_segmentation
748771 return initial_segmentation
749772
773+ def get_state (self ):
774+ state = super ().get_state ()
775+ state ["tile_shape" ] = self ._tile_shape
776+ state ["halo" ] = self ._halo
777+ return state
778+
779+ def set_state (self , state ):
780+ self ._tile_shape = state ["tile_shape" ]
781+ self ._halo = state ["halo" ]
782+ super ().set_state (state )
783+
750784
751785#
752786# Functional interfaces to run instance segmentation
0 commit comments