Skip to content

Commit 7c736f6

Browse files
Add doc strings and type annotations for util
1 parent f2b20e0 commit 7c736f6

File tree

1 file changed

+81
-43
lines changed

1 file changed

+81
-43
lines changed

micro_sam/util.py

Lines changed: 81 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import hashlib
22
import os
33
import warnings
4+
from collections.abc import Mapping
45
from shutil import copyfileobj
5-
from typing import Any
6+
from typing import Any, Optional
67

78
import numpy as np
89
import requests
@@ -86,20 +87,26 @@ def _get_checkpoint(model_type, checkpoint_path=None):
8687
return checkpoint_path
8788

8889

89-
def get_sam_model(device=None, model_type="vit_h", checkpoint_path=None, return_sam=False):
90+
def get_sam_model(
91+
device: Optional[str] = None,
92+
model_type: str = "vit_h",
93+
checkpoint_path: Optional[str] = None,
94+
return_sam: bool = False
95+
) -> SamPredictor:
9096
"""Get the SegmentAnything Predictor.
9197
9298
This function will download the required model checkpoint or load it from file if it
93-
was already downloaded. By default the models are downloaded to ~/.sam_models.
99+
was already downloaded. By default the models are downloaded to '~/.sam_models'.
94100
This location can be changed by setting the environment variable SAM_MODELS.
95101
96-
Arguments:
97-
device [str, torch.device] - the device for the model. If none is given will use GPU if available.
98-
(default: None)
99-
model_type [str] - the SegmentAnything model to use. (default: vit_h)
100-
checkpoint_path [str] - the path to the corresponding checkpoint if it is already present
101-
and not in the default model folder. (default: None)
102-
return_sam [bool] - return the sam model object as well as the predictor (default: False)
102+
Args:
103+
device: The device for the model. If none is given will use GPU if available.
104+
model_type: The SegmentAnything model to use.
105+
checkpoint_path: The path to the corresponding checkpoint if not in the default model folder.
106+
return_sam: Return the sam model object as well as the predictor.
107+
108+
Returns:
109+
The segment anything predictor.
103110
"""
104111
checkpoint = _get_checkpoint(model_type, checkpoint_path)
105112
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -321,43 +328,46 @@ def _precompute_3d(input_, predictor, save_path, lazy_loading, tile_shape=None,
321328
return image_embeddings
322329

323330

324-
def compute_data_signature(input_):
331+
def _compute_data_signature(input_):
325332
data_signature = hashlib.sha1(np.asarray(input_).tobytes()).hexdigest()
326333
return data_signature
327334

328335

329336
def precompute_image_embeddings(
330-
predictor, input_,
331-
save_path=None, lazy_loading=False,
332-
ndim=None, tile_shape=None, halo=None,
333-
wrong_file_callback=None,
334-
):
337+
predictor: SamPredictor,
338+
input_: np.ndarray,
339+
save_path: Optional[str] = None,
340+
lazy_loading: bool = False,
341+
ndim: Optional[int] = None,
342+
tile_shape: Optional[tuple[int]] = None,
343+
halo: Optional[tuple[int]] = None,
344+
wrong_file_callback: Optional[callable] = None,
345+
) -> ImageEmbeddings:
335346
"""Compute the image embeddings (output of the encoder) for the input.
336347
337-
If save_path is given the embeddings will be loaded/saved in a zarr container.
348+
If 'save_path' is given the embeddings will be loaded/saved in a zarr container.
338349
339-
Arguments:
340-
predictor - the SegmentAnything predictor
341-
input_ [np.ndarray] - the input. Can be 2D or 3D.
342-
save_path [str] - path to save the embeddings in a zarr container (default: None)
343-
lazy_loading [bool] - whether to load all embeddings into memory or return an
344-
object to load them on demand when required. This only has an effect if 'save_path'
345-
is given and if the input is 3D. (default: False)
346-
ndim [int] - the dimensionality of the data. If not given will be deduced from the input data. (default: None)
347-
tile_shape [tuple] - shape of tiles for tiled prediction.
348-
By default prediction is run without tiling. (default: None)
349-
halo [tuple] - additional overlap of the tiles for tiled prediction. (default: None)
350-
wrong_file_callback [callable] - function to call when an embedding file with wrong file signature
350+
Args:
351+
predictor: The SegmentAnything predictor
352+
input_: The input data. Can be 2 or 3 dimensional, corresponding to an image, volume or timeseries.
353+
save_path: Path to save the embeddings in a zarr container.
354+
lazy_loading: Whether to load all embeddings into memory or return an
355+
object to load them on demand when required. This only has an effect if 'save_path' is given
356+
and if the input is 3 dimensional.
357+
ndim: The dimensionality of the data. If not given will be deduced from the input data.
358+
tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
359+
halo: Overlap of the tiles for tiled prediction.
360+
wrong_file_callback [callable]: Function to call when an embedding file with wrong file signature
351361
is passed. If none is given a wrong file signature will cause a warning.
352-
If passed, the callback should have the signature 'def callback(save_path): return str',
353-
where the return value is the (potentially updated) embedding save path (default: None)
362+
The callback ,ust have the signature 'def callback(save_path: str) -> str',
363+
where the return value is the (potentially updated) embedding save path.
354364
"""
355365
ndim = input_.ndim if ndim is None else ndim
356366
if tile_shape is not None:
357367
assert save_path is not None, "Tiled prediction is only supported when the embeddings are saved to file."
358368

