Skip to content

Commit a7dd369

Browse files
Update TiledAutomaticMaskGenerator
1 parent 65833ad commit a7dd369

File tree

1 file changed

+40
-11
lines changed

1 file changed

+40
-11
lines changed

micro_sam/instance_segmentation.py

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -753,18 +753,36 @@ def _compute_tiled_embeddings(predictor, image, image_embeddings, embedding_save
753753

754754

755755
class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
756+
"""Generates an instance segmentation without prompts, using a point grid.
757+
758+
Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
759+
760+
Args:
761+
predictor: The segment anything predictor.
762+
points_per_side: The number of points to be sampled along one side of the image.
763+
If None, `point_grids` must provide explicit point sampling.
764+
points_per_batch: The number of points run simultaneously by the model.
765+
Higher numbers may be faster but use more GPU memory.
766+
point_grids: A lisst over explicit grids of points used for sampling masks.
767+
Normalized to [0, 1] with respect to the image coordinate system.
756768
"""
757-
"""
758-
# def __init__(
759-
# self,
760-
# predictor: SamPredictor,
761-
# **kwargs
762-
# ):
763-
# super().__init__(predictor=predictor, **kwargs)
764769

765-
# # additional state for 'initialize'
766-
# self._tile_shape = None
767-
# self._halo = None
770+
# We only expose the arguments that make sense for the tiled mask generator.
771+
# Anything related to crops doesn't make sense, because we re-use that functionality
772+
# for tiling, so these parameters wouldn't have any effect.
773+
def __init__(
774+
self,
775+
predictor: SamPredictor,
776+
points_per_side: Optional[int] = 32,
777+
points_per_batch: int = 64,
778+
point_grids: Optional[List[np.ndarray]] = None,
779+
) -> None:
780+
super().__init__(
781+
predictor=predictor,
782+
points_per_side=points_per_side,
783+
points_per_batch=points_per_batch,
784+
point_grids=point_grids,
785+
)
768786

769787
@torch.no_grad()
770788
def initialize(
@@ -777,7 +795,18 @@ def initialize(
777795
verbose: bool = False,
778796
embedding_save_path: Optional[str] = None,
779797
) -> None:
780-
"""
798+
"""Initialize image embeddings and masks for an image.
799+
800+
Args:
801+
image: The input image, volume or timeseries.
802+
image_embeddings: Optional precomputed image embeddings.
803+
See `util.precompute_image_embeddings` for details.
804+
i: Index for the image data. Required if `image` has three spatial dimensions
805+
or a time dimension and two spatial dimensions.
806+
tile_shape: The tile shape for embedding prediction.
807+
halo: The overlap of between tiles.
808+
verbose: Whether to print computation progress.
809+
embedding_save_path: Where to save the image embeddings.
781810
"""
782811
original_size = image.shape[:2]
783812
image_embeddings, tile_shape, halo = _compute_tiled_embeddings(

0 commit comments

Comments
 (0)