Skip to content

Commit 3aac00d

Browse files
Add doc strings and type annotation to prompt_based_segmentation
1 parent bae896e commit 3aac00d

File tree

1 file changed

+107
-16
lines changed

1 file changed

+107
-16
lines changed

micro_sam/prompt_based_segmentation.py

Lines changed: 107 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import warnings
2+
from typing import Optional
23

34
import numpy as np
45
from nifty.tools import blocking
56
from skimage.feature import peak_local_max
67
from skimage.filters import gaussian
78
from scipy.ndimage import distance_transform_edt
89

10+
from segment_anything.predictor import SamPredictor
911
from segment_anything.utils.transforms import ResizeLongestSide
1012
from . import util
1113

@@ -229,7 +231,7 @@ def _tile_to_full_mask(mask, shape, tile, return_all, multimask_output):
229231

230232

231233
#
232-
# functions for prompted:
234+
# functions for prompted segmentation:
233235
# - segment_from_points: use point prompts as input
234236
# - segment_from_mask: use binary mask as input, support conversion to mask, box and point prompts
235237
# - segment_from_box: use box prompt as input
@@ -238,10 +240,30 @@ def _tile_to_full_mask(mask, shape, tile, return_all, multimask_output):
238240

239241

240242
def segment_from_points(
241-
predictor, points, labels,
242-
image_embeddings=None,
243-
i=None, multimask_output=False, return_all=False,
243+
predictor: SamPredictor,
244+
points: np.ndarray,
245+
labels: np.ndarray,
246+
image_embeddings: Optional[util.ImageEmbeddings] = None,
247+
i: Optional[int] = None,
248+
multimask_output: bool = False,
249+
return_all: bool = False,
244250
):
251+
"""Segmentation from point prompts.
252+
253+
Args:
254+
predictor: The segment anything predictor.
255+
points: The point prompts given in the image coordinate system.
256+
labels: The labels (positive or negative) associated with the points.
257+
image_embeddings: Optional precomputed image embeddings.
258+
Has to be passed if the predictor is not yet initialized.
259+
i: Index for the image data. Required if the input data has three spatial dimensions
260+
or a time dimension and two spatial dimensions.
261+
multimask_output: Whether to return multiple or just a single mask.
262+
return_all: Whether to return the score and logits in addition to the mask.
263+
264+
Returns:
265+
The binary segmentation mask.
266+
"""
245267
predictor, tile, prompts, shape = _initialize_predictor(
246268
predictor, image_embeddings, i, (points, labels), _points_to_tile
247269
)
@@ -264,13 +286,38 @@ def segment_from_points(
264286

265287
# use original_size if the mask is downscaled w.r.t. the original image size
266288
def segment_from_mask(
267-
predictor, mask,
268-
image_embeddings=None, i=None,
269-
use_box=True, use_mask=True, use_points=False,
270-
original_size=None, multimask_output=False,
271-
return_all=False, return_logits=False,
272-
box_extension=0,
289+
predictor: SamPredictor,
290+
mask: np.ndarray,
291+
image_embeddings: Optional[util.ImageEmbeddings] = None,
292+
i: Optional[int] = None,
293+
use_box: bool = True,
294+
use_mask: bool = True,
295+
use_points: bool = False,
296+
original_size: Optional[tuple[int, ...]] = None,
297+
multimask_output: bool = False,
298+
return_all: bool = False,
299+
return_logits: bool = False,
300+
box_extension: float = 0.0,
273301
):
302+
"""Segmentation from a mask prompt.
303+
304+
Args:
305+
predictor: The segment anything predictor.
306+
mask: The mask used to derive prompts.
307+
image_embeddings: Optional precomputed image embeddings.
308+
Has to be passed if the predictor is not yet initialized.
309+
i: Index for the image data. Required if the input data has three spatial dimensions
310+
or a time dimension and two spatial dimensions.
311+
use_box: Whether to derive the bounding box prompt from the mask.
312+
use_mask: Whether to use the mask itself as prompt.
313+
use_points: Wehter to derive point prompts from the mask.
314+
multimask_output: Whether to return multiple or just a single mask.
315+
return_all: Whether to return the score and logits in addition to the mask.
316+
box_extension: Relative factor used to enlarge the bounding box prompt.
317+
318+
Returns:
319+
The binary segmentation mask.
320+
"""
274321
predictor, tile, mask, shape = _initialize_predictor(
275322
predictor, image_embeddings, i, mask, _mask_to_tile
276323
)
@@ -299,10 +346,30 @@ def segment_from_mask(
299346

300347

301348
def segment_from_box(
302-
predictor, box,
303-
image_embeddings=None, i=None, original_size=None,
304-
multimask_output=False, return_all=False,
349+
predictor: SamPredictor,
350+
box: np.ndarray,
351+
image_embeddings: Optional[util.ImageEmbeddings] = None,
352+
i: Optional[int] = None,
353+
original_size: Optional[tuple[int, ...]] = None,
354+
multimask_output: bool = False,
355+
return_all: bool = False,
305356
):
357+
"""Segmentation from a box prompt.
358+
359+
Args:
360+
predictor: The segment anything predictor.
361+
box: The box prompt.
362+
image_embeddings: Optional precomputed image embeddings.
363+
Has to be passed if the predictor is not yet initialized.
364+
i: Index for the image data. Required if the input data has three spatial dimensions
365+
or a time dimension and two spatial dimensions.
366+
original_size: The original image shape.
367+
multimask_output: Whether to return multiple or just a single mask.
368+
return_all: Whether to return the score and logits in addition to the mask.
369+
370+
Returns:
371+
The binary segmentation mask.
372+
"""
306373
predictor, tile, box, shape = _initialize_predictor(
307374
predictor, image_embeddings, i, box, _box_to_tile
308375
)
@@ -317,10 +384,34 @@ def segment_from_box(
317384

318385

319386
def segment_from_box_and_points(
320-
predictor, box, points, labels,
321-
image_embeddings=None, i=None, original_size=None,
322-
multimask_output=False, return_all=False,
387+
predictor: SamPredictor,
388+
box: np.ndarray,
389+
points: np.ndarray,
390+
labels: np.ndarray,
391+
image_embeddings: Optional[util.ImageEmbeddings] = None,
392+
i: Optional[int] = None,
393+
original_size: Optional[tuple[int, ...]] = None,
394+
multimask_output: bool = False,
395+
return_all: bool = False,
323396
):
397+
"""Segmentation from a box prompt and point prompts.
398+
399+
Args:
400+
predictor: The segment anything predictor.
401+
box: The box prompt.
402+
points: The point prompts, given in the image coordinates system.
403+
labels: The point labels, either positive or negative.
404+
image_embeddings: Optional precomputed image embeddings.
405+
Has to be passed if the predictor is not yet initialized.
406+
i: Index for the image data. Required if the input data has three spatial dimensions
407+
or a time dimension and two spatial dimensions.
408+
original_size: The original image shape.
409+
multimask_output: Whether to return multiple or just a single mask.
410+
return_all: Whether to return the score and logits in addition to the mask.
411+
412+
Returns:
413+
The binary segmentation mask.
414+
"""
324415
def box_and_points_to_tile(prompts, shape, tile_shape, halo):
325416
box, points, labels = prompts
326417
tile_id, tile, point_prompts = _points_to_tile((points, labels), shape, tile_shape, halo)

0 commit comments

Comments
 (0)