11import warnings
2+ from typing import Optional
23
34import numpy as np
45from nifty .tools import blocking
56from skimage .feature import peak_local_max
67from skimage .filters import gaussian
78from scipy .ndimage import distance_transform_edt
89
10+ from segment_anything .predictor import SamPredictor
911from segment_anything .utils .transforms import ResizeLongestSide
1012from . import util
1113
@@ -229,7 +231,7 @@ def _tile_to_full_mask(mask, shape, tile, return_all, multimask_output):
229231
230232
231233#
232- # functions for prompted:
234+ # functions for prompted segmentation :
233235# - segment_from_points: use point prompts as input
234236# - segment_from_mask: use binary mask as input, support conversion to mask, box and point prompts
235237# - segment_from_box: use box prompt as input
@@ -238,10 +240,30 @@ def _tile_to_full_mask(mask, shape, tile, return_all, multimask_output):
238240
239241
240242def segment_from_points (
241- predictor , points , labels ,
242- image_embeddings = None ,
243- i = None , multimask_output = False , return_all = False ,
243+ predictor : SamPredictor ,
244+ points : np .ndarray ,
245+ labels : np .ndarray ,
246+ image_embeddings : Optional [util .ImageEmbeddings ] = None ,
247+ i : Optional [int ] = None ,
248+ multimask_output : bool = False ,
249+ return_all : bool = False ,
244250):
251+ """Segmentation from point prompts.
252+
253+ Args:
254+ predictor: The segment anything predictor.
255+ points: The point prompts given in the image coordinate system.
256+ labels: The labels (positive or negative) associated with the points.
257+ image_embeddings: Optional precomputed image embeddings.
258+ Has to be passed if the predictor is not yet initialized.
259+ i: Index for the image data. Required if the input data has three spatial dimensions
260+ or a time dimension and two spatial dimensions.
261+ multimask_output: Whether to return multiple or just a single mask.
262+ return_all: Whether to return the score and logits in addition to the mask.
263+
264+ Returns:
265+ The binary segmentation mask.
266+ """
245267 predictor , tile , prompts , shape = _initialize_predictor (
246268 predictor , image_embeddings , i , (points , labels ), _points_to_tile
247269 )
@@ -264,13 +286,38 @@ def segment_from_points(
264286
265287# use original_size if the mask is downscaled w.r.t. the original image size
266288def segment_from_mask (
267- predictor , mask ,
268- image_embeddings = None , i = None ,
269- use_box = True , use_mask = True , use_points = False ,
270- original_size = None , multimask_output = False ,
271- return_all = False , return_logits = False ,
272- box_extension = 0 ,
289+ predictor : SamPredictor ,
290+ mask : np .ndarray ,
291+ image_embeddings : Optional [util .ImageEmbeddings ] = None ,
292+ i : Optional [int ] = None ,
293+ use_box : bool = True ,
294+ use_mask : bool = True ,
295+ use_points : bool = False ,
296+ original_size : Optional [tuple [int , ...]] = None ,
297+ multimask_output : bool = False ,
298+ return_all : bool = False ,
299+ return_logits : bool = False ,
300+ box_extension : float = 0.0 ,
273301):
302+ """Segmentation from a mask prompt.
303+
304+ Args:
305+ predictor: The segment anything predictor.
306+ mask: The mask used to derive prompts.
307+ image_embeddings: Optional precomputed image embeddings.
308+ Has to be passed if the predictor is not yet initialized.
309+ i: Index for the image data. Required if the input data has three spatial dimensions
310+ or a time dimension and two spatial dimensions.
311+ use_box: Whether to derive the bounding box prompt from the mask.
312+ use_mask: Whether to use the mask itself as prompt.
313+ use_points: Wehter to derive point prompts from the mask.
314+ multimask_output: Whether to return multiple or just a single mask.
315+ return_all: Whether to return the score and logits in addition to the mask.
316+ box_extension: Relative factor used to enlarge the bounding box prompt.
317+
318+ Returns:
319+ The binary segmentation mask.
320+ """
274321 predictor , tile , mask , shape = _initialize_predictor (
275322 predictor , image_embeddings , i , mask , _mask_to_tile
276323 )
@@ -299,10 +346,30 @@ def segment_from_mask(
299346
300347
301348def segment_from_box (
302- predictor , box ,
303- image_embeddings = None , i = None , original_size = None ,
304- multimask_output = False , return_all = False ,
349+ predictor : SamPredictor ,
350+ box : np .ndarray ,
351+ image_embeddings : Optional [util .ImageEmbeddings ] = None ,
352+ i : Optional [int ] = None ,
353+ original_size : Optional [tuple [int , ...]] = None ,
354+ multimask_output : bool = False ,
355+ return_all : bool = False ,
305356):
357+ """Segmentation from a box prompt.
358+
359+ Args:
360+ predictor: The segment anything predictor.
361+ box: The box prompt.
362+ image_embeddings: Optional precomputed image embeddings.
363+ Has to be passed if the predictor is not yet initialized.
364+ i: Index for the image data. Required if the input data has three spatial dimensions
365+ or a time dimension and two spatial dimensions.
366+ original_size: The original image shape.
367+ multimask_output: Whether to return multiple or just a single mask.
368+ return_all: Whether to return the score and logits in addition to the mask.
369+
370+ Returns:
371+ The binary segmentation mask.
372+ """
306373 predictor , tile , box , shape = _initialize_predictor (
307374 predictor , image_embeddings , i , box , _box_to_tile
308375 )
@@ -317,10 +384,34 @@ def segment_from_box(
317384
318385
319386def segment_from_box_and_points (
320- predictor , box , points , labels ,
321- image_embeddings = None , i = None , original_size = None ,
322- multimask_output = False , return_all = False ,
387+ predictor : SamPredictor ,
388+ box : np .ndarray ,
389+ points : np .ndarray ,
390+ labels : np .ndarray ,
391+ image_embeddings : Optional [util .ImageEmbeddings ] = None ,
392+ i : Optional [int ] = None ,
393+ original_size : Optional [tuple [int , ...]] = None ,
394+ multimask_output : bool = False ,
395+ return_all : bool = False ,
323396):
397+ """Segmentation from a box prompt and point prompts.
398+
399+ Args:
400+ predictor: The segment anything predictor.
401+ box: The box prompt.
402+ points: The point prompts, given in the image coordinates system.
403+ labels: The point labels, either positive or negative.
404+ image_embeddings: Optional precomputed image embeddings.
405+ Has to be passed if the predictor is not yet initialized.
406+ i: Index for the image data. Required if the input data has three spatial dimensions
407+ or a time dimension and two spatial dimensions.
408+ original_size: The original image shape.
409+ multimask_output: Whether to return multiple or just a single mask.
410+ return_all: Whether to return the score and logits in addition to the mask.
411+
412+ Returns:
413+ The binary segmentation mask.
414+ """
324415 def box_and_points_to_tile (prompts , shape , tile_shape , halo ):
325416 box , points , labels = prompts
326417 tile_id , tile , point_prompts = _points_to_tile ((points , labels ), shape , tile_shape , halo )
0 commit comments