@@ -82,8 +82,8 @@ def mask_data_to_segmentation(
8282#
8383
8484
85- class _AMGBase (ABC ):
86- """
85+ class AMGBase (ABC ):
86+ """Base class for the automatic mask generators.
8787 """
8888 def __init__ (self ):
8989 # the state that has to be computed by the 'initialize' method of the child classes
@@ -277,7 +277,7 @@ def set_state(self, state: Dict[str, Any]) -> None:
277277 self ._is_initialized = True
278278
279279
280- class AutomaticMaskGenerator (_AMGBase ):
280+ class AutomaticMaskGenerator (AMGBase ):
281281 """Generates an instance segmentation without prompts, using a point grid.
282282
283283 This class implements the same logic as
@@ -358,8 +358,11 @@ def _process_batch(self, points, im_size):
358358
359359 def _process_crop (self , image , crop_box , crop_layer_idx , verbose , precomputed_embeddings ):
360360 # crop the image and calculate embeddings
361- x0 , y0 , x1 , y1 = crop_box
362- cropped_im = image [y0 :y1 , x0 :x1 , :]
361+ if crop_box is None :
362+ cropped_im = image
363+ else :
364+ x0 , y0 , x1 , y1 = crop_box
365+ cropped_im = image [y0 :y1 , x0 :x1 , :]
363366 cropped_im_size = cropped_im .shape [:2 ]
364367
365368 if not precomputed_embeddings :
@@ -477,7 +480,7 @@ def generate(
477480 )
478481 data .cat (crop_data )
479482
480- if len (self .crop_boxes ) > 1 :
483+ if len (self .crop_boxes ) > 1 and len ( data [ "crop_boxes" ]) > 0 :
481484 # Prefer masks from smaller crops
482485 scores = 1 / box_area (data ["crop_boxes" ])
483486 scores = scores .to (data ["boxes" ].device )
@@ -494,7 +497,7 @@ def generate(
494497 return masks
495498
496499
497- class EmbeddingMaskGenerator (_AMGBase ):
500+ class EmbeddingMaskGenerator (AMGBase ):
498501 """Generates an instance segmentation without prompts, using an initial segmentations derived from image embeddings.
499502
500503 Uses an intial segmentation derived from the image embeddings via the Mutex Watershed,
@@ -718,6 +721,133 @@ def set_state(self, state: Dict[str, Any]) -> None:
718721 super ().set_state (state )
719722
720723
724+ def _compute_tiled_embeddings (predictor , image , image_embeddings , embedding_save_path , tile_shape , halo ):
725+ have_tiling_params = (tile_shape is not None ) and (halo is not None )
726+ if image_embeddings is None and have_tiling_params :
727+ if embedding_save_path is None :
728+ raise ValueError (
729+ "You have passed neither pre-computed embeddings nor a path for saving embeddings."
730+ "Embeddings with tiling can only be computed if a save path is given."
731+ )
732+ image_embeddings = util .precompute_image_embeddings (
733+ predictor , image , tile_shape = tile_shape , halo = halo , save_path = embedding_save_path
734+ )
735+ elif image_embeddings is None and not have_tiling_params :
736+ raise ValueError ("You passed neither pre-computed embeddings nor tiling parameters (tile_shape and halo)" )
737+ else :
738+ feats = image_embeddings ["features" ]
739+ tile_shape_ , halo_ = feats .attrs ["tile_shape" ], feats .attrs ["halo" ]
740+ if have_tiling_params and (
741+ (list (tile_shape ) != list (tile_shape_ )) or
742+ (list (halo ) != list (halo_ ))
743+ ):
744+ warnings .warn (
745+ "You have passed both pre-computed embeddings and tiling parameters (tile_shape and halo) and"
746+ "the values of the tiling parameters from the embeddings disagree with the ones that were passed."
747+ "The tiling parameters you have passed wil be ignored."
748+ )
749+ tile_shape = tile_shape_
750+ halo = halo_
751+
752+ return image_embeddings , tile_shape , halo
753+
754+
755+ class TiledAutomaticMaskGenerator (AutomaticMaskGenerator ):
756+ """Generates an instance segmentation without prompts, using a point grid.
757+
758+ Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
759+
760+ Args:
761+ predictor: The segment anything predictor.
762+ points_per_side: The number of points to be sampled along one side of the image.
763+ If None, `point_grids` must provide explicit point sampling.
764+ points_per_batch: The number of points run simultaneously by the model.
765+ Higher numbers may be faster but use more GPU memory.
766+ point_grids: A lisst over explicit grids of points used for sampling masks.
767+ Normalized to [0, 1] with respect to the image coordinate system.
768+ """
769+
770+ # We only expose the arguments that make sense for the tiled mask generator.
771+ # Anything related to crops doesn't make sense, because we re-use that functionality
772+ # for tiling, so these parameters wouldn't have any effect.
773+ def __init__ (
774+ self ,
775+ predictor : SamPredictor ,
776+ points_per_side : Optional [int ] = 32 ,
777+ points_per_batch : int = 64 ,
778+ point_grids : Optional [List [np .ndarray ]] = None ,
779+ ) -> None :
780+ super ().__init__ (
781+ predictor = predictor ,
782+ points_per_side = points_per_side ,
783+ points_per_batch = points_per_batch ,
784+ point_grids = point_grids ,
785+ )
786+
787+ @torch .no_grad ()
788+ def initialize (
789+ self ,
790+ image : np .ndarray ,
791+ image_embeddings : Optional [util .ImageEmbeddings ] = None ,
792+ i : Optional [int ] = None ,
793+ tile_shape : Optional [Tuple [int , int ]] = None ,
794+ halo : Optional [Tuple [int , int ]] = None ,
795+ verbose : bool = False ,
796+ embedding_save_path : Optional [str ] = None ,
797+ ) -> None :
798+ """Initialize image embeddings and masks for an image.
799+
800+ Args:
801+ image: The input image, volume or timeseries.
802+ image_embeddings: Optional precomputed image embeddings.
803+ See `util.precompute_image_embeddings` for details.
804+ i: Index for the image data. Required if `image` has three spatial dimensions
805+ or a time dimension and two spatial dimensions.
806+ tile_shape: The tile shape for embedding prediction.
807+ halo: The overlap of between tiles.
808+ verbose: Whether to print computation progress.
809+ embedding_save_path: Where to save the image embeddings.
810+ """
811+ original_size = image .shape [:2 ]
812+ image_embeddings , tile_shape , halo = _compute_tiled_embeddings (
813+ self ._predictor , image , image_embeddings , embedding_save_path , tile_shape , halo
814+ )
815+
816+ tiling = blocking ([0 , 0 ], original_size , tile_shape )
817+ n_tiles = tiling .numberOfBlocks
818+
819+ mask_data = []
820+ for tile_id in tqdm (range (n_tiles ), total = n_tiles , desc = "Compute masks for tile" , disable = not verbose ):
821+ # get the bounding box for this tile and crop the image data
822+ tile = tiling .getBlockWithHalo (tile_id , list (halo )).outerBlock
823+ tile_bb = tuple (slice (beg , end ) for beg , end in zip (tile .begin , tile .end ))
824+ tile_data = image [tile_bb ]
825+
826+ # set the pre-computed embeddings for this tile
827+ features = image_embeddings ["features" ][tile_id ]
828+ tile_embeddings = {
829+ "features" : features ,
830+ "input_size" : features .attrs ["input_size" ],
831+ "original_size" : features .attrs ["original_size" ],
832+ }
833+ util .set_precomputed (self ._predictor , tile_embeddings , i )
834+
835+ # compute the mask data for this tile and append it
836+ this_mask_data = self ._process_crop (
837+ tile_data , crop_box = None , crop_layer_idx = 0 , verbose = verbose , precomputed_embeddings = True
838+ )
839+ mask_data .append (this_mask_data )
840+
841+ # set the initialized data
842+ self ._is_initialized = True
843+ self ._crop_list = mask_data
844+ self ._original_size = original_size
845+
846+ # the crop box is always the full local tile
847+ tiles = [tiling .getBlockWithHalo (tile_id , list (halo )).outerBlock for tile_id in range (n_tiles )]
848+ self ._crop_boxes = [[tile .begin [1 ], tile .begin [0 ], tile .end [1 ], tile .end [0 ]] for tile in tiles ]
849+
850+
721851class TiledEmbeddingMaskGenerator (EmbeddingMaskGenerator ):
722852 """Generates an instance segmentation without prompts, using an initial segmentations derived from image embeddings.
723853
@@ -812,33 +942,9 @@ def initialize(
812942 embedding_save_path: Where to save the image embeddings.
813943 """
814944 original_size = image .shape [:2 ]
815-
816- have_tiling_params = (tile_shape is not None ) and (halo is not None )
817- if image_embeddings is None and have_tiling_params :
818- if embedding_save_path is None :
819- raise ValueError (
820- "You have passed neither pre-computed embeddings nor a path for saving embeddings."
821- "Embeddings with tiling can only be computed if a save path is given."
822- )
823- image_embeddings = util .precompute_image_embeddings (
824- self ._predictor , image , tile_shape = tile_shape , halo = halo , save_path = embedding_save_path
825- )
826- elif image_embeddings is None and not have_tiling_params :
827- raise ValueError ("You passed neither pre-computed embeddings nor tiling parameters (tile_shape and halo)" )
828- else :
829- feats = image_embeddings ["features" ]
830- tile_shape_ , halo_ = feats .attrs ["tile_shape" ], feats .attrs ["halo" ]
831- if have_tiling_params and (
832- (list (tile_shape ) != list (tile_shape_ )) or
833- (list (halo ) != list (halo_ ))
834- ):
835- warnings .warn (
836- "You have passed both pre-computed embeddings and tiling parameters (tile_shape and halo) and"
837- "the values of the tiling parameters from the embeddings disagree with the ones that were passed."
838- "The tiling parameters you have passed wil be ignored."
839- )
840- tile_shape = tile_shape_
841- halo = halo_
945+ image_embeddings , tile_shape , halo = _compute_tiled_embeddings (
946+ self ._predictor , image , image_embeddings , embedding_save_path , tile_shape , halo
947+ )
842948
843949 tiling = blocking ([0 , 0 ], original_size , tile_shape )
844950 n_tiles = tiling .numberOfBlocks
0 commit comments