11import hashlib
22import os
33import warnings
4+ from collections .abc import Mapping
45from shutil import copyfileobj
5- from typing import Any
6+ from typing import Any , Optional
67
78import numpy as np
89import requests
@@ -86,20 +87,26 @@ def _get_checkpoint(model_type, checkpoint_path=None):
8687 return checkpoint_path
8788
8889
89- def get_sam_model (device = None , model_type = "vit_h" , checkpoint_path = None , return_sam = False ):
90+ def get_sam_model (
91+ device : Optional [str ] = None ,
92+ model_type : str = "vit_h" ,
93+ checkpoint_path : Optional [str ] = None ,
94+ return_sam : bool = False
95+ ) -> SamPredictor :
9096 """Get the SegmentAnything Predictor.
9197
9298 This function will download the required model checkpoint or load it from file if it
93- was already downloaded. By default the models are downloaded to ~/.sam_models.
99+ was already downloaded. By default the models are downloaded to ' ~/.sam_models' .
94100 This location can be changed by setting the environment variable SAM_MODELS.
95101
96- Arguments:
97- device [str, torch.device] - the device for the model. If none is given will use GPU if available.
98- (default: None)
99- model_type [str] - the SegmentAnything model to use. (default: vit_h)
100- checkpoint_path [str] - the path to the corresponding checkpoint if it is already present
101- and not in the default model folder. (default: None)
102- return_sam [bool] - return the sam model object as well as the predictor (default: False)
102+ Args:
103+ device: The device for the model. If none is given will use GPU if available.
104+ model_type: The SegmentAnything model to use.
105+ checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
106+ return_sam: Return the sam model object as well as the predictor.
107+
108+ Returns:
109+ The segment anything predictor.
103110 """
104111 checkpoint = _get_checkpoint (model_type , checkpoint_path )
105112 device = "cuda" if torch .cuda .is_available () else "cpu"
@@ -321,43 +328,46 @@ def _precompute_3d(input_, predictor, save_path, lazy_loading, tile_shape=None,
321328 return image_embeddings
322329
323330
324- def compute_data_signature (input_ ):
331+ def _compute_data_signature (input_ ):
325332 data_signature = hashlib .sha1 (np .asarray (input_ ).tobytes ()).hexdigest ()
326333 return data_signature
327334
328335
329336def precompute_image_embeddings (
330- predictor , input_ ,
331- save_path = None , lazy_loading = False ,
332- ndim = None , tile_shape = None , halo = None ,
333- wrong_file_callback = None ,
334- ):
337+ predictor : SamPredictor ,
338+ input_ : np .ndarray ,
339+ save_path : Optional [str ] = None ,
340+ lazy_loading : bool = False ,
341+ ndim : Optional [int ] = None ,
342+ tile_shape : Optional [tuple [int ]] = None ,
343+ halo : Optional [tuple [int ]] = None ,
344+ wrong_file_callback : Optional [callable ] = None ,
345+ ) -> ImageEmbeddings :
335346 """Compute the image embeddings (output of the encoder) for the input.
336347
337- If save_path is given the embeddings will be loaded/saved in a zarr container.
348+ If ' save_path' is given the embeddings will be loaded/saved in a zarr container.
338349
339- Arguments:
340- predictor - the SegmentAnything predictor
341- input_ [np.ndarray] - the input. Can be 2D or 3D.
342- save_path [str] - path to save the embeddings in a zarr container (default: None)
343- lazy_loading [bool] - whether to load all embeddings into memory or return an
344- object to load them on demand when required. This only has an effect if 'save_path'
345- is given and if the input is 3D. (default: False)
346- ndim [int] - the dimensionality of the data. If not given will be deduced from the input data. (default: None)
347- tile_shape [tuple] - shape of tiles for tiled prediction.
348- By default prediction is run without tiling. (default: None)
349- halo [tuple] - additional overlap of the tiles for tiled prediction. (default: None)
350- wrong_file_callback [callable] - function to call when an embedding file with wrong file signature
350+ Args:
351+ predictor: The SegmentAnything predictor
352+ input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
353+ save_path: Path to save the embeddings in a zarr container.
354+ lazy_loading: Whether to load all embeddings into memory or return an
355+ object to load them on demand when required. This only has an effect if 'save_path' is given
356+ and if the input is 3 dimensional.
357+ ndim: The dimensionality of the data. If not given will be deduced from the input data.
358+ tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
359+ halo: Overlap of the tiles for tiled prediction.
360+ wrong_file_callback [callable]: Function to call when an embedding file with wrong file signature
351361 is passed. If none is given a wrong file signature will cause a warning.
352- If passed, the callback should have the signature 'def callback(save_path): return str',
353- where the return value is the (potentially updated) embedding save path (default: None)
362+ The callback ,ust have the signature 'def callback(save_path: str) -> str',
363+ where the return value is the (potentially updated) embedding save path.
354364 """
355365 ndim = input_ .ndim if ndim is None else ndim
356366 if tile_shape is not None :
357367 assert save_path is not None , "Tiled prediction is only supported when the embeddings are saved to file."
358368
359369 if save_path is not None :
360- data_signature = compute_data_signature (input_ )
370+ data_signature = _compute_data_signature (input_ )
361371
362372 f = zarr .open (save_path , "a" )
363373 if "input_size" in f .attrs : # we have computed the embeddings already
@@ -388,15 +398,18 @@ def precompute_image_embeddings(
388398 return image_embeddings
389399
390400
391- def set_precomputed (predictor , image_embeddings , i = None ):
392- """Set the precomputed image embeddings.
401+ def set_precomputed (
402+ predictor : SamPredictor ,
403+ image_embeddings : ImageEmbeddings ,
404+ i : Optional [int ] = None
405+ ):
406+ """Set the precomputed image embeddings for a predictor.
393407
394408 Arguments:
395- predictor - the SegmentAnything predictor
396- image_embeddings [dict] - the precomputed image embeddings.
397- This object is returned by 'precomputed_image_embeddings'.
398- i [int] - the index for the image embeddings for 3D data.
399- Only needs to be passed for 3d data. (default: None)
409+ predictor: The SegmentAnything predictor.
410+ image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`.
411+ i: Index for the image data. Required if `image` has three spatial dimensions
412+ or a time dimension and two spatial dimensions.
400413 """
401414 device = "cuda" if torch .cuda .is_available () else "cpu"
402415 features = image_embeddings ["features" ]
@@ -420,8 +433,15 @@ def set_precomputed(predictor, image_embeddings, i=None):
420433 return predictor
421434
422435
423- def compute_iou (mask1 , mask2 ) :
436+ def compute_iou (mask1 : np . ndarray , mask2 : np . ndarray ) -> float :
424437 """Compute the intersection over union of two masks.
438+
439+ Args:
440+ mask1: The first mask.
441+ mask2: The second mask.
442+
443+ Returns:
444+ The intersection over union of the two masks.
425445 """
426446 overlap = np .logical_and (mask1 == 1 , mask2 == 1 ).sum ()
427447 union = np .logical_or (mask1 == 1 , mask2 == 1 ).sum ()
@@ -437,8 +457,10 @@ def get_bounding_boxes_and_centers(
437457 """Returns the center coordinates of the foreground instances in the ground-truth.
438458
439459 Args:
440- segmentation:
441- mode:
460+ segmentation: The segmentation.
461+ mode: Determines the functionality used for computing the centers.
462+ If 'v', the object's eccentricity centers computed by vigra are used.
463+ If 'p' the object's centroids computed by skimage are used.
442464
443465 Returns:
444466 A dictionary that maps object ids to the corresponding centroid.
@@ -460,7 +482,23 @@ def get_bounding_boxes_and_centers(
460482 return center_coordinates , bbox_coordinates
461483
462484
463- def load_image_data (path , ndim , key = None , lazy_loading = False ):
485+ def load_image_data (
486+ path : str ,
487+ ndim : int ,
488+ key : Optional [str ] = None ,
489+ lazy_loading : bool = False
490+ ) -> np .ndarray :
491+ """Helper function to load image data from file.
492+
493+ Args:
494+ path: The filepath to the image data.
495+ ndim: The data dimensionality.
496+ key: The internal filepath for complex data formats like hdf5.
497+ lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
498+
499+ Returns:
500+ The image data.
501+ """
464502 if key is None :
465503 image_data = imageio .imread (path ) if ndim == 2 else imageio .volread (path )
466504 else :
0 commit comments