Skip to content

Commit 83de9c5

Browse files
authored
Update automatic tracking for storing tracks in CTC format (#1045)
Enable saving automatic tracking results and update automatic segmentation doc
1 parent d46d769 commit 83de9c5

File tree

6 files changed

+153
-55
lines changed

6 files changed

+153
-55
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ iterative_prompting_results/
195195
*.tif
196196
*.zip
197197
*MACOSX
198+
hela_ctc
198199
clf-test-data
199200

200201
# Related to i2k workshop folders.

examples/README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
# Examples
22

33
Examples for using the `micro_sam` annotation tools:
4-
- `annotator_2d.py`: run the interactive 2d annotation tool.
5-
- `annotator_3d.py`: run the interactive 3d annotation tool.
6-
- `annotator_tracking.py`: run the interactive tracking annotation tool.
7-
- `image_series_annotator.py`: run the annotation tool for a series of images.
4+
- `annotator_2d.py`: Run the interactive 2d annotation tool.
5+
- `annotator_3d.py`: Run the interactive 3d annotation tool.
6+
- `annotator_tracking.py`: Run the interactive tracking annotation tool.
7+
- `image_series_annotator.py`: Run the annotation tool for a series of images.
8+
9+
And python scripts for automatic segmentation and tracking:
10+
- `automatic_segmentation.py`: Run automatic segmentation on 2d images.
11+
- `automatic_tracking.py`: Run automatic tracking on 2d timeseries images.
812

913
We provide Jupyter Notebooks for using automatic segmentation and finetuning on some example data in the [notebooks](../notebooks/) folder.
1014

examples/automatic_tracking.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import os
2+
3+
from elf.io import open_file
4+
5+
from micro_sam.util import get_cache_directory
6+
from micro_sam.sample_data import fetch_tracking_example_data
7+
from micro_sam.automatic_segmentation import automatic_tracking, get_predictor_and_segmenter
8+
9+
10+
DATA_CACHE = os.path.join(get_cache_directory(), "sample_data")
11+
EMBEDDING_CACHE = os.path.join(get_cache_directory(), "embeddings")
12+
os.makedirs(EMBEDDING_CACHE, exist_ok=True)
13+
14+
15+
def example_automatic_tracking(use_finetuned_model):
16+
"""Run automatic tracking for data from the cell tracking challenge.
17+
"""
18+
# Download the example tracking data.
19+
example_data = fetch_tracking_example_data(DATA_CACHE)
20+
21+
# Load the example data (load the sequence of tif files as timeseries)
22+
with open_file(example_data, mode="r") as f:
23+
timeseries = f["*.tif"]
24+
25+
if use_finetuned_model:
26+
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc-vit_b_lm.zarr")
27+
model_type = "vit_b_lm"
28+
else:
29+
embedding_path = os.path.join(EMBEDDING_CACHE, "embeddings-ctc.zarr")
30+
model_type = "vit_h"
31+
32+
predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, amg=False)
33+
34+
masks_tracked, _ = automatic_tracking(
35+
predictor=predictor,
36+
segmenter=segmenter,
37+
input_path=timeseries[:],
38+
output_path="./hela_ctc",
39+
embedding_path=embedding_path,
40+
)
41+
42+
import napari
43+
v = napari.Viewer()
44+
v.add_image(timeseries)
45+
v.add_labels(masks_tracked)
46+
napari.run()
47+
48+
49+
def main():
50+
# Whether to use the fine-tuned SAM model.
51+
use_finetuned_model = True
52+
example_automatic_tracking(use_finetuned_model)
53+
54+
55+
if __name__ == "__main__":
56+
main()

micro_sam/automatic_segmentation.py

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import os
2-
from functools import partial
2+
import warnings
33
from glob import glob
44
from tqdm import tqdm
55
from pathlib import Path
6+
from functools import partial
67
from typing import Dict, List, Optional, Union, Tuple
78

