Skip to content
Draft
136 changes: 118 additions & 18 deletions micro_sam/automatic_segmentation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
from functools import partial
from glob import glob
from tqdm import tqdm
from pathlib import Path
from typing import Optional, Union, Tuple
from typing import Dict, List, Optional, Union, Tuple

import numpy as np
import imageio.v3 as imageio
Expand All @@ -14,7 +15,7 @@
get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder,
AMGBase, AutomaticMaskGenerator, TiledAutomaticMaskGenerator
)
from .multi_dimensional_segmentation import automatic_3d_segmentation
from .multi_dimensional_segmentation import automatic_3d_segmentation, automatic_tracking_implementation


def get_predictor_and_segmenter(
Expand Down Expand Up @@ -71,12 +72,94 @@ def _add_suffix_to_output_path(output_path: Union[str, os.PathLike], suffix: str
return str(fpath.with_name(f"{fpath.stem}{suffix}{fext}"))


def automatic_tracking(
predictor: util.SamPredictor,
segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
input_path: Union[Union[os.PathLike, str], np.ndarray],
output_path: Optional[Union[os.PathLike, str]] = None,
embedding_path: Optional[Union[os.PathLike, str, util.ImageEmbeddings]] = None,
key: Optional[str] = None,
tile_shape: Optional[Tuple[int, int]] = None,
halo: Optional[Tuple[int, int]] = None,
verbose: bool = True,
return_embeddings: bool = False,
annotate: bool = False,
batch_size: int = 1,
**generate_kwargs
) -> Tuple[np.ndarray, List[Dict]]:
"""Run automatic tracking for the input timeseries.

Args:
predictor: The Segment Anything model.
segmenter: The automatic instance segmentation class.
input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
or a container file (e.g. hdf5 or zarr).
output_path: The output path where the instance segmentations will be saved.
embedding_path: The path where the embeddings are cached already / will be saved.
This argument also accepts already deserialized embeddings.
key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
halo: Overlap of the tiles for tiled prediction.
verbose: Verbosity flag.
return_embeddings: Whether to return the precomputed image embeddings.
annotate: Whether to activate the annotator for continue annotation process.
batch_size: The batch size to compute image embeddings over tiles / z-planes.
By default, does it sequentially, i.e. one after the other.
generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.

Returns:
"""
if output_path is not None:
# TODO implement saving tracking results in CTC format and use it to save the result here.
raise NotImplementedError("Saving the tracking result to file is currently not supported.")

# Load the input image file.
if isinstance(input_path, np.ndarray):
image_data = input_path
else:
image_data = util.load_image_data(input_path, key)

# We perform additional post-processing for AMG-only.
# Otherwise, we ignore additional post-processing for AIS.
if isinstance(segmenter, InstanceSegmentationWithDecoder):
generate_kwargs["output_mode"] = None

if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")

gap_closing, min_time_extent = generate_kwargs.get("gap_closing"), generate_kwargs.get("min_time_extent")
segmentation, lineage, image_embeddings = automatic_tracking_implementation(
image_data,
predictor,
segmenter,
embedding_path=embedding_path,
gap_closing=gap_closing,
min_time_extent=min_time_extent,
tile_shape=tile_shape,
halo=halo,
verbose=verbose,
batch_size=batch_size,
return_image_embeddings=True,
**generate_kwargs,
)

if annotate:
# TODO We need to support initialization of the tracking annotator with the tracking result for this.
raise NotImplementedError("Annotation after running the automated tracking is currently not supported.")

if return_embeddings:
return segmentation, lineage, image_embeddings
else:
return segmentation, lineage


def automatic_instance_segmentation(
predictor: util.SamPredictor,
segmenter: Union[AMGBase, InstanceSegmentationWithDecoder],
input_path: Union[Union[os.PathLike, str], np.ndarray],
output_path: Optional[Union[os.PathLike, str]] = None,
embedding_path: Optional[Union[os.PathLike, str]] = None,
embedding_path: Optional[Union[os.PathLike, str, util.ImageEmbeddings]] = None,
key: Optional[str] = None,
ndim: Optional[int] = None,
tile_shape: Optional[Tuple[int, int]] = None,
Expand All @@ -96,6 +179,7 @@ def automatic_instance_segmentation(
or a container file (e.g. hdf5 or zarr).
output_path: The output path where the instance segmentations will be saved.
embedding_path: The path where the embeddings are cached already / will be saved.
This argument also accepts already deserialized embeddings.
key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
ndim: The dimensionality of the data. By default the dimensionality of the data will be used.
Expand Down Expand Up @@ -137,16 +221,19 @@ def automatic_instance_segmentation(
raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}")

# Precompute the image embeddings.
image_embeddings = util.precompute_image_embeddings(
predictor=predictor,
input_=image_data,
save_path=embedding_path,
ndim=ndim,
tile_shape=tile_shape,
halo=halo,
verbose=verbose,
batch_size=batch_size,
)
if embedding_path is None or isinstance(embedding_path, (str, os.PathLike)):
image_embeddings = util.precompute_image_embeddings(
predictor=predictor,
input_=image_data,
save_path=embedding_path,
ndim=ndim,
tile_shape=tile_shape,
halo=halo,
verbose=verbose,
batch_size=batch_size,
)
else:
image_embeddings = embedding_path
initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose)

# If we run AIS with tiling then we use the same tile shape for the watershed postprocessing.
Expand All @@ -160,14 +247,14 @@ def automatic_instance_segmentation(
masks = segmenter.generate(**generate_kwargs)

if isinstance(masks, list):
# whether the predictions from 'generate' are list of dict,
# Whether the predictions from 'generate' are list of dict,
# which contains additional info req. for post-processing, eg. area per object.
if len(masks) == 0:
instances = np.zeros(image_data.shape[:2], dtype="uint32")
else:
instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0)
else:
# if (raw) predictions provided, store them as it is w/o further post-processing.
# If (raw) predictions provided, store them as it is w/o further post-processing.
instances = masks

else:
Expand Down Expand Up @@ -258,7 +345,13 @@ def main():
available_models = list(util.get_model_names())
available_models = ", ".join(available_models)

parser = argparse.ArgumentParser(description="Run automatic segmentation for an image.")
parser = argparse.ArgumentParser(
description="Run automatic segmentation for an image using either automatic instance segmentation (AIS) \n"
"or automatic mask generation (AMG). In addition to the arguments explained below,\n"
"you can also passed additional arguments for these two segmentation modes:\n"
"For AIS: '--center_distance_threshold', '--boundary_distance_threshold' and other arguments of `InstanceSegmentationWithDecoder.generate`." # noqa
"For AMG: '--pred_iou_thresh', '--stability_score_thresh' and other arguments of `AutomaticMaskGenerator.generate`." # noqa
)
parser.add_argument(
"-i", "--input_path", required=True, type=str, nargs="+",
help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "
Expand Down Expand Up @@ -314,6 +407,10 @@ def main():
help="The batch size for computing image embeddings over tiles or z-plane. "
"By default, computes the image embeddings for one tile / z-plane at a time."
)
parser.add_argument(
"--tracking", action="store_true", help="Run tracking instead of instance segmentation. "
"Only supported for timeseries inputs.."
)
parser.add_argument(
"-v", "--verbose", action="store_true", help="Whether to allow verbosity of outputs."
)
Expand Down Expand Up @@ -368,6 +465,10 @@ def _convert_argval(value):
embedding_path = args.embedding_path
has_one_input = len(input_paths) == 1

instance_seg_function = automatic_tracking if args.tracking else partial(
automatic_instance_segmentation, ndim=args.ndim
)

# Run automatic segmentation per image.
for path in tqdm(input_paths, desc="Run automatic segmentation"):
if has_one_input: # if we have one image only.
Expand All @@ -393,14 +494,13 @@ def _convert_argval(value):
os.makedirs(output_path, exist_ok=True)
_output_fpath = os.path.join(output_path, Path(os.path.basename(path)).with_suffix(".tif"))

automatic_instance_segmentation(
instance_seg_function(
predictor=predictor,
segmenter=segmenter,
input_path=path,
output_path=_output_fpath,
embedding_path=_embedding_fpath,
key=args.key,
ndim=args.ndim,
tile_shape=args.tile_shape,
halo=args.halo,
annotate=args.annotate,
Expand Down
45 changes: 29 additions & 16 deletions micro_sam/multi_dimensional_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,16 +375,20 @@ def _segment_slices(
assert data.ndim == 3

min_object_size = kwargs.pop("min_object_size", 0)
image_embeddings = util.precompute_image_embeddings(
predictor=predictor,
input_=data,
save_path=embedding_path,
ndim=3,
tile_shape=tile_shape,
halo=halo,
verbose=verbose,
batch_size=batch_size,
)
# Check if the embeddings still have to be computed.
if embedding_path is None or isinstance(embedding_path, (str, os.PathLike)):
image_embeddings = util.precompute_image_embeddings(
predictor=predictor,
input_=data,
save_path=embedding_path,
ndim=3,
tile_shape=tile_shape,
halo=halo,
verbose=verbose,
batch_size=batch_size,
)
else: # Otherwise the deserialized embeddings were passed.
image_embeddings = embedding_path

offset = 0
segmentation = np.zeros(data.shape, dtype="uint32")
Expand Down Expand Up @@ -417,7 +421,7 @@ def automatic_3d_segmentation(
volume: np.ndarray,
predictor: SamPredictor,
segmentor: AMGBase,
embedding_path: Optional[Union[str, os.PathLike]] = None,
embedding_path: Optional[Union[str, os.PathLike, util.ImageEmbeddings]] = None,
with_background: bool = True,
gap_closing: Optional[int] = None,
min_z_extent: Optional[int] = None,
Expand All @@ -438,6 +442,7 @@ def automatic_3d_segmentation(
predictor: The SAM model.
segmentor: The instance segmentation class.
embedding_path: The path to save pre-computed embeddings.
This argument also accepts already deserialized embeddings.
with_background: Whether the segmentation has background.
gap_closing: If given, gaps in the segmentation are closed with a binary closing
operation. The value is used to determine the number of iterations for the closing.
Expand Down Expand Up @@ -622,16 +627,18 @@ def track_across_frames(
return segmentation, lineage


def automatic_tracking(
def automatic_tracking_implementation(
timeseries: np.ndarray,
predictor: SamPredictor,
segmentor: AMGBase,
embedding_path: Optional[Union[str, os.PathLike]] = None,
embedding_path: Optional[Union[str, os.PathLike, util.ImageEmbeddings]] = None,
gap_closing: Optional[int] = None,
min_time_extent: Optional[int] = None,
tile_shape: Optional[Tuple[int, int]] = None,
halo: Optional[Tuple[int, int]] = None,
verbose: bool = True,
return_embeddings: bool = False,
batch_size: int = 1,
**kwargs,
) -> Tuple[np.ndarray, List[Dict]]:
"""Automatically track objects in a timesries based on per-frame automatic segmentation.
Expand All @@ -644,12 +651,15 @@ def automatic_tracking(
predictor: The SAM model.
segmentor: The instance segmentation class.
embedding_path: The path to save pre-computed embeddings.
This argument also accepts already deserialized embeddings.
gap_closing: If given, gaps in the segmentation are closed with a binary closing
operation. The value is used to determine the number of iterations for the closing.
min_time_extent: Require a minimal extent in time for the tracked objects.
tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
halo: Overlap of the tiles for tiled prediction.
verbose: Verbosity flag.
return_embeddings: Whether to return the precomputed image embeddings.
batch_size: The batch size to compute image embeddings over planes.
kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.

Returns:
Expand All @@ -662,15 +672,18 @@ def automatic_tracking(
raise RuntimeError(
"Automatic tracking requires trackastra. You can install it via 'pip install trackastra'."
)
segmentation, _ = _segment_slices(
segmentation, image_embeddings = _segment_slices(
timeseries, predictor, segmentor, embedding_path, verbose,
tile_shape=tile_shape, halo=halo,
tile_shape=tile_shape, halo=halo, batch_size=batch_size,
**kwargs,
)
segmentation, lineage = track_across_frames(
timeseries, segmentation, gap_closing=gap_closing, min_time_extent=min_time_extent, verbose=verbose,
)
return segmentation, lineage
if return_embeddings:
return segmentation, lineage, image_embeddings
else:
return segmentation, lineage


def get_napari_track_data(
Expand Down
Loading