@@ -753,18 +753,36 @@ def _compute_tiled_embeddings(predictor, image, image_embeddings, embedding_save
753753
754754
755755class TiledAutomaticMaskGenerator (AutomaticMaskGenerator ):
756+ """Generates an instance segmentation without prompts, using a point grid.
757+
758+ Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
759+
760+ Args:
761+ predictor: The segment anything predictor.
762+ points_per_side: The number of points to be sampled along one side of the image.
763+ If None, `point_grids` must provide explicit point sampling.
764+ points_per_batch: The number of points run simultaneously by the model.
765+ Higher numbers may be faster but use more GPU memory.
766+ point_grids: A lisst over explicit grids of points used for sampling masks.
767+ Normalized to [0, 1] with respect to the image coordinate system.
756768 """
757- """
758- # def __init__(
759- # self,
760- # predictor: SamPredictor,
761- # **kwargs
762- # ):
763- # super().__init__(predictor=predictor, **kwargs)
764769
765- # # additional state for 'initialize'
766- # self._tile_shape = None
767- # self._halo = None
770+ # We only expose the arguments that make sense for the tiled mask generator.
771+ # Anything related to crops doesn't make sense, because we re-use that functionality
772+ # for tiling, so these parameters wouldn't have any effect.
773+ def __init__ (
774+ self ,
775+ predictor : SamPredictor ,
776+ points_per_side : Optional [int ] = 32 ,
777+ points_per_batch : int = 64 ,
778+ point_grids : Optional [List [np .ndarray ]] = None ,
779+ ) -> None :
780+ super ().__init__ (
781+ predictor = predictor ,
782+ points_per_side = points_per_side ,
783+ points_per_batch = points_per_batch ,
784+ point_grids = point_grids ,
785+ )
768786
769787 @torch .no_grad ()
770788 def initialize (
@@ -777,7 +795,18 @@ def initialize(
777795 verbose : bool = False ,
778796 embedding_save_path : Optional [str ] = None ,
779797 ) -> None :
780- """
798+ """Initialize image embeddings and masks for an image.
799+
800+ Args:
801+ image: The input image, volume or timeseries.
802+ image_embeddings: Optional precomputed image embeddings.
803+ See `util.precompute_image_embeddings` for details.
804+ i: Index for the image data. Required if `image` has three spatial dimensions
805+ or a time dimension and two spatial dimensions.
806+ tile_shape: The tile shape for embedding prediction.
807+ halo: The overlap of between tiles.
808+ verbose: Whether to print computation progress.
809+ embedding_save_path: Where to save the image embeddings.
781810 """
782811 original_size = image .shape [:2 ]
783812 image_embeddings , tile_shape , halo = _compute_tiled_embeddings (
0 commit comments