89
import numpy as np
@@ -95,7 +96,7 @@ def automatic_tracking(
9596
segmenter: The automatic instance segmentation class.
9697
input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
9798
or a container file (e.g. hdf5 or zarr).
98-
output_path: The output path where the instance segmentations will be saved.
99+
output_path: The folder where the tracking outputs will be saved in CTC format.
99100
embedding_path: The path where the embeddings are cached already / will be saved.
100101
key: The key to the input file. This is needed for container files (eg. hdf5 or zarr)
101102
or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case.
@@ -111,11 +112,9 @@ def automatic_tracking(
111112
generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
112113
113114
Returns:
115+
The tracking result as a timeseries, where each object is labeled by its track id.
116+
The lineages representing cell divisions, stored as a dictionary.
114117
"""
115-
if output_path is not None:
116-
# TODO implement saving tracking results in CTC format and use it to save the result here.
117-
raise NotImplementedError("Saving the tracking result to file is currently not supported.")
118-
119118
# Load the input image file.
120119
if isinstance(input_path, np.ndarray):
121120
image_data = input_path
@@ -142,7 +141,8 @@ def automatic_tracking(
142141
halo=halo,
143142
verbose=verbose,
144143
batch_size=batch_size,
145-
return_image_embeddings=True,
144+
return_embeddings=True,
145+
output_folder=output_path,
146146
**generate_kwargs,
147147
)
148148

@@ -335,12 +335,6 @@ def _get_inputs_from_paths(paths, pattern):
335335
return fpaths
336336

337337

338-
def _has_extension(fpath: Union[os.PathLike, str]) -> bool:
339-
"Returns whether the provided path has an extension or not."
340-
breakpoint()
341-
return bool(os.path.splitext(fpath)[1])
342-
343-
344338
def main():
345339
"""@private"""
346340
import argparse
@@ -349,23 +343,34 @@ def main():
349343
available_models = ", ".join(available_models)
350344

351345
parser = argparse.ArgumentParser(
352-
description="Run automatic segmentation for an image using either automatic instance segmentation (AIS) \n"
353-
"or automatic mask generation (AMG). In addition to the arguments explained below,\n"
346+
description="Run automatic segmentation or tracking for 2d, 3d or timeseries data.\n"
347+
"Either a single input file or multiple input files are supported. You can specify multiple files "
348+
"by either providing multiple filepaths to the '--i/--input_paths' argument, or by providing an argument "
349+
"to '--pattern' to use a wildcard pattern ('*') for selecting multiple files.\n"
350+
"NOTE: for automatic 3d segmentation or tracking the data has to be stored as volume / timeseries, "
351+
"stacking individual tif images is not supported.\n"
352+
"Segmentation is performed using one of the two modes supported by micro_sam: \n"
353+
"automatic instance segmentation (AIS) or automatic mask generation (AMG).\n"
354+
"In addition to the options listed below, "
354355
"you can also passed additional arguments for these two segmentation modes:\n"
355356
"For AIS: '--center_distance_threshold', '--boundary_distance_threshold' and other arguments of `InstanceSegmentationWithDecoder.generate`." # noqa
356357
"For AMG: '--pred_iou_thresh', '--stability_score_thresh' and other arguments of `AutomaticMaskGenerator.generate`." # noqa
357358
)
358359
parser.add_argument(
359360
"-i", "--input_path", required=True, type=str, nargs="+",
360-
help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "
361+
help="The filepath(s) to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " # noqa
361362
"or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter."
362363
)
363364
parser.add_argument(
364365
"-o", "--output_path", required=True, type=str,
365-
help="The filepath to store the instance segmentation. The current support stores segmentation in a 'tif' file."
366+
help="The filepath to store the results. If multiple inputs are provied, "
367+
"this should be a folder. For a single image, you should provide the path to a tif file for the output segmentation." # noqa
368+
"NOTE: Segmentation results are stored as tif files, tracking results in the CTC fil format ."
366369
)
367370
parser.add_argument(
368-
"-e", "--embedding_path", default=None, type=str, help="The path where the embeddings will be saved."
371+
"-e", "--embedding_path", default=None, type=str,
372+
help="An optional path where the embeddings will be saved. If multiple inputs are provided, "
373+
"this should be a folder. Otherwise you can store embeddings in single zarr file."
369374
)
370375
parser.add_argument(
371376
"--pattern", type=str, help="Pattern / wildcard for selecting files in a folder. To select all files use '*'."
@@ -411,8 +416,8 @@ def main():
411416
"By default, computes the image embeddings for one tile / z-plane at a time."
412417
)
413418
parser.add_argument(
414-
"--tracking", action="store_true", help="Run tracking instead of instance segmentation. "
415-
"Only supported for timeseries inputs.."
419+
"--tracking", action="store_true", help="Run automatic tracking instead of instance segmentation. "
420+
"NOTE: It is only supported for timeseries inputs."
416421
)
417422
parser.add_argument(
418423
"-v", "--verbose", action="store_true", help="Whether to allow verbosity of outputs."
@@ -473,34 +478,51 @@ def _convert_argval(value):
473478
)
474479

475480
# Run automatic segmentation per image.
476-
for path in tqdm(input_paths, desc="Run automatic segmentation"):
477-
if has_one_input: # if we have one image only.
478-
_output_fpath = str(Path(output_path).with_suffix(".tif"))
479-
_embedding_fpath = embedding_path
480-
481-
else: # if we have multiple image, we need to make the other target filepaths compatible.
482-
# Let's check for 'embedding_path'.
483-
_embedding_fpath = embedding_path
484-
if embedding_path:
485-
if _has_extension(embedding_path): # in this case, use filename as addl. suffix to provided path.
486-
_embedding_fpath = str(Path(embedding_path).with_suffix(".zarr"))
487-
_embedding_fpath = _embedding_fpath.replace(".zarr", f"_{Path(path).stem}.zarr")
488-
else: # otherwise, for directory, use image filename for multiple images.
489-
os.makedirs(embedding_path, exist_ok=True)
490-
_embedding_fpath = os.path.join(embedding_path, Path(os.path.basename(path)).with_suffix(".zarr"))
491-
492-
# Next, let's check for output file to store segmentation.
493-
if _has_extension(output_path): # in this case, use filename as addl. suffix to provided path.
494-
_output_fpath = str(Path(output_path).with_suffix(".tif"))
495-
_output_fpath = _output_fpath.replace(".tif", f"_{Path(path).stem}.tif")
496-
else: # otherwise, for directory, use image filename for multiple images.
497-
os.makedirs(output_path, exist_ok=True)
498-
_output_fpath = os.path.join(output_path, Path(os.path.basename(path)).with_suffix(".tif"))
481+
for input_path in tqdm(input_paths, desc="Run automatic " + ("tracking" if args.tracking else "segmentation")):
482+
if has_one_input: # When we have only one image / volume.
483+
_embedding_fpath = embedding_path # Either folder or zarr file, would work for both.
484+
485+
output_fdir = os.path.splitext(output_path)[0]
486+
os.makedirs(output_fdir, exist_ok=True)
487+
488+
# For tracking, we ensure that the output path is a folder,
489+
# i.e. does not have an extension. We throw a warning if the user provided an extension.
490+
if args.tracking:
491+
if os.path.splitext(output_path)[-1]:
492+
warnings.warn(
493+
f"The output folder has an extension '{os.path.splitext(output_path)[-1]}'. "
494+
"We remove it and treat it as a folder to store tracking outputs in CTC format."
495+
)
496+
_output_fpath = output_fdir
497+
else: # Otherwise, we can store outputs for user directly in the provided filepath, ensuring extension .tif
498+
_output_fpath = f"{output_fdir}.tif"
499+
500+
else: # When we have multiple images.
501+
# Get the input filename, without the extension.
502+
input_name = str(Path(input_path).stem)
503+
504+
# Let's check the 'embedding_path'.
505+
if embedding_path is None: # For computing embeddings on-the-fly, we don't care about the path logic.
506+
_embedding_fpath = embedding_path
507+
else: # Otherwise, store each embeddings inside a folder.
508+
embedding_folder = os.path.splitext(embedding_path)[0] # Treat the provided embedding path as folder.
509+
os.makedirs(embedding_folder, exist_ok=True)
510+
_embedding_fpath = os.path.join(embedding_folder, f"{input_name}.zarr") # Create each embedding file.
511+
512+
# Get the output folder name.
513+
output_folder = os.path.splitext(output_path)[0]
514+
os.makedirs(output_folder, exist_ok=True)
515+
516+
# Next, let's check for output file to store segmentation (or tracks).
517+
if args.tracking: # For tracking, we store CTC outputs in subfolders, with input_name as folder.
518+
_output_fpath = os.path.join(output_folder, input_name)
519+
else: # Otherwise, store each result inside a folder.
520+
_output_fpath = os.path.join(output_folder, f"{input_name}.tif")
499521

500522
instance_seg_function(
501523
predictor=predictor,
502524
segmenter=segmenter,
503-
input_path=path,
525+
input_path=input_path,
504526
output_path=_output_fpath,
505527
embedding_path=_embedding_fpath,
506528
key=args.key,

micro_sam/evaluation/evaluation.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from elf.evaluation import mean_segmentation_accuracy
1818

1919
from ..util import load_image_data
20-
from ..automatic_segmentation import _has_extension
2120

2221

2322
def _run_evaluation(gt_paths, prediction_paths, verbose=True, thresholds=None):
@@ -206,7 +205,7 @@ def main():
206205
def _get_inputs_from_paths(paths, key):
207206
fpaths = []
208207
for path in paths:
209-
if _has_extension(path): # it is just one filepath and we check whether we can access it via 'elf'.
208+
if os.path.isfile(path): # it is just one filepath and we check whether we can access it via 'elf'.
210209
fpaths.append(path if key is None else load_image_data(path=path, key=key))
211210
else: # otherwise, path is a directory, fetch all inputs provided with a pattern.
212211
assert key is not None, \
@@ -222,7 +221,7 @@ def _get_inputs_from_paths(paths, key):
222221
# Check whether output path is a csv or not, if passed.
223222
output_path = args.output_path
224223
if output_path is not None:
225-
if not _has_extension(output_path): # If it is a directory, store this in "<OUTPUT_PATH>/results.csv"
224+
if not os.path.isfile(output_path): # If it is a directory, store this in "<OUTPUT_PATH>/results.csv"
226225
os.makedirs(output_path, exist_ok=True)
227226
output_path = os.path.join(output_path, "results.csv")
228227

micro_sam/multi_dimensional_segmentation.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
try:
3131
from trackastra.model import Trackastra
32-
from trackastra.tracking import graph_to_napari_tracks
32+
from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks
3333
except ImportError:
3434
Trackastra = None
3535

@@ -570,14 +570,17 @@ def _filter_lineages(lineages, tracking_result):
570570
return filtered_lineages
571571

572572

573-
def _tracking_impl(timeseries, segmentation, mode, min_time_extent):
573+
def _tracking_impl(timeseries, segmentation, mode, min_time_extent, output_folder=None):
574574
device = "cuda" if torch.cuda.is_available() else "cpu"
575575
model = Trackastra.from_pretrained("general_2d", device=device)
576576
lineage_graph = model.track(timeseries, segmentation, mode=mode)
577577
track_data, parent_graph, _ = graph_to_napari_tracks(lineage_graph)
578578
node_to_track, lineages = _extract_tracks_and_lineages(segmentation, track_data, parent_graph)
579579
tracking_result = recolor_segmentation(segmentation, node_to_track)
580580

581+
if output_folder is not None: # Store tracking results in CTC format.
582+
graph_to_ctc(lineage_graph, segmentation, outdir=output_folder)
583+
581584
# TODO
582585
# We should check if trackastra supports this already.
583586
# Filter out short tracks / lineages.
@@ -599,6 +602,7 @@ def track_across_frames(
599602
verbose: bool = True,
600603
pbar_init: Optional[callable] = None,
601604
pbar_update: Optional[callable] = None,
605+
output_folder: Optional[Union[os.PathLike, str]] = None,
602606
) -> Tuple[np.ndarray, List[Dict]]:
603607
"""Track segmented objects over time.
604608
@@ -615,6 +619,7 @@ def track_across_frames(
615619
verbose: Verbosity flag. By default, set to 'True'.
616620
pbar_init: Function to initialize the progress bar.
617621
pbar_update: Function to update the progress bar.
622+
output_folder: The folder where the tracking results are stored in CTC format.
618623
619624
Returns:
620625
The tracking result. Each object is colored by its track id.
@@ -628,7 +633,11 @@ def track_across_frames(
628633
segmentation = _preprocess_closing(segmentation, gap_closing, pbar_update)
629634

630635
segmentation, lineage = _tracking_impl(
631-
np.asarray(timeseries), segmentation, mode="greedy", min_time_extent=min_time_extent
636+
timeseries=np.asarray(timeseries),
637+
segmentation=segmentation,
638+
mode="greedy",
639+
min_time_extent=min_time_extent,
640+
output_folder=output_folder,
632641
)
633642
return segmentation, lineage
634643

@@ -645,6 +654,7 @@ def automatic_tracking_implementation(
645654
verbose: bool = True,
646655
return_embeddings: bool = False,
647656
batch_size: int = 1,
657+
output_folder: Optional[Union[os.PathLike, str]] = None,
648658
**kwargs,
649659
) -> Tuple[np.ndarray, List[Dict]]:
650660
"""Automatically track objects in a timesries based on per-frame automatic segmentation.
@@ -665,6 +675,7 @@ def automatic_tracking_implementation(
665675
verbose: Verbosity flag. By default, set to 'True'.
666676
return_embeddings: Whether to return the precomputed image embeddings. By default, set to 'False'.
667677
batch_size: The batch size to compute image embeddings over planes. By default, set to '1'.
678+
output_folder: The folder where the tracking results are stored in CTC format.
668679
kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
669680
670681
Returns:
@@ -685,7 +696,12 @@ def automatic_tracking_implementation(
685696
)
686697

687698
segmentation, lineage = track_across_frames(
688-
timeseries, segmentation, gap_closing=gap_closing, min_time_extent=min_time_extent, verbose=verbose,
699+
timeseries=timeseries,
700+
segmentation=segmentation,
701+
gap_closing=gap_closing,
702+
min_time_extent=min_time_extent,
703+
verbose=verbose,
704+
output_folder=output_folder,
689705
)
690706

691707
if return_embeddings:

0 commit comments

Comments
 (0)