359369
if save_path is not None:
360-
data_signature = compute_data_signature(input_)
370+
data_signature = _compute_data_signature(input_)
361371

362372
f = zarr.open(save_path, "a")
363373
if "input_size" in f.attrs: # we have computed the embeddings already
@@ -388,15 +398,18 @@ def precompute_image_embeddings(
388398
return image_embeddings
389399

390400

391-
def set_precomputed(predictor, image_embeddings, i=None):
392-
"""Set the precomputed image embeddings.
401+
def set_precomputed(
402+
predictor: SamPredictor,
403+
image_embeddings: ImageEmbeddings,
404+
i: Optional[int] = None
405+
):
406+
"""Set the precomputed image embeddings for a predictor.
393407
394408
Arguments:
395-
predictor - the SegmentAnything predictor
396-
image_embeddings [dict] - the precomputed image embeddings.
397-
This object is returned by 'precomputed_image_embeddings'.
398-
i [int] - the index for the image embeddings for 3D data.
399-
Only needs to be passed for 3d data. (default: None)
409+
predictor: The SegmentAnything predictor.
410+
image_embeddings: The precomputed image embeddings computed by `precompute_image_embeddings`.
411+
i: Index for the image data. Required if `image` has three spatial dimensions
412+
or a time dimension and two spatial dimensions.
400413
"""
401414
device = "cuda" if torch.cuda.is_available() else "cpu"
402415
features = image_embeddings["features"]
@@ -420,8 +433,15 @@ def set_precomputed(predictor, image_embeddings, i=None):
420433
return predictor
421434

422435

423-
def compute_iou(mask1, mask2):
436+
def compute_iou(mask1: np.ndarray, mask2: np.ndarray) -> float:
424437
"""Compute the intersection over union of two masks.
438+
439+
Args:
440+
mask1: The first mask.
441+
mask2: The second mask.
442+
443+
Returns:
444+
The intersection over union of the two masks.
425445
"""
426446
overlap = np.logical_and(mask1 == 1, mask2 == 1).sum()
427447
union = np.logical_or(mask1 == 1, mask2 == 1).sum()
@@ -437,8 +457,10 @@ def get_bounding_boxes_and_centers(
437457
"""Returns the center coordinates of the foreground instances in the ground-truth.
438458
439459
Args:
440-
segmentation:
441-
mode:
460+
segmentation: The segmentation.
461+
mode: Determines the functionality used for computing the centers.
462+
If 'v', the object's eccentricity centers computed by vigra are used.
463+
If 'p' the object's centroids computed by skimage are used.
442464
443465
Returns:
444466
A dictionary that maps object ids to the corresponding centroid.
@@ -460,7 +482,23 @@ def get_bounding_boxes_and_centers(
460482
return center_coordinates, bbox_coordinates
461483

462484

463-
def load_image_data(path, ndim, key=None, lazy_loading=False):
485+
def load_image_data(
486+
path: str,
487+
ndim: int,
488+
key: Optional[str] = None,
489+
lazy_loading: bool = False
490+
) -> np.ndarray:
491+
"""Helper function to load image data from file.
492+
493+
Args:
494+
path: The filepath to the image data.
495+
ndim: The data dimensionality.
496+
key: The internal filepath for complex data formats like hdf5.
497+
lazy_loading: Whether to lazyly load data. Only supported for n5 and zarr data.
498+
499+
Returns:
500+
The image data.
501+
"""
464502
if key is None:
465503
image_data = imageio.imread(path) if ndim == 2 else imageio.volread(path)
466504
else:

0 commit comments

Comments
 (0)