From a048e75c0ba258dbe9efd4cf34164e2c1aefe028 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 28 Mar 2025 20:58:15 +0100 Subject: [PATCH 1/7] Implement reproducibility functionality for commit file WIP --- micro_sam/reproducibility.py | 40 ++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 micro_sam/reproducibility.py diff --git a/micro_sam/reproducibility.py b/micro_sam/reproducibility.py new file mode 100644 index 000000000..c1ba76328 --- /dev/null +++ b/micro_sam/reproducibility.py @@ -0,0 +1,40 @@ +import os +from typing import Union + +import zarr + + +# TODO add a test for this (with a prepared commit file) +def rerun_segmentation_from_commit_file( + commit_file: Union[str, os.PathLike], + input_path: Union[str, os.PathLike], +) -> None: + """ + + Args: + commit_file: + input_path: + """ + f = zarr.open(commit_file, mode="r") + + # TODO + # 1. Load the model according to the model description stored in the commit file. + + # 2. Go through the commit history and redo the action of each commit. + # Actions can be: + # - Committing an automatic segmentation result. + # - Committing an interactive segmentation result. + + +# TODO +def continue_annotation_from_commit_file( + commit_file: Union[str, os.PathLike], + input_path: Union[str, os.PathLike], +) -> None: + """ + """ + + +# TODO CLI for 'continue_annotation_from_commit_file' +def main(): + pass From 9d1925b70ef748cf54df81ca9aa28964a095e532 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 30 Mar 2025 21:51:45 +0200 Subject: [PATCH 2/7] More work on reproducibility scripts --- micro_sam/reproducibility.py | 40 --- micro_sam/sam_annotator/_widgets.py | 40 ++- micro_sam/sam_annotator/reproducibility.py | 249 ++++++++++++++++++ .../test_reproducibility.py | 29 ++ 4 files changed, 309 insertions(+), 49 deletions(-) delete mode 100644 micro_sam/reproducibility.py create mode 100644 micro_sam/sam_annotator/reproducibility.py create mode 100644 test/test_sam_annotator/test_reproducibility.py diff --git a/micro_sam/reproducibility.py b/micro_sam/reproducibility.py deleted file mode 100644 index c1ba76328..000000000 --- a/micro_sam/reproducibility.py +++ /dev/null @@ -1,40 +0,0 @@ -import os -from typing import Union - -import zarr - - -# TODO add a test for this (with a prepared commit file) -def rerun_segmentation_from_commit_file( - commit_file: Union[str, os.PathLike], - input_path: Union[str, os.PathLike], -) -> None: - """ - - Args: - commit_file: - input_path: - """ - f = zarr.open(commit_file, mode="r") - - # TODO - # 1. Load the model according to the model description stored in the commit file. - - # 2. Go through the commit history and redo the action of each commit. - # Actions can be: - # - Committing an automatic segmentation result. - # - Committing an interactive segmentation result. - - -# TODO -def continue_annotation_from_commit_file( - commit_file: Union[str, os.PathLike], - input_path: Union[str, os.PathLike], -) -> None: - """ - """ - - -# TODO CLI for 'continue_annotation_from_commit_file' -def main(): - pass diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index a840d49bb..55301d70c 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -550,8 +550,8 @@ def _get_auto_segmentation_options(state, object_ids): segmentation_options = {"object_ids": [int(object_id) for object_id in object_ids]} if widget.with_decoder: - segmentation_options["boundary_distance_thresh"] = widget.boundary_distance_thresh - segmentation_options["center_distance_thresh"] = widget.center_distance_thresh + segmentation_options["boundary_distance_threshold"] = widget.boundary_distance_thresh + segmentation_options["center_distance_threshold"] = widget.center_distance_thresh else: segmentation_options["pred_iou_thresh"] = widget.pred_iou_thresh segmentation_options["stability_score_thresh"] = widget.stability_score_thresh @@ -582,18 +582,28 @@ def _get_promptable_segmentation_options(state, object_ids): return segmentation_options, is_tracking -def _commit_to_file(path, viewer, layer, seg, mask, bb, extra_attrs=None): +def _get_preservation_settings(state): + widget = state.widgets["commit"] + return { + "preserve_mode": widget.preserve_mode.value, + "preservation_threshold": widget.preservation_threshold.value, + } - # NOTE: zarr-python is quite inefficient and writes empty blocks. - # So we have to use z5py here. - # Deal with issues z5py has with empty folders and require the json. - if os.path.exists(path): - required_json = os.path.join(path, ".zgroup") +# Deal with issues z5py has with empty folders and require the json with zarr group metadata. +def _require_zarr_group_metadata(path, group_name=None): + path_ = path if group_name is None else os.path.join(path, group_name) + if os.path.exists(path_): + required_json = os.path.join(path_, ".zgroup") if not os.path.exists(required_json): with open(required_json, "w") as f: json.dump({"zarr_format": 2}, f) + +def _commit_to_file(path, viewer, layer, seg, mask, bb, extra_attrs=None): + + # NOTE: zarr-python is quite inefficient and writes empty blocks. So we have to use z5py here. + _require_zarr_group_metadata(path) f = z5py.ZarrFile(path, "a") state = AnnotatorState() @@ -609,6 +619,10 @@ def _save_signature(f, data_signature): ) for key, val in signature.items(): f.attrs[key] = val + # Add the annotator type to the signature. + # TODO need to merge the latest dev / master for this. + f.attrs["annotator_class"] = "Annotator2d" + # f.attrs["annotator_class"] = state. # If the data signature is saved in the file already, # then we check if saved data signature and data signature of our image agree. @@ -650,10 +664,14 @@ def _save_signature(f, data_signature): commit_history = f.attrs.get("commit_history", []) object_ids = np.unique(seg[mask]) + # Get the preservation settings from the commit widget. + preservation_settings = _get_preservation_settings(state) + # We committed an automatic segmentation. if layer == "auto_segmentation": # Save the settings of the segmentation widget. segmentation_options = _get_auto_segmentation_options(state, object_ids) + segmentation_options.update(**preservation_settings) commit_history.append({"auto_segmentation": segmentation_options}) # Write the commit history. @@ -664,10 +682,14 @@ def _save_signature(f, data_signature): return segmentation_options, is_tracking = _get_promptable_segmentation_options(state, object_ids) + segmentation_options.update(**preservation_settings) commit_history.append({"current_object": segmentation_options}) + # TODO add support for mask prompt (e.g. from polygon layer) def write_prompts(object_id, prompts, point_prompts, point_labels, track_state=None): - g = f.create_group(f"prompts/{object_id}") + group_name = f"prompts/{object_id}" + g = f.create_group(group_name) + _require_zarr_group_metadata(path, group_name) if prompts is not None and len(prompts) > 0: data = np.array(prompts) g.create_dataset("prompts", data=data, chunks=data.shape) diff --git a/micro_sam/sam_annotator/reproducibility.py b/micro_sam/sam_annotator/reproducibility.py new file mode 100644 index 000000000..9085400b7 --- /dev/null +++ b/micro_sam/sam_annotator/reproducibility.py @@ -0,0 +1,249 @@ +import os +import warnings +from typing import Optional, Union + +import numpy as np +from elf.io import open_file +from tqdm import tqdm + +from .. import util +from ..automatic_segmentation import get_predictor_and_segmenter +from ..instance_segmentation import mask_data_to_segmentation + +from ._widgets import _mask_matched_objects +from .util import prompt_segmentation + + +def _load_model_from_commit_file(f): + # Check which segmentation mode is used, by going through the commit history + # and checking if we have committed the 'auto_segmentation' layer. + # If we did, then we derive which mode was used from the serialized parameters. + amg = None + commit_history = f.attrs["commit_history"] + for commit in commit_history: + layer, options = next(iter(commit.items())) + if layer == "auto_segmentation": + amg = not ("boundary_distance_threshold" in options) + + # Get the predictor and segmenter. + predictor, segmenter = get_predictor_and_segmenter( + model_type=f.attrs["model_name"], + # checkpoint="", TODO we need to also serialize this + amg=amg, + is_tiled=f.attrs["tile_shape"] is not None, + ) + return predictor, segmenter + + +def _check_data_hash(f, input_data, input_path): + data_hash = util._compute_data_signature(input_data) + expected_hash = f.attrs["data_signature"] + if data_hash != expected_hash: + raise RuntimeError( + f"The hash of the input data loaded from {input_path} is {data_hash}, " + f"which does not match the expected hash {expected_hash}." + ) + + +def _write_masks_with_preservation(prev_seg, seg, preserve_mode, preservation_threshold, object_ids): + # Make sure the segmented ids and the object ids match (warn if not), + # and apply the id offset to match them. + segmented_ids = np.setdiff1d(np.unique(seg), [0]) + if len(segmented_ids) != len(object_ids): + warnings.warn( + f"The number of objects found by running auto segmentation (={len(segmented_ids)})" + f" does not match the number of expected objects (={len(object_ids)})." + ) + id_offset = int(np.min(object_ids)) - 1 + seg[seg != 0] += id_offset + + # Write the new segmentation to the previous segmentation, taking the preservation rules into account. + mask = seg != 0 + if preserve_mode != "none": + preserve_mask = prev_seg != 0 + if preserve_mask.sum() != 0: + # In the mode 'objects' we preserve committed objects instead, by comparing the overlaps + # of already committed and newly committed objects. + if preserve_mode == "objects": + preserve_mask = _mask_matched_objects(seg, prev_seg, preservation_threshold) + mask[preserve_mask] = 0 + prev_seg[mask] = seg[mask] + return prev_seg + + +def _rerun_interactive_segmentation(segmentation, f, predictor, image_embeddings, annotator_class, options): + object_ids = options.pop("object_ids") + preserve_mode, preservation_threshold = options.pop("preserve_mode"), options.pop("preservation_threshold") + g = f["prompts"] + + if annotator_class == "Annotator2d": + # Load the serialized prompts for all objects in this commit. + boxes, masks = [], [] + points, labels = [], [] + for object_id in object_ids: + prompt_group = g[str(object_id)] + if "point_prompts" in prompt_group: + points.append(prompt_group["point_prompts"][:]) + labels.append(prompt_group["point_labels"][:]) + if "prompts" in prompt_group: + boxes.append(prompt_group["prompts"][:]) + if "mask" in prompt_group: + masks.append(prompt_group["mask"][:]) + + # TODO + if not masks: + masks = None + if points: + points, labels = np.array(points), np.array(labels) + # else: + # points, labels = None, None + + batched = len(object_ids) > 1 + seg = prompt_segmentation( + predictor, points, labels, boxes, masks, segmentation.shape, image_embeddings=image_embeddings, + multiple_box_prompts=True, batched=batched, previous_segmentation=segmentation, + ) + + # TODO implement batched segmentation for these cases. + elif annotator_class == "AnnotatorTracking": + pass + elif annotator_class == "Annotator3d": + pass + else: + raise RuntimeError(f"Invalid annotator class {annotator_class}.") + + return _write_masks_with_preservation(segmentation, seg, preserve_mode, preservation_threshold, object_ids) + + +def _rerun_automatic_segmentation( + image, segmentation, predictor, segmenter, image_embeddings, annotator_class, options +): + object_ids = options.pop("object_ids") + preserve_mode, preservation_threshold = options.pop("preserve_mode"), options.pop("preservation_threshold") + with_background, min_object_size = options.pop("with_background"), options.pop("min_object_size") + + # If there was nothing committed then we don't need to rerun the automatic segmentation. + if len(object_ids) == 0: + return segmentation + + if annotator_class == "Annotator2d": + segmenter.initialize(image=image, image_embeddings=image_embeddings) + seg = segmenter.generate(**options) + seg = mask_data_to_segmentation(seg, with_background=with_background, min_object_size=min_object_size) + # TODO implement auto segmentation for these cases. + elif annotator_class == "AnnotatorTracking": + pass + elif annotator_class == "Annotator3d": + pass + else: + raise RuntimeError(f"Invalid annotator class {annotator_class}.") + + return _write_masks_with_preservation(segmentation, seg, preserve_mode, preservation_threshold, object_ids) + + +def rerun_segmentation_from_commit_file( + commit_file: Union[str, os.PathLike], + input_path: Union[str, os.PathLike, np.ndarray], + input_key: Optional[str] = None, + embedding_path: Optional[Union[str, os.PathLike]] = None, +) -> np.ndarray: + """ + + Args: + commit_file: + input_path: + input_key: + embedding_path: + """ + # Load the image data and open the zarr commit file. + input_data = util.load_image_data(input_path, key=input_key) + with open_file(commit_file, mode="r") as f: + + # Get the annotator class. + if "annotator_class" not in f.attrs: + raise RuntimeError( + f"You have saved the {commit_file} in a version that does not yet support rerunning the segmentation." + ) + annotator_class = f.attrs["annotator_class"] + ndim = 2 if annotator_class == "Annotator2d" else 3 + + # Check that the stored data hash and the input data hash match. + _check_data_hash(f, input_data, input_path) + + # Load the model according to the model description stored in the commit file. + predictor, segmenter = _load_model_from_commit_file(f) + + # Compute oder load the image embeddings. + image_embeddings = util.precompute_image_embeddings( + predictor=predictor, + input_=input_data, + save_path=embedding_path, + ndim=ndim, + tile_shape=f.attrs["tile_shape"], + halo=f.attrs["halo"], + ) + + # Go through the commit history and redo the action of each commit. + # Actions can be: + # - Committing an automatic segmentation result. + # - Committing an interactive segmentation result. + commit_history = f.attrs["commit_history"] + + # Rerun the commit history. + # TODO check if this works correctly for 3d data + shape = image_embeddings["original_size"] + segmentation = np.zeros(shape, dtype="uint32") + for commit in tqdm(commit_history, desc="Rerunning commit history"): + layer, options = next(iter(commit.items())) + if layer == "current_object": + segmentation = _rerun_interactive_segmentation( + segmentation, f, predictor, image_embeddings, annotator_class, options + ) + elif layer == "auto_segmentation": + segmentation = _rerun_automatic_segmentation( + input_data, segmentation, predictor, segmenter, image_embeddings, annotator_class, options + ) + else: + raise RuntimeError(f"Invalid layer {layer} in commit_historty.") + + return segmentation + + +def load_committed_objects_from_commit_file(commit_file: Union[str, os.PathLike]) -> np.ndarray: + """ + Args: + commit_file + + Returns: + AAA + """ + with open_file(commit_file, mode="r") as f: + return f["committed_objects"][:] + + +# TODO +def continue_annotation_from_commit_file( + commit_file: Union[str, os.PathLike], + input_path: Union[str, os.PathLike], + input_key: Optional[str] = None, + embedding_path: Optional[Union[str, os.PathLike]] = None, +) -> None: + """ + Args: + commit_file: + input_path: + input_key: + embedding_path: + """ + committed_objects = load_committed_objects_from_commit_file(commit_file) + with open_file(commit_file, mode="r") as f: + if "annotator_class" not in f.attrs: + raise RuntimeError( + f"You have saved the {commit_file} in a version that does not yet support rerunning the segmentation." + ) + annotator_class = f.attrs["annotator_class"] + + +# TODO CLI for 'continue_annotation_from_commit_file' +def main(): + pass diff --git a/test/test_sam_annotator/test_reproducibility.py b/test/test_sam_annotator/test_reproducibility.py new file mode 100644 index 000000000..f4b7f3fea --- /dev/null +++ b/test/test_sam_annotator/test_reproducibility.py @@ -0,0 +1,29 @@ +import os +import unittest + +import micro_sam.util as util +from elf.io import open_file +from micro_sam.sample_data import fetch_hela_2d_example_data + + +@unittest.skipUnless(util.VIT_T_SUPPORT, "Needs vit_t") +class TestReproducibility(unittest.TestCase): + commit_path = "commit-for-test.zarr" + + def test_automatic_mask_generator_2d(self): + from micro_sam.sam_annotator.reproducibility import rerun_segmentation_from_commit_file + + base_data_directory = os.path.join(util.get_cache_directory(), "sample_data") + input_path = fetch_hela_2d_example_data(base_data_directory) + segmentation = rerun_segmentation_from_commit_file(self.commit_path, input_path) + + with open_file(self.commit_path, "r") as f: + expected_segmentation = f["committed_objects"][:] + + breakpoint() + + # TODO check that segmentation and expected_segmentation are equivalent + + +if __name__ == "__main__": + unittest.main() From a6618e62fdae9c603c9717d78cece6afec369ebf Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 21 Apr 2025 13:21:32 +0200 Subject: [PATCH 3/7] Updates to automatic segmentation (#979) Support tracking in auto seg CLI --- micro_sam/automatic_segmentation.py | 105 +++++++++++++++++++- micro_sam/multi_dimensional_segmentation.py | 15 ++- 2 files changed, 111 insertions(+), 9 deletions(-) diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py index 6d88824ec..a0fe27914 100644 --- a/micro_sam/automatic_segmentation.py +++ b/micro_sam/automatic_segmentation.py @@ -1,8 +1,9 @@ import os +from functools import partial from glob import glob from tqdm import tqdm from pathlib import Path -from typing import Optional, Union, Tuple +from typing import Dict, List, Optional, Union, Tuple import numpy as np import imageio.v3 as imageio @@ -14,7 +15,7 @@ get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder, AMGBase, AutomaticMaskGenerator, TiledAutomaticMaskGenerator ) -from .multi_dimensional_segmentation import automatic_3d_segmentation +from .multi_dimensional_segmentation import automatic_3d_segmentation, automatic_tracking_implementation def get_predictor_and_segmenter( @@ -71,6 +72,87 @@ def _add_suffix_to_output_path(output_path: Union[str, os.PathLike], suffix: str return str(fpath.with_name(f"{fpath.stem}{suffix}{fext}")) +def automatic_tracking( + predictor: util.SamPredictor, + segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], + input_path: Union[Union[os.PathLike, str], np.ndarray], + output_path: Optional[Union[os.PathLike, str]] = None, + embedding_path: Optional[Union[os.PathLike, str]] = None, + key: Optional[str] = None, + tile_shape: Optional[Tuple[int, int]] = None, + halo: Optional[Tuple[int, int]] = None, + verbose: bool = True, + return_embeddings: bool = False, + annotate: bool = False, + batch_size: int = 1, + **generate_kwargs +) -> Tuple[np.ndarray, List[Dict]]: + """Run automatic tracking for the input timeseries. + + Args: + predictor: The Segment Anything model. + segmenter: The automatic instance segmentation class. + input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), + or a container file (e.g. hdf5 or zarr). + output_path: The output path where the instance segmentations will be saved. + embedding_path: The path where the embeddings are cached already / will be saved. + key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) + or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. + tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. + halo: Overlap of the tiles for tiled prediction. + verbose: Verbosity flag. + return_embeddings: Whether to return the precomputed image embeddings. + annotate: Whether to activate the annotator for continue annotation process. + batch_size: The batch size to compute image embeddings over tiles / z-planes. + By default, does it sequentially, i.e. one after the other. + generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. + + Returns: + """ + if output_path is not None: + # TODO implement saving tracking results in CTC format and use it to save the result here. + raise NotImplementedError("Saving the tracking result to file is currently not supported.") + + # Load the input image file. + if isinstance(input_path, np.ndarray): + image_data = input_path + else: + image_data = util.load_image_data(input_path, key) + + # We perform additional post-processing for AMG-only. + # Otherwise, we ignore additional post-processing for AIS. + if isinstance(segmenter, InstanceSegmentationWithDecoder): + generate_kwargs["output_mode"] = None + + if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): + raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") + + gap_closing, min_time_extent = generate_kwargs.get("gap_closing"), generate_kwargs.get("min_time_extent") + segmentation, lineage, image_embeddings = automatic_tracking_implementation( + image_data, + predictor, + segmenter, + embedding_path=embedding_path, + gap_closing=gap_closing, + min_time_extent=min_time_extent, + tile_shape=tile_shape, + halo=halo, + verbose=verbose, + batch_size=batch_size, + return_image_embeddings=True, + **generate_kwargs, + ) + + if annotate: + # TODO We need to support initialization of the tracking annotator with the tracking result for this. + raise NotImplementedError("Annotation after running the automated tracking is currently not supported.") + + if return_embeddings: + return segmentation, lineage, image_embeddings + else: + return segmentation, lineage + + def automatic_instance_segmentation( predictor: util.SamPredictor, segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], @@ -258,7 +340,13 @@ def main(): available_models = list(util.get_model_names()) available_models = ", ".join(available_models) - parser = argparse.ArgumentParser(description="Run automatic segmentation for an image.") + parser = argparse.ArgumentParser( + description="Run automatic segmentation for an image using either automatic instance segmentation (AIS) \n" + "or automatic mask generation (AMG). In addition to the arguments explained below,\n" + "you can also passed additional arguments for these two segmentation modes:\n" + "For AIS: '--center_distance_threshold', '--boundary_distance_threshold' and other arguments of `InstanceSegmentationWithDecoder.generate`." # noqa + "For AMG: '--pred_iou_thresh', '--stability_score_thresh' and other arguments of `AutomaticMaskGenerator.generate`." # noqa + ) parser.add_argument( "-i", "--input_path", required=True, type=str, nargs="+", help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " @@ -314,6 +402,10 @@ def main(): help="The batch size for computing image embeddings over tiles or z-plane. " "By default, computes the image embeddings for one tile / z-plane at a time." ) + parser.add_argument( + "--tracking", action="store_true", help="Run tracking instead of instance segmentation. " + "Only supported for timeseries inputs.." + ) parser.add_argument( "-v", "--verbose", action="store_true", help="Whether to allow verbosity of outputs." ) @@ -368,6 +460,10 @@ def _convert_argval(value): embedding_path = args.embedding_path has_one_input = len(input_paths) == 1 + instance_seg_function = automatic_tracking if args.tracking else partial( + automatic_instance_segmentation, ndim=args.ndim + ) + # Run automatic segmentation per image. for path in tqdm(input_paths, desc="Run automatic segmentation"): if has_one_input: # if we have one image only. @@ -393,14 +489,13 @@ def _convert_argval(value): os.makedirs(output_path, exist_ok=True) _output_fpath = os.path.join(output_path, Path(os.path.basename(path)).with_suffix(".tif")) - automatic_instance_segmentation( + instance_seg_function( predictor=predictor, segmenter=segmenter, input_path=path, output_path=_output_fpath, embedding_path=_embedding_fpath, key=args.key, - ndim=args.ndim, tile_shape=args.tile_shape, halo=args.halo, annotate=args.annotate, diff --git a/micro_sam/multi_dimensional_segmentation.py b/micro_sam/multi_dimensional_segmentation.py index eb98ec47c..7460bc294 100644 --- a/micro_sam/multi_dimensional_segmentation.py +++ b/micro_sam/multi_dimensional_segmentation.py @@ -622,7 +622,7 @@ def track_across_frames( return segmentation, lineage -def automatic_tracking( +def automatic_tracking_implementation( timeseries: np.ndarray, predictor: SamPredictor, segmentor: AMGBase, @@ -632,6 +632,8 @@ def automatic_tracking( tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, + return_embeddings: bool = False, + batch_size: int = 1, **kwargs, ) -> Tuple[np.ndarray, List[Dict]]: """Automatically track objects in a timesries based on per-frame automatic segmentation. @@ -650,6 +652,8 @@ def automatic_tracking( tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. halo: Overlap of the tiles for tiled prediction. verbose: Verbosity flag. + return_embeddings: Whether to return the precomputed image embeddings. + batch_size: The batch size to compute image embeddings over planes. kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. Returns: @@ -662,15 +666,18 @@ def automatic_tracking( raise RuntimeError( "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'." ) - segmentation, _ = _segment_slices( + segmentation, image_embeddings = _segment_slices( timeseries, predictor, segmentor, embedding_path, verbose, - tile_shape=tile_shape, halo=halo, + tile_shape=tile_shape, halo=halo, batch_size=batch_size, **kwargs, ) segmentation, lineage = track_across_frames( timeseries, segmentation, gap_closing=gap_closing, min_time_extent=min_time_extent, verbose=verbose, ) - return segmentation, lineage + if return_embeddings: + return segmentation, lineage, image_embeddings + else: + return segmentation, lineage def get_napari_track_data( From ec8f447d8bb8739ac1a9c5917747b32d5cb5afcc Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 21 Apr 2025 19:57:55 +0200 Subject: [PATCH 4/7] Implement instance segmentation training and update training tests (#1008) Implement instance segmentation training and update training tests --- micro_sam/training/__init__.py | 5 +- micro_sam/training/training.py | 327 +++++++++++++++--- .../training/train_instance_segmentation.py | 64 ++++ test/test_training.py | 55 ++- 4 files changed, 390 insertions(+), 61 deletions(-) create mode 100644 scripts/training/train_instance_segmentation.py diff --git a/micro_sam/training/__init__.py b/micro_sam/training/__init__.py index 576d72ce8..2fae39f1f 100644 --- a/micro_sam/training/__init__.py +++ b/micro_sam/training/__init__.py @@ -6,4 +6,7 @@ from .joint_sam_trainer import JointSamTrainer, JointSamLogger from .simple_sam_trainer import SimpleSamTrainer, MedSAMTrainer from .semantic_sam_trainer import SemanticSamTrainer, SemanticMapsSamTrainer -from .training import train_sam, train_sam_for_configuration, default_sam_loader, default_sam_dataset, CONFIGURATIONS +from .training import ( + train_sam, train_sam_for_configuration, train_instance_segmentation, default_sam_loader, default_sam_dataset, + export_instance_segmentation_model, CONFIGURATIONS, +) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 292e874d8..568afe828 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -2,7 +2,7 @@ import time import warnings from glob import glob -from tqdm import tqdm +from collections import OrderedDict from contextlib import contextmanager, nullcontext from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -13,6 +13,7 @@ from torch.utils.data import random_split from torch.utils.data import DataLoader, Dataset from torch.optim.lr_scheduler import _LRScheduler +from tqdm import tqdm import torch_em from torch_em.util import load_data @@ -28,7 +29,7 @@ from . import sam_trainer as trainers from ..instance_segmentation import get_unetr from . import joint_sam_trainer as joint_trainers -from ..util import get_device, get_model_names, export_custom_sam_model +from ..util import get_device, get_model_names, export_custom_sam_model, get_sam_model from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit, get_raw_transform @@ -165,6 +166,32 @@ def _count_parameters(model_parameters): print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)") +def _get_trainer_fit_params(n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training): + if n_iterations is None: + trainer_fit_params = {"epochs": n_epochs} + else: + trainer_fit_params = {"iterations": n_iterations} + + if save_every_kth_epoch is not None: + trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch + + if pbar_signals is not None: + progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) + trainer_fit_params["progress"] = progress_bar_wrapper + + # Avoid overwriting a trained model, if desired by the user. + trainer_fit_params["overwrite_training"] = overwrite_training + return trainer_fit_params + + +def _get_optimizer_and_scheduler(model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs): + optimizer = optimizer_class(model_params, lr=lr) + if scheduler_kwargs is None: + scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} + scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) + return optimizer, scheduler + + def train_sam( name: str, model_type: str, @@ -223,8 +250,7 @@ def train_sam( If passed None, the chosen default parameters are used in ReduceLROnPlateau. save_every_kth_epoch: Save checkpoints after every kth epoch separately. pbar_signals: Controls for napari progress bar. - optimizer_class: The optimizer class. - By default, torch.optim.AdamW is used. + optimizer_class: The optimizer class. By default, torch.optim.AdamW is used. peft_kwargs: Keyword arguments for the PEFT wrapper class. ignore_warnings: Whether to ignore raised warnings. verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. @@ -277,12 +303,9 @@ def train_sam( else: model_params = model.parameters() - optimizer = optimizer_class(model_params, lr=lr) - - if scheduler_kwargs is None: - scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} - - scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) + optimizer, scheduler = _get_optimizer_and_scheduler( + model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs + ) # The trainer which performs training and validation. if with_segmentation_decoder: @@ -330,21 +353,168 @@ def train_sam( save_root=save_root, ) - if n_iterations is None: - trainer_fit_params = {"epochs": n_epochs} - else: - trainer_fit_params = {"iterations": n_iterations} + trainer_fit_params = _get_trainer_fit_params( + n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training + ) + trainer.fit(**trainer_fit_params) + + t_run = time.time() - t_start + hours = int(t_run // 3600) + minutes = int(t_run // 60) + seconds = int(round(t_run % 60, 0)) + print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)") + + +def export_instance_segmentation_model( + trained_model_path: Union[str, os.PathLike], + output_path: Union[str, os.PathLike], + model_type: str, + initial_checkpoint_path: Optional[Union[str, os.PathLike]] = None, +) -> None: + """Export a model trained for instance segmentation with `train_instance_segmentation`. + + The exported model will be compatible with the micro_sam functions, CLI and napari plugin. + It should only be used for automatic segmentation and may not work well for interactive segmentation. + + Args: + trained_model_path: The path to the checkpoint of the model trained for instance segmentation. + output_path: The path where the exported model will be saved. + model_type: The model type. + initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional). + """ + trained_state = torch.load(trained_model_path, weights_only=False, map_location="cpu")["model_state"] + + # Get the state of the encoder and instance segmentation decoder from the trained checkpoint. + encoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if k.startswith("encoder")]) + decoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if not k.startswith("encoder")]) + + # Load the original state of the model that was used as the basis of instance segmentation training. + _, model_state = get_sam_model( + model_type=model_type, checkpoint_path=initial_checkpoint_path, return_state=True, device="cpu", + ) + # Remove the sam prefix if it's in the model state. + prefix = "sam." + model_state = OrderedDict( + [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()] + ) + + # Replace the image encoder state. + model_state = OrderedDict( + [(k, encoder_state[k[6:]] if k.startswith("image_encoder") else v) + for k, v in model_state.items()] + ) + + save_state = {"model_state": model_state, "decoder_state": decoder_state} + torch.save(save_state, output_path) + + +def train_instance_segmentation( + name: str, + model_type: str, + train_loader: DataLoader, + val_loader: DataLoader, + n_epochs: int = 100, + early_stopping: Optional[int] = 10, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, + freeze: Optional[List[str]] = None, + device: Optional[Union[str, torch.device]] = None, + lr: float = 1e-5, + save_root: Optional[Union[str, os.PathLike]] = None, + n_iterations: Optional[int] = None, + scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + save_every_kth_epoch: Optional[int] = None, + pbar_signals: Optional[QObject] = None, + optimizer_class: Optional[Optimizer] = torch.optim.AdamW, + peft_kwargs: Optional[Dict] = None, + ignore_warnings: bool = True, + overwrite_training: bool = True, + **model_kwargs, +) -> None: + """Train a UNETR for instance segmentation using the SAM encoder as backbone. + + This setting corresponds to training a SAM model with an instance segmentation decoder, + without training the model parts for interactive segmentation, + i.e. without training the prompt encoder and mask decoder. + + The checkpoint of the trained model, which will be saved in 'checkpoints/', + will not be compatible with the micro_sam functionality. + You can call the function `export_instance_segmentation_model` with the path to the checkpoint to export it + in a format that is compatible with micro_sam functionality. + Note that the exported model should only be used for automatic segmentation via AIS. + + Args: + name: The name of the model to be trained. The checkpoint and logs will have this name. + model_type: The type of the SAM model. + train_loader: The dataloader for training. + val_loader: The dataloader for validation. + n_epochs: The number of epochs to train for. + early_stopping: Enable early stopping after this number of epochs without improvement. + checkpoint_path: Path to checkpoint for initializing the SAM model. + freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen. + By default nothing is frozen and the full model is updated. + device: The device to use for training. + lr: The learning rate. + save_root: Optional root directory for saving the checkpoints and logs. + If not given the current working directory is used. + n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given. + scheduler_class: The learning rate scheduler to update the learning rate. + By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used. + scheduler_kwargs: The learning rate scheduler parameters. + If passed None, the chosen default parameters are used in ReduceLROnPlateau. + save_every_kth_epoch: Save checkpoints after every kth epoch separately. + pbar_signals: Controls for napari progress bar. + optimizer_class: The optimizer class. By default, torch.optim.AdamW is used. + peft_kwargs: Keyword arguments for the PEFT wrapper class. + ignore_warnings: Whether to ignore raised warnings. + overwrite_training: Whether to overwrite the trained model stored at the same location. + By default, overwrites the trained model at each run. + If set to 'False', it will avoid retraining the model if the previous run was completed. + model_kwargs: Additional keyword arguments for the `util.get_sam_model`. + """ - if save_every_kth_epoch is not None: - trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch + with _filter_warnings(ignore_warnings): + t_start = time.time() - if pbar_signals is not None: - progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) - trainer_fit_params["progress"] = progress_bar_wrapper + sam_model, state = get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + return_state=True, + peft_kwargs=peft_kwargs, + freeze=freeze, + **model_kwargs + ) + device = get_device(device) + model = get_unetr( + image_encoder=sam_model.sam.image_encoder, + decoder_state=state.get("decoder_state", None), + device=device, + ) - # Avoid overwriting a trained model, if desired by the user. - trainer_fit_params["overwrite_training"] = overwrite_training + loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) + optimizer, scheduler = _get_optimizer_and_scheduler( + model.parameters(), lr, optimizer_class, scheduler_class, scheduler_kwargs + ) + trainer = torch_em.trainer.DefaultTrainer( + name=name, + model=model, + train_loader=train_loader, + val_loader=val_loader, + device=device, + mixed_precision=True, + log_image_interval=50, + compile_model=False, + save_root=save_root, + loss=loss, + metric=loss, + optimizer=optimizer, + lr_scheduler=scheduler, + ) + trainer_fit_params = _get_trainer_fit_params( + n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training + ) trainer.fit(**trainer_fit_params) t_run = time.time() - t_start @@ -395,6 +565,7 @@ def default_sam_dataset( patch_shape: Tuple[int], with_segmentation_decoder: bool, with_channels: Optional[bool] = None, + train_instance_segmentation_only: bool = False, sampler: Optional[Callable] = None, raw_transform: Optional[Callable] = None, n_samples: Optional[int] = None, @@ -417,6 +588,8 @@ def default_sam_dataset( patch_shape: The shape for training patches. with_segmentation_decoder: Whether to train with additional segmentation decoder. with_channels: Whether the image data has channels. By default, it makes the decision based on inputs. + train_instance_segmentation_only: Set this argument to True in order to + pass the dataset to `train_instance_segmentation`. sampler: A sampler to reject batches according to a given criterion. raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit. @@ -430,6 +603,24 @@ def default_sam_dataset( The segmentation dataset. """ + # Check if this dataset should be used for instance segmentation only training. + # If yes, we set return_instances to False, since the instance channel must not + # be passed for this training mode. + return_instances = True + if train_instance_segmentation_only: + if not with_segmentation_decoder: + raise ValueError( + "If 'train_instance_segmentation_only' is True, then 'with_segmentation_decoder' must also be True." + ) + return_instances = False + + # If a sampler is not passed, then we set a MinInstanceSampler, which requires 3 distinct instances per sample. + # This is necessary, because training for interactive segmentation does not work on 'empty' images. + # However, if we train only the automatic instance segmentation decoder, then this sampler is not required + # and we do not set a default sampler. + if sampler is None and not train_instance_segmentation_only: + sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=min_size) + # By default, let the 'default_segmentation_dataset' heuristic decide for itself. is_seg_dataset = kwargs.pop("is_seg_dataset", None) @@ -470,16 +661,12 @@ def default_sam_dataset( boundary_distances=True, directed_distances=False, foreground=True, - instances=True, + instances=return_instances, min_size=min_size, ) else: label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size) - # Set a default sampler if none was passed. - if sampler is None: - sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size) - # Check the patch shape to add a singleton if required. patch_shape = _update_patch_shape( patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels, @@ -556,6 +743,8 @@ def default_sam_loader(**kwargs) -> DataLoader: "V100": {"model_type": "vit_b"}, "A100": {"model_type": "vit_h"}, } +"""Best training configurations for given hardware resources. +""" def _find_best_configuration(): @@ -579,10 +768,6 @@ def _find_best_configuration(): return "CPU" -"""Best training configurations for given hardware resources. -""" - - def train_sam_for_configuration( name: str, configuration: str, @@ -590,6 +775,7 @@ def train_sam_for_configuration( val_loader: DataLoader, checkpoint_path: Optional[Union[str, os.PathLike]] = None, with_segmentation_decoder: bool = True, + train_instance_segmentation_only: bool = False, model_type: Optional[str] = None, **kwargs, ) -> None: @@ -607,6 +793,8 @@ def train_sam_for_configuration( checkpoint_path: Path to checkpoint for initializing the SAM model. with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. + train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation + using the training implementation `train_instance_segmentation`. By default, `train_sam` is used. model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model. @@ -625,15 +813,26 @@ def train_sam_for_configuration( warnings.warn("You have specified a different model type.") train_kwargs.update(**kwargs) - train_sam( - name=name, - train_loader=train_loader, - val_loader=val_loader, - checkpoint_path=checkpoint_path, - with_segmentation_decoder=with_segmentation_decoder, - model_type=model_type, - **train_kwargs - ) + if train_instance_segmentation_only: + train_instance_segmentation( + name=name, + train_loader=train_loader, + val_loader=val_loader, + checkpoint_path=checkpoint_path, + with_segmentation_decoder=with_segmentation_decoder, + model_type=model_type, + **train_kwargs + ) + else: + train_sam( + name=name, + train_loader=train_loader, + val_loader=val_loader, + checkpoint_path=checkpoint_path, + with_segmentation_decoder=with_segmentation_decoder, + model_type=model_type, + **train_kwargs + ) def _export_helper(save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader): @@ -692,6 +891,21 @@ def _export_helper(save_root, checkpoint_name, output_path, model_type, with_seg return final_path +def _parse_segmentation_decoder(segmentation_decoder): + if segmentation_decoder in ("None", "none"): + with_segmentation_decoder, train_instance_segmentation_only = False, False + elif segmentation_decoder == "instances": + with_segmentation_decoder, train_instance_segmentation_only = True, False + elif segmentation_decoder == "instances_only": + with_segmentation_decoder, train_instance_segmentation_only = True, True + else: + raise ValueError( + "The 'segmentation_decoder' argument currently supports the values:\n" + f"'instances', 'instances_only', or 'None'. You have passed {segmentation_decoder}." + ) + return with_segmentation_decoder, train_instance_segmentation_only + + def main(): """@private""" import argparse @@ -753,12 +967,18 @@ def none_or_str(value): return None return value + # This could be extended to train for semantic segmentation or other options. parser.add_argument( "--segmentation_decoder", type=none_or_str, default="instances", - # TODO: in future, we can extend this to semantic seg / or even more advanced stuff. - help="Whether to finetune Segment Anything Model with additional segmentation decoder for desired targets. " - "By default, it uses the 'instances' option, i.e. trains with the additional segmentation decoder for " - "instance segmentation, otherwise pass 'None' for training without the additional segmentation decoder at all." + help="Whether to finetune Segment Anything Model with an additional segmentation decoder. " + "The following options are possible:\n" + "- 'instances' to train with an additional decoder for automatic instance segmentation. " + " This option enables using the automatic instance segmentation (AIS) mode.\n" + "- 'instances_only' to train only the instance segmentation decoder. " + " In this case the parts of SAM that are used for interactive segmentation will not be trained.\n" + "- 'None' to train without an additional segmentation decoder." + " This options trains only the parts of the original SAM.\n" + "By default the option 'instances' is used." ) # Optional advanced settings a user can opt to change the values for. @@ -825,12 +1045,7 @@ def none_or_str(value): device = args.device save_root = args.save_root output_path = args.output_path - - if args.segmentation_decoder and args.segmentation_decoder != "instances": - raise ValueError( - "The 'segmentation_decoder' argument currently supports 'instances' as input argument only." - ) - with_segmentation_decoder = (args.segmentation_decoder is not None) + with_segmentation_decoder, train_instance_segmentation_only = _parse_segmentation_decoder(args.segmentation_decoder) # Get image paths and corresponding keys. train_images, train_gt, train_image_key, train_gt_key = args.images, args.labels, args.image_key, args.label_key @@ -850,6 +1065,7 @@ def none_or_str(value): patch_shape=patch_shape, with_segmentation_decoder=with_segmentation_decoder, raw_transform=_raw_transform, + train_instance_segmentation_only=train_instance_segmentation_only, ) # If val images are not exclusively provided, we create a val split from the training data. @@ -868,6 +1084,7 @@ def none_or_str(value): label_key=val_gt_key, patch_shape=patch_shape, with_segmentation_decoder=with_segmentation_decoder, + train_instance_segmentation_only=train_instance_segmentation_only, raw_transform=_raw_transform, ) @@ -899,11 +1116,17 @@ def none_or_str(value): device=device, save_root=save_root, peft_kwargs=None, # TODO: Allow for PEFT. + train_instance_segmentation_only=train_instance_segmentation_only, ) # 4. Export the model, if desired by the user - final_path = _export_helper( - save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader - ) + if train_instance_segmentation_only and output_path: + trained_path = os.path.join("" if save_root is None else save_root, "checkpoints", checkpoint_name, "best.pt") + export_instance_segmentation_model(trained_path, output_path, model_type, checkpoint_path) + final_path = output_path + else: + final_path = _export_helper( + save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader, + ) print(f"Training has finished. The trained model is saved at {final_path}.") diff --git a/scripts/training/train_instance_segmentation.py b/scripts/training/train_instance_segmentation.py new file mode 100644 index 000000000..895cc3ce7 --- /dev/null +++ b/scripts/training/train_instance_segmentation.py @@ -0,0 +1,64 @@ +"""This is an example script for training a model only for automated instance segmentation. +""" +import os + +# This function downloads the DSB dataset, which we use as exaple data for this script. +# You can use any other data with images and associated label masks for training, +# for example images and label masks stored in a .tif format. +from torch_em.data.datasets.light_microscopy.dsb import get_dsb_paths + +# The required functionality for training. +from micro_sam.training import export_instance_segmentation_model, default_sam_loader, train_instance_segmentation + +image_paths, label_paths = get_dsb_paths("./data", source="reduced", split="train", download=True) + +# Use 10% of the data for validation. +# val_len = int(0.1 * len(image_paths)) +# train_images, val_images = image_paths[:-val_len], image_paths[-val_len:] +# train_labels, val_labels = label_paths[:-val_len], label_paths[-val_len:] +train_images, val_images = image_paths[:10], image_paths[-5:] +train_labels, val_labels = label_paths[:10], label_paths[-5:] + +# Run the training. This will train a UNETR with instance segmentation decoder. +# This is equivalent to training the additional instance segmentation decoder +# in the full training logic, WITHOUT training the model for interactive segmentation. + +# Adjust the patch shape to match your images! +patch_shape = (256, 256) + +# First, we create the training and validation loaders. +train_loader = default_sam_loader( + raw_paths=train_images, label_paths=train_labels, + raw_key=None, label_key=None, batch_size=1, + patch_shape=patch_shape, with_segmentation_decoder=True, + train_instance_segmentation_only=True, is_train=True, +) +val_loader = default_sam_loader( + raw_paths=val_images, label_paths=val_labels, + raw_key=None, label_key=None, batch_size=1, + patch_shape=patch_shape, with_segmentation_decoder=True, + train_instance_segmentation_only=True, is_train=False, +) + +# Choose the model type to start the training from. +# We recommend 'vit_b_lm' for any light microscopy images. +model_type = "vit_t_lm" + +# This is the name for the checkpoint that will be trained. +name = "ais-dsb" + +# Then run the training. Check out the docstring of the function for more training options. +train_instance_segmentation( + name=name, + model_type=model_type, + train_loader=train_loader, + val_loader=val_loader +) + +# Finally, we export the trained model to a new format that is compatible with micro_sam: +# This exported model can be used by micro_sam functions, the CLI or in the napari plugin. +# However, it may not work well for interactive segmentation, since it may suffer from 'catastrophic forgetting' +# for this task, because its image encoder was updated without training for interactive segmentation. +checkpoint_path = os.path.join("checkpoints", name, "best.pt") +export_path = "./ais-dsb-model.pt" +export_instance_segmentation_model(checkpoint_path, export_path, model_type) diff --git a/test/test_training.py b/test/test_training.py index 3a5f007df..3347f5831 100644 --- a/test/test_training.py +++ b/test/test_training.py @@ -10,8 +10,7 @@ from micro_sam.util import VIT_T_SUPPORT, get_sam_model, SamPredictor -# FIXME this now hangs on github not sure why -@unittest.skip("Test hangs on CI") +@unittest.skip("Not working in CI") @unittest.skipUnless(VIT_T_SUPPORT, "Integration test is only run with vit_t support, otherwise it takes too long.") class TestTraining(unittest.TestCase): """Integration test for training a SAM model. @@ -55,7 +54,7 @@ def tearDown(self): except OSError: pass - def _get_dataloader(self, split, patch_shape, batch_size): + def _get_dataloader(self, split, patch_shape, batch_size, train_instance_segmentation_only=False): import micro_sam.training as sam_training # Create the synthetic training data and get the corresponding folders. @@ -67,9 +66,10 @@ def _get_dataloader(self, split, patch_shape, batch_size): raw_paths=image_root, raw_key=raw_key, label_paths=label_root, label_key=label_key, patch_shape=patch_shape, batch_size=batch_size, - with_segmentation_decoder=False, + with_segmentation_decoder=train_instance_segmentation_only, shuffle=True, num_workers=1, - n_samples=self.n_images_train if split == "train" else self.n_images_val + n_samples=self.n_images_train if split == "train" else self.n_images_val, + train_instance_segmentation_only=train_instance_segmentation_only, ) return loader @@ -95,8 +95,9 @@ def _train_model(self, model_type, device): n_objects_per_batch=n_objects_per_batch, n_sub_iteration=n_sub_iteration, with_segmentation_decoder=False, + freeze=["image_encoder"], device=device, - save_root=self.tmp_folder + save_root=self.tmp_folder, ) def _export_model(self, checkpoint_path, export_path, model_type): @@ -134,8 +135,7 @@ def _run_inference_and_check_results( def test_training(self): import micro_sam.evaluation as evaluation - model_type = "vit_t" - device = "cpu" + model_type, device = "vit_t", "cpu" # Fine-tune the model. self._train_model(model_type=model_type, device=device) @@ -163,6 +163,45 @@ def test_training(self): inference_function=iterative_inference, expected_sa=0.8, ) + def test_train_instance_segmentation(self): + from micro_sam.training.training import train_instance_segmentation, export_instance_segmentation_model + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter + + model_type, device = "vit_t", "cpu" + batch_size, patch_shape = 1, (512, 512) + + # Get the dataloaders. + train_loader = self._get_dataloader("train", patch_shape, batch_size, train_instance_segmentation_only=True) + val_loader = self._get_dataloader("val", patch_shape, batch_size, train_instance_segmentation_only=True) + + # Run the training. + # We freeze the image encoder to speed up the training process. + name = "test-instance-seg-only" + train_instance_segmentation( + name=name, + model_type=model_type, + train_loader=train_loader, + val_loader=val_loader, + n_epochs=1, + device=device, + save_root=self.tmp_folder, + freeze=["image_encoder"], + ) + + checkpoint_path = os.path.join(self.tmp_folder, "checkpoints", name, "best.pt") + self.assertTrue(os.path.exists(checkpoint_path)) + + export_path = os.path.join(self.tmp_folder, "instance_segmentation_model.pt") + export_instance_segmentation_model(checkpoint_path, export_path, model_type) + self.assertTrue(os.path.exists(export_path)) + + # Check that this model works for AIS. + predictor, segmenter = get_predictor_and_segmenter(model_type, export_path, amg=False) + image_path = os.path.join(self.tmp_folder, "synthetic-data", "images", "test", "data-0.tif") + segmentation = automatic_instance_segmentation(predictor, segmenter, image_path) + expected_shape = imageio.imread(image_path).shape + self.assertEqual(segmentation.shape, expected_shape) + if __name__ == "__main__": unittest.main() From b82c4eba35cd1bf8ba641401f7cdbc4f20626859 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Mon, 21 Apr 2025 21:04:20 +0200 Subject: [PATCH 5/7] Ensure valid layers exist for each annotator click menu (#943) Ensure valid layers exist for each annotator click menu --------- Co-authored-by: Constantin Pape --- micro_sam/sam_annotator/_annotator.py | 38 ++- micro_sam/sam_annotator/_state.py | 4 + micro_sam/sam_annotator/_widgets.py | 47 +++- micro_sam/sam_annotator/annotator_2d.py | 13 +- micro_sam/sam_annotator/annotator_3d.py | 13 +- micro_sam/sam_annotator/annotator_tracking.py | 228 +++++++++++++----- test/{test_sam_annotator => }/test_cli.py | 0 7 files changed, 257 insertions(+), 86 deletions(-) rename test/{test_sam_annotator => }/test_cli.py (100%) diff --git a/micro_sam/sam_annotator/_annotator.py b/micro_sam/sam_annotator/_annotator.py index b4c042bf1..6300e7dd7 100644 --- a/micro_sam/sam_annotator/_annotator.py +++ b/micro_sam/sam_annotator/_annotator.py @@ -1,6 +1,8 @@ -import napari +from typing import Optional, List + import numpy as np +import napari from qtpy import QtWidgets from magicgui.widgets import Widget, Container, FunctionGui @@ -16,19 +18,39 @@ class _AnnotatorBase(QtWidgets.QScrollArea): The annotators differ in their data dimensionality and the widgets. """ - def _create_layers(self): + def _require_layers(self, layer_choices: Optional[List[str]] = None): + + # Check whether the image is initialized already. And use the image shape and scale for the layers. + state = AnnotatorState() + shape = self._shape if state.image_shape is None else state.image_shape + # Add the label layers for the current object, the automatic segmentation and the committed segmentation. - dummy_data = np.zeros(self._shape, dtype="uint32") + dummy_data = np.zeros(shape, dtype="uint32") + image_scale = state.image_scale # Before adding new layers, we always check whether a layer with this name already exists or not. if "current_object" not in self._viewer.layers: + if layer_choices and "current_object" in layer_choices: # Check at 'commit' call button. + widgets._validation_window_for_missing_layer("current_object") self._viewer.add_labels(data=dummy_data, name="current_object") + if image_scale is not None: + self.layers["current_objects"].scale = image_scale + if "auto_segmentation" not in self._viewer.layers: + if layer_choices and "auto_segmentation" in layer_choices: # Check at 'commit' call button. + widgets._validation_window_for_missing_layer("auto_segmentation") self._viewer.add_labels(data=dummy_data, name="auto_segmentation") + if image_scale is not None: + self.layers["auto_segmentation"].scale = image_scale + if "committed_objects" not in self._viewer.layers: + if layer_choices and "committed_objects" in layer_choices: # Check at 'commit' call button. + widgets._validation_window_for_missing_layer("committed_objects") self._viewer.add_labels(data=dummy_data, name="committed_objects") # Randomize colors so it is easy to see when object committed. self._viewer.layers["committed_objects"].new_colormap() + if image_scale is not None: + self.layers["committed_objects"].scale = image_scale # Add the point layer for point prompts. self._point_labels = ["positive", "negative"] @@ -70,7 +92,7 @@ def _create_widgets(self): # Create the prompt widget. (The same for all plugins.) self._prompt_widget = widgets.create_prompt_menu(self._point_prompt_layer, self._point_labels) - # Create the dictionray for the widgets and get the widgets of the child plugin. + # Create the dictionary for the widgets and get the widgets of the child plugin. self._widgets = {"embeddings": self._embedding_widget, "prompts": self._prompt_widget} self._widgets.update(self._get_widgets()) @@ -131,7 +153,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", ndim: int) -> None: # Initialize with a dummy shape, which is reset to the correct shape once an image is set. self._ndim = ndim self._shape = (256, 256) if ndim == 2 else (16, 256, 256) - self._create_layers() + self._require_layers() # Create all the widgets and add them to the layout. self._create_widgets() @@ -179,6 +201,9 @@ def _update_image(self, segmentation_result=None): ) self._shape = state.image_shape + # Before we reset the layers, we ensure all expected layers exist. + self._require_layers() + # Update the image scale. scale = state.image_scale @@ -187,12 +212,15 @@ def _update_image(self, segmentation_result=None): self._viewer.layers["current_object"].scale = scale self._viewer.layers["auto_segmentation"].data = np.zeros(self._shape, dtype="uint32") self._viewer.layers["auto_segmentation"].scale = scale + if segmentation_result is None or segmentation_result is False: self._viewer.layers["committed_objects"].data = np.zeros(self._shape, dtype="uint32") else: assert segmentation_result.shape == self._shape self._viewer.layers["committed_objects"].data = segmentation_result self._viewer.layers["committed_objects"].scale = scale + self._viewer.layers["point_prompts"].scale = scale self._viewer.layers["prompts"].scale = scale + vutil.clear_annotations(self._viewer, clear_segmentations=False) diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py index 09f5e9870..13ef8abca 100644 --- a/micro_sam/sam_annotator/_state.py +++ b/micro_sam/sam_annotator/_state.py @@ -14,6 +14,7 @@ import torch.nn as nn +import micro_sam import micro_sam.util as util from micro_sam.instance_segmentation import AMGBase, get_decoder from micro_sam.precompute_state import cache_amg_state, cache_is_state @@ -69,6 +70,9 @@ class AnnotatorState(metaclass=Singleton): # z-range to limit the data being committed in 3d / tracking. z_range: Optional[Tuple[int, int]] = None + # annotator_class + annotator: Optional["micro_sam.sam_annotator._annotator._AnnotatorBase"] = None + def initialize_predictor( self, image_data, diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index a840d49bb..c8a87dca2 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -30,9 +30,9 @@ # from napari.qt.threading import thread_worker from napari.utils import progress -from ._state import AnnotatorState from . import util as vutil from ._tooltips import get_tooltip +from ._state import AnnotatorState from .. import instance_segmentation, util from ..multi_dimensional_segmentation import ( segment_mask_in_volume, merge_instance_segmentation_3d, track_across_frames, PROJECTION_MODES, get_napari_track_data @@ -496,8 +496,12 @@ def _mask_matched_objects(seg, prev_seg, preservation_threshold): def _commit_impl(viewer, layer, preserve_mode, preservation_threshold): - # Check if we have a z_range. If yes, use it to set a bounding box. state = AnnotatorState() + + # Check whether all layers exist as expected or create new ones automatically. + state.annotator._require_layers(layer_choices=[layer, "committed_objects"]) + + # Check if we have a z_range. If yes, use it to set a bounding box. if state.z_range is None: bb = np.s_[:] else: @@ -750,6 +754,7 @@ def commit( commit_path: Select a file path where the committed results and prompts will be saved. This feature is still experimental. """ + # Commit the segmentation layer. _, seg, mask, bb = _commit_impl(viewer, layer, preserve_mode, preservation_threshold) if commit_path is not None: @@ -964,12 +969,27 @@ def _validate_embeddings(viewer: "napari.viewer.Viewer"): # return False -def _validate_prompts(viewer: "napari.viewer.Viewer") -> bool: - if len(viewer.layers["prompts"].data) == 0 and len(viewer.layers["point_prompts"].data) == 0: - msg = "No prompts were given. Please provide prompts to run interactive segmentation." - return _generate_message("error", msg) +def _validation_window_for_missing_layer(layer_choice): + if layer_choice == "committed_objects": + msg = "The 'committed_objects' layer to commit masks is missing. Please try to commit again." else: - return False + msg = f"The '{layer_choice}' layer to commit is missing. Please re-annotate and try again." + + return _generate_message(message_type="error", message=msg) + + +def _validate_layers(viewer: "napari.viewer.Viewer", automatic_segmentation: bool = False) -> bool: + # Check whether all layers exist as expected or create new ones automatically. + state = AnnotatorState() + state.annotator._require_layers() + + if not automatic_segmentation: + # Check prompts layer. + if len(viewer.layers["prompts"].data) == 0 and len(viewer.layers["point_prompts"].data) == 0: + msg = "No prompts were given. Please provide prompts to run interactive segmentation." + return _generate_message("error", msg) + else: + return False @magic_factory(call_button="Segment Object [S]") @@ -982,7 +1002,7 @@ def segment(viewer: "napari.viewer.Viewer", batched: bool = False) -> None: """ if _validate_embeddings(viewer): return None - if _validate_prompts(viewer): + if _validate_layers(viewer): return None shape = viewer.layers["current_object"].data.shape @@ -1016,7 +1036,7 @@ def segment_slice(viewer: "napari.viewer.Viewer") -> None: """ if _validate_embeddings(viewer): return None - if _validate_prompts(viewer): + if _validate_layers(viewer): return None shape = viewer.layers["current_object"].data.shape[1:] @@ -1057,8 +1077,9 @@ def segment_frame(viewer: "napari.viewer.Viewer") -> None: """ if _validate_embeddings(viewer): return None - if _validate_prompts(viewer): + if _validate_layers(viewer): return None + state = AnnotatorState() shape = state.image_shape[1:] position = viewer.dims.point @@ -1626,8 +1647,9 @@ def update_segmentation(seg): def __call__(self): if _validate_embeddings(self._viewer): return None - if _validate_prompts(self._viewer): + if _validate_layers(self._viewer): return None + if self.tracking: return self._run_tracking() else: @@ -1889,6 +1911,9 @@ def update_segmentation(seg): self._viewer.layers["auto_segmentation"].data[i] = seg self._viewer.layers["auto_segmentation"].refresh() + # Validate all layers. + _validate_layers(self._viewer, automatic_segmentation=True) + seg = seg_impl() update_segmentation(seg) # worker = seg_impl() diff --git a/micro_sam/sam_annotator/annotator_2d.py b/micro_sam/sam_annotator/annotator_2d.py index 5afbc63ef..d3f4b3e55 100644 --- a/micro_sam/sam_annotator/annotator_2d.py +++ b/micro_sam/sam_annotator/annotator_2d.py @@ -24,9 +24,18 @@ def _get_widgets(self): "clear": widgets.clear(), } - def __init__(self, viewer: "napari.viewer.Viewer") -> None: + def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: super().__init__(viewer=viewer, ndim=2) + # Set the expected annotator class to the state. + state = AnnotatorState() + + # Reset the state. + if reset_state: + state.reset_state() + + state.annotator = self + def annotator_2d( image: np.ndarray, @@ -85,7 +94,7 @@ def annotator_2d( viewer = napari.Viewer() viewer.add_image(image, name="image") - annotator = Annotator2d(viewer) + annotator = Annotator2d(viewer, reset_state=False) # Trigger layer update of the annotator so that layers have the correct shape. # And initialize the 'committed_objects' with the segmentation result if it was given. diff --git a/micro_sam/sam_annotator/annotator_3d.py b/micro_sam/sam_annotator/annotator_3d.py index 773653b8c..6a8ea79b5 100644 --- a/micro_sam/sam_annotator/annotator_3d.py +++ b/micro_sam/sam_annotator/annotator_3d.py @@ -24,10 +24,19 @@ def _get_widgets(self): "clear": widgets.clear_volume(), } - def __init__(self, viewer: "napari.viewer.Viewer") -> None: + def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: self._with_decoder = AnnotatorState().decoder is not None super().__init__(viewer=viewer, ndim=3) + # Set the expected annotator class to the state. + state = AnnotatorState() + + # Reset the state. + if reset_state: + state.reset_state() + + state.annotator = self + def _update_image(self, segmentation_result=None): super()._update_image(segmentation_result=segmentation_result) # Load the amg state from the embedding path. @@ -95,7 +104,7 @@ def annotator_3d( viewer = napari.Viewer() viewer.add_image(image, name="image") - annotator = Annotator3d(viewer) + annotator = Annotator3d(viewer, reset_state=False) # Trigger layer update of the annotator so that layers have the correct shape. # And initialize the 'committed_objects' with the segmentation result if it was given. diff --git a/micro_sam/sam_annotator/annotator_tracking.py b/micro_sam/sam_annotator/annotator_tracking.py index b9bc82252..6388e676d 100644 --- a/micro_sam/sam_annotator/annotator_tracking.py +++ b/micro_sam/sam_annotator/annotator_tracking.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, List import napari import numpy as np @@ -21,26 +21,40 @@ # This solution is a bit hacky, so I won't move it to _widgets.py yet. -def create_tracking_menu(points_layer, box_layer, states, track_ids): +def create_tracking_menu(points_layer, box_layer, states, track_ids, tracking_widget=None): """@private""" state = AnnotatorState() - state_menu = ComboBox(label="track_state", choices=states, tooltip=get_tooltip("annotator_tracking", "track_state")) - track_id_menu = ComboBox( - label="track_id", choices=list(map(str, track_ids)), tooltip=get_tooltip("annotator_tracking", "track_id") - ) - tracking_widget = Container(widgets=[state_menu, track_id_menu]) + def _get_widget_menu(container, label): + for w in container: + if isinstance(w, ComboBox) and w.label == label: + return w + raise ValueError(f"ComboBox with label '{label}' not found.") + + if tracking_widget is None: + state_menu = ComboBox( + label="track_state", choices=states, tooltip=get_tooltip("annotator_tracking", "track_state") + ) + track_id_menu = ComboBox( + label="track_id", choices=list(map(str, track_ids)), tooltip=get_tooltip("annotator_tracking", "track_id") + ) + tracking_widget = Container(widgets=[state_menu, track_id_menu]) + else: + state_menu = _get_widget_menu(tracking_widget, "track_state") + track_id_menu = _get_widget_menu(tracking_widget, "track_id") def update_state(event): - new_state = str(points_layer.current_properties["state"][0]) - if new_state != state_menu.value: - state_menu.value = new_state + if "state" in points_layer.current_properties: + new_state = str(points_layer.current_properties["state"][0]) + if new_state != state_menu.value: + state_menu.value = new_state def update_track_id(event): - new_id = str(points_layer.current_properties["track_id"][0]) - if new_id != track_id_menu.value: - track_id_menu.value = new_id - state.current_track_id = int(new_id) + if "track_id" in points_layer.current_properties: + new_id = str(points_layer.current_properties["track_id"][0]) + if new_id != track_id_menu.value: + track_id_menu.value = new_id + state.current_track_id = int(new_id) # def update_state_boxes(event): # new_state = str(box_layer.current_properties["state"][0]) @@ -48,10 +62,11 @@ def update_track_id(event): # state_menu.value = new_state def update_track_id_boxes(event): - new_id = str(box_layer.current_properties["track_id"][0]) - if new_id != track_id_menu.value: - track_id_menu.value = new_id - state.current_track_id = int(new_id) + if "track_id" in box_layer.current_properties: + new_id = str(box_layer.current_properties["track_id"][0]) + if new_id != track_id_menu.value: + track_id_menu.value = new_id + state.current_track_id = int(new_id) points_layer.events.current_properties.connect(update_state) points_layer.events.current_properties.connect(update_track_id) @@ -101,59 +116,131 @@ class AnnotatorTracking(_AnnotatorBase): # The tracking annotator needs different settings for the prompt layers # to support the additional tracking state. # That's why we over-ride this function. - def _create_layers(self): - self._point_labels = ["positive", "negative"] - self._track_state_labels = ["track", "division"] + def _require_layers(self, layer_choices: Optional[List[str]] = None): - self._point_prompt_layer = self._viewer.add_points( - name="point_prompts", - property_choices={ - "label": self._point_labels, - "state": self._track_state_labels, - "track_id": ["1"], # we use string to avoid pandas warning - }, - border_color="label", - border_color_cycle=vutil.LABEL_COLOR_CYCLE, - symbol="o", - face_color="state", - face_color_cycle=STATE_COLOR_CYCLE, - border_width=0.4, - size=12, - ndim=self._ndim, - ) - self._point_prompt_layer.border_color_mode = "cycle" - self._point_prompt_layer.face_color_mode = "cycle" - - # Using the box layer to set divisions currently doesn't work. - # That's why some of the code below is commented out. - self._box_prompt_layer = self._viewer.add_shapes( - shape_type="rectangle", - edge_width=4, - ndim=self._ndim, - face_color="transparent", - name="prompts", - edge_color="green", - property_choices={"track_id": ["1"]}, - # property_choces={"track_id": ["1"], "state": self._track_state_labels}, - # edge_color_cycle=STATE_COLOR_CYCLE, - ) - # self._box_prompt_layer.edge_color_mode = "cycle" + # Check whether the image is initialized already. And use the image shape and scale for the layers. + state = AnnotatorState() + shape = self._shape if state.image_shape is None else state.image_shape # Add the label layers for the current object, the automatic segmentation and the committed segmentation. - dummy_data = np.zeros(self._shape, dtype="uint32") - self._viewer.add_labels(data=dummy_data, name="current_object") - self._viewer.add_labels(data=dummy_data, name="auto_segmentation") - self._viewer.add_labels(data=dummy_data, name="committed_objects") - # Randomize colors so it is easy to see when object committed. - self._viewer.layers["committed_objects"].new_colormap() + dummy_data = np.zeros(shape, dtype="uint32") + image_scale = state.image_scale + + # Before adding new layers, we always check whether a layer with this name already exists or not. + if "current_object" not in self._viewer.layers: + if layer_choices and "current_object" in layer_choices: # Check at 'commit' call button. + widgets._validation_window_for_missing_layer("current_object") + self._viewer.add_labels(data=dummy_data, name="current_object") + if image_scale is not None: + self.layers["current_objects"].scale = image_scale + + if "auto_segmentation" not in self._viewer.layers: + if layer_choices and "auto_segmentation" in layer_choices: # Check at 'commit' call button. + widgets._validation_window_for_missing_layer("auto_segmentation") + self._viewer.add_labels(data=dummy_data, name="auto_segmentation") + if image_scale is not None: + self.layers["auto_segmentation"].scale = image_scale + + if "committed_objects" not in self._viewer.layers: + if layer_choices and "committed_objects" in layer_choices: # Check at 'commit' call button. + widgets._validation_window_for_missing_layer("committed_objects") + self._viewer.add_labels(data=dummy_data, name="committed_objects") + # Randomize colors so it is easy to see when object committed. + self._viewer.layers["committed_objects"].new_colormap() + if image_scale is not None: + self.layers["committed_objects"].scale = image_scale + + # Add the point prompts layer. + self._point_labels = ["positive", "negative"] + self._track_state_labels = ["track", "division"] + _point_prompt_property_choices = { + "label": self._point_labels, + "state": self._track_state_labels, + "track_id": ["1"], # we use string to avoid pandas warning + } + + point_layer_mismatch = True + if "point_prompts" in self._viewer.layers: + # Check whether the 'property_choices' match or not. + curr_property_choices = self._viewer.layers["point_prompts"].property_choices + point_layer_mismatch = set(curr_property_choices.keys()) != set(_point_prompt_property_choices.keys()) + + if point_layer_mismatch and "point_prompts" not in self._viewer.layers: + self._point_prompt_layer = self._viewer.add_points( + name="point_prompts", + property_choices=_point_prompt_property_choices, + border_color="label", + border_color_cycle=vutil.LABEL_COLOR_CYCLE, + symbol="o", + face_color="state", + face_color_cycle=STATE_COLOR_CYCLE, + border_width=0.4, + size=12, + ndim=self._ndim, + ) + self._point_prompt_layer.border_color_mode = "cycle" + self._point_prompt_layer.face_color_mode = "cycle" + _new_point_layer = True + else: + self._point_prompt_layer = self._viewer.layers["point_prompts"] + _new_point_layer = False + + # Add the point prompts layer. + _box_prompt_property_choices = {"track_id": ["1"]} + + box_layer_mismatch = True + if "prompts" in self._viewer.layers: + # Check whether the 'property_choices' match or not. + curr_property_choices = self._viewer.layers["prompts"].property_choices + box_layer_mismatch = set(curr_property_choices.keys()) != set(_box_prompt_property_choices.keys()) + + if box_layer_mismatch and "prompts" not in self._viewer.layers: + # Using the box layer to set divisions currently doesn't work. + # That's why some of the code below is commented out. + self._box_prompt_layer = self._viewer.add_shapes( + shape_type="rectangle", + edge_width=4, + ndim=self._ndim, + face_color="transparent", + name="prompts", + edge_color="green", + property_choices=_box_prompt_property_choices, + # property_choices={"track_id": ["1"], "state": self._track_state_labels}, + # edge_color_cycle=STATE_COLOR_CYCLE, + ) + # self._box_prompt_layer.edge_color_mode = "cycle" + _new_box_layer = True + else: + self._box_prompt_layer = self._viewer.layers["prompts"] + _new_box_layer = False + + # Trigger a new connection for the tracking state menu only when a new layer is (re)created. + if _new_point_layer or _new_box_layer: + self._tracking_widget = create_tracking_menu( + points_layer=self._point_prompt_layer, + box_layer=self._box_prompt_layer, + states=self._track_state_labels, + track_ids=list(state.lineage.keys()), + tracking_widget=state.widgets.get("tracking"), + ) + state.widgets["tracking"] = self._tracking_widget def _get_widgets(self): state = AnnotatorState() + self._require_layers() + # Create the tracking state menu. - self._tracking_widget = create_tracking_menu( - self._point_prompt_layer, self._box_prompt_layer, - states=self._track_state_labels, track_ids=list(state.lineage.keys()), - ) + # NOTE: Check whether it exists already from `_require_layers` or needs to be created. + if state.widgets.get("tracking") is None: + self._tracking_widget = create_tracking_menu( + ponts_layer=self._point_prompt_layer, + box_layer=self._box_prompt_layer, + states=self._track_state_labels, + track_ids=list(state.lineage.keys()), + ) + else: + self._tracking_widget = state.widgets.get("tracking") + segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=True) autotrack = widgets.AutoTrackWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True) return { @@ -165,7 +252,7 @@ def _get_widgets(self): "clear": widgets.clear_track(), } - def __init__(self, viewer: "napari.viewer.Viewer") -> None: + def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: # Initialize the state for tracking. self._init_track_state() self._with_decoder = AnnotatorState().decoder is not None @@ -173,6 +260,15 @@ def __init__(self, viewer: "napari.viewer.Viewer") -> None: # Go to t=0. self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:]) + # Set the expected annotator class to the state. + state = AnnotatorState() + + # Reset the state. + if reset_state: + state.reset_state() + + state.annotator = self + def _init_track_state(self): state = AnnotatorState() state.current_track_id = 1 @@ -239,7 +335,7 @@ def annotator_tracking( viewer = napari.Viewer() viewer.add_image(image, name="image") - annotator = AnnotatorTracking(viewer) + annotator = AnnotatorTracking(viewer, reset_state=False) # Trigger layer update of the annotator so that layers have the correct shape. annotator._update_image() diff --git a/test/test_sam_annotator/test_cli.py b/test/test_cli.py similarity index 100% rename from test/test_sam_annotator/test_cli.py rename to test/test_cli.py From 060c0c1197079df0020ab9792a64add4cddc3321 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 22 Apr 2025 15:45:23 +0200 Subject: [PATCH 6/7] Implement CLI for continuing annotation from commit file --- micro_sam/sam_annotator/_widgets.py | 5 +- micro_sam/sam_annotator/reproducibility.py | 96 +++++++++++++----- setup.cfg | 1 + .../commit-for-test.zarr/.zattrs | 1 + .../commit-for-test.zarr/.zgroup | 3 + .../committed_objects/.zarray | 20 ++++ .../committed_objects/0.0 | Bin 0 -> 5404 bytes .../commit-for-test.zarr/prompts/.zgroup | 3 + .../commit-for-test.zarr/prompts/10/.zgroup | 3 + .../prompts/10/point_labels/.zarray | 21 ++++ .../prompts/10/point_labels/0 | Bin 0 -> 40 bytes .../prompts/10/point_prompts/.zarray | 23 +++++ .../prompts/10/point_prompts/0.0 | Bin 0 -> 64 bytes .../commit-for-test.zarr/prompts/4/.zgroup | 3 + .../prompts/4/point_labels/.zarray | 21 ++++ .../prompts/4/point_labels/0 | Bin 0 -> 24 bytes .../prompts/4/point_prompts/.zarray | 23 +++++ .../prompts/4/point_prompts/0.0 | Bin 0 -> 32 bytes .../commit-for-test.zarr/prompts/5/.zgroup | 3 + .../prompts/5/point_labels/.zarray | 21 ++++ .../prompts/5/point_labels/0 | Bin 0 -> 24 bytes .../prompts/5/point_prompts/.zarray | 23 +++++ .../prompts/5/point_prompts/0.0 | Bin 0 -> 32 bytes .../commit-for-test.zarr/prompts/6/.zgroup | 3 + .../prompts/6/prompts/.zarray | 25 +++++ .../prompts/6/prompts/0.0.0 | Bin 0 -> 80 bytes .../commit-for-test.zarr/prompts/7/.zgroup | 3 + .../prompts/7/prompts/.zarray | 25 +++++ .../prompts/7/prompts/0.0.0 | Bin 0 -> 80 bytes .../commit-for-test.zarr/prompts/8/.zgroup | 3 + .../prompts/8/prompts/.zarray | 25 +++++ .../prompts/8/prompts/0.0.0 | Bin 0 -> 80 bytes .../commit-for-test.zarr/prompts/9/.zgroup | 3 + .../prompts/9/point_labels/.zarray | 21 ++++ .../prompts/9/point_labels/0 | Bin 0 -> 32 bytes .../prompts/9/point_prompts/.zarray | 23 +++++ .../prompts/9/point_prompts/0.0 | Bin 0 -> 48 bytes .../prompts/9/prompts/.zarray | 25 +++++ .../prompts/9/prompts/0.0.0 | Bin 0 -> 80 bytes .../test_reproducibility.py | 21 ++-- 40 files changed, 410 insertions(+), 37 deletions(-) create mode 100644 test/test_sam_annotator/commit-for-test.zarr/.zattrs create mode 100644 test/test_sam_annotator/commit-for-test.zarr/.zgroup create mode 100644 test/test_sam_annotator/commit-for-test.zarr/committed_objects/.zarray create mode 100644 test/test_sam_annotator/commit-for-test.zarr/committed_objects/0.0 create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/.zgroup create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/10/.zgroup create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/10/point_labels/.zarray create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/10/point_labels/0 create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/10/point_prompts/.zarray create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/10/point_prompts/0.0 create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/4/.zgroup create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/4/point_labels/.zarray create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/4/point_labels/0 create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/4/point_prompts/.zarray create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/4/point_prompts/0.0 create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/5/.zgroup create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/5/point_labels/.zarray create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/5/point_labels/0 create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/5/point_prompts/.zarray create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/5/point_prompts/0.0 create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/6/.zgroup create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/6/prompts/.zarray create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/6/prompts/0.0.0 create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/7/.zgroup create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/7/prompts/.zarray create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/7/prompts/0.0.0 create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/8/.zgroup create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/8/prompts/.zarray create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/8/prompts/0.0.0 create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/9/.zgroup create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/9/point_labels/.zarray create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/9/point_labels/0 create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/9/point_prompts/.zarray create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/9/point_prompts/0.0 create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/9/prompts/.zarray create mode 100644 test/test_sam_annotator/commit-for-test.zarr/prompts/9/prompts/0.0.0 diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 897d0def1..7736e3e1a 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -623,10 +623,9 @@ def _save_signature(f, data_signature): ) for key, val in signature.items(): f.attrs[key] = val + # Add the annotator type to the signature. - # TODO need to merge the latest dev / master for this. - f.attrs["annotator_class"] = "Annotator2d" - # f.attrs["annotator_class"] = state. + f.attrs["annotator_class"] = state.annotator.__class__.__name__ # If the data signature is saved in the file already, # then we check if saved data signature and data signature of our image agree. diff --git a/micro_sam/sam_annotator/reproducibility.py b/micro_sam/sam_annotator/reproducibility.py index 9085400b7..21615bc1e 100644 --- a/micro_sam/sam_annotator/reproducibility.py +++ b/micro_sam/sam_annotator/reproducibility.py @@ -11,6 +11,9 @@ from ..instance_segmentation import mask_data_to_segmentation from ._widgets import _mask_matched_objects +from .annotator_2d import annotator_2d +from .annotator_3d import annotator_3d +# from .annotator_tracking import annotator_tracking from .util import prompt_segmentation @@ -87,22 +90,27 @@ def _rerun_interactive_segmentation(segmentation, f, predictor, image_embeddings labels.append(prompt_group["point_labels"][:]) if "prompts" in prompt_group: boxes.append(prompt_group["prompts"][:]) - if "mask" in prompt_group: - masks.append(prompt_group["mask"][:]) + # We can only have a mask if we also have a box prompt. + if "mask" in prompt_group: + masks.append(prompt_group["mask"][:]) + else: + masks.append(None) - # TODO - if not masks: - masks = None if points: - points, labels = np.array(points), np.array(labels) - # else: - # points, labels = None, None + points = np.concatenate(points, axis=0) + labels = np.concatenate(labels, axis=0) + if boxes: + # Map boxes to the correct input format. + boxes = np.concatenate(boxes, axis=0) + boxes = [ + np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) for box in boxes + ] batched = len(object_ids) > 1 seg = prompt_segmentation( predictor, points, labels, boxes, masks, segmentation.shape, image_embeddings=image_embeddings, multiple_box_prompts=True, batched=batched, previous_segmentation=segmentation, - ) + ).astype("uint32") # TODO implement batched segmentation for these cases. elif annotator_class == "AnnotatorTracking": @@ -147,13 +155,16 @@ def rerun_segmentation_from_commit_file( input_key: Optional[str] = None, embedding_path: Optional[Union[str, os.PathLike]] = None, ) -> np.ndarray: - """ + """Rerun a segmentation from the commit history of a commit file. Args: - commit_file: - input_path: - input_key: - embedding_path: + commit_file: The path to the zarr file storing the commit history. + input_path: The path to the image data for the respective micro_sam commit history. + input_key: The key for the image data, in case it is a zarr, n5, hdf5 file or similar. + embedding_path: The path to precomputed embeddings for this project. + + Returns: + The segmentation recreated from the commit history. """ # Load the image data and open the zarr commit file. input_data = util.load_image_data(input_path, key=input_key) @@ -212,38 +223,71 @@ def rerun_segmentation_from_commit_file( def load_committed_objects_from_commit_file(commit_file: Union[str, os.PathLike]) -> np.ndarray: """ Args: - commit_file + commit_file: The path to the zarr file storing the commit history. Returns: - AAA + The committed segmentation. """ with open_file(commit_file, mode="r") as f: return f["committed_objects"][:] -# TODO -def continue_annotation_from_commit_file( +def continue_annotation( commit_file: Union[str, os.PathLike], input_path: Union[str, os.PathLike], input_key: Optional[str] = None, embedding_path: Optional[Union[str, os.PathLike]] = None, ) -> None: - """ + """Start an annotator from a commit file and set it to the commited state. + + This currently does not support files committed with annotator_tracking. + Args: - commit_file: - input_path: - input_key: - embedding_path: + commit_file: The path to the zarr file storing the commit history. + input_path: The path to the image data for the respective micro_sam commit history. + input_key: The key for the image data, in case it is a zarr, n5, hdf5 file or similar. + embedding_path: The path to precomputed embeddings for this project. """ committed_objects = load_committed_objects_from_commit_file(commit_file) with open_file(commit_file, mode="r") as f: if "annotator_class" not in f.attrs: raise RuntimeError( - f"You have saved the {commit_file} in a version that does not yet support rerunning the segmentation." + f"You have saved {commit_file} in a version that does not support continuing the annotation." ) annotator_class = f.attrs["annotator_class"] + model_type = f.attrs["model_name"] + tile_shape = f.attrs["tile_shape"] + halo = f.attrs["halo"] + + input_data = util.load_image_data(input_path, key=input_key) + if annotator_class == "Annotator2d": + annotator_2d( + input_data, embedding_path=embedding_path, segmentation_result=committed_objects, + model_type=model_type, tile_shape=tile_shape, halo=halo, + ) + elif annotator_class == "Annotator3d": + annotator_3d( + input_data, embedding_path=embedding_path, segmentation_result=committed_objects, + model_type=model_type, tile_shape=tile_shape, halo=halo, + ) + # We need to implement initialization of the tracking annotator with a segmentation + tracking state for this. + elif annotator_class == "AnnotatorTracking": + raise NotImplementedError("'continue_annotation_from_commit_file' is not yet supported for AnnotatorTracking.") + else: + raise RuntimeError(f"Unsupported annotator class {annotator_class}.") -# TODO CLI for 'continue_annotation_from_commit_file' def main(): - pass + import argparse + + parser = argparse.ArgumentParser( + description="Start an annotator from a commit file and set it to the commited state." + ) + parser.add_argument("-c", "--commit_file", required=True, help="The zarr file with the commit history.") + parser.add_argument("-i", "--input_path", required=True, help="The file path of the image data.") + parser.add_argument( + "-k", "--input_key", help="The input key for the image data. Required for zarr, n5 or hdf5 files." + ) + parser.add_argument("-e", "--embedding_path", help="Optional file path for precomputed embeddings.") + args = parser.parse_args() + continue_annotation(args.commit_file, args.input_path, args.input_key, args.embedding_path) diff --git a/setup.cfg b/setup.cfg index 24bbfec82..f8eb4300a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,6 +53,7 @@ console_scripts = micro_sam.evaluate = micro_sam.evaluation.evaluation:main micro_sam.info = micro_sam.util:micro_sam_info micro_sam.benchmark_sam = micro_sam.evaluation.benchmark_datasets:main + micro_sam.continue_annotation = micro_sam.sam_annotator.reproducibility:main # make sure it gets included in your package diff --git a/test/test_sam_annotator/commit-for-test.zarr/.zattrs b/test/test_sam_annotator/commit-for-test.zarr/.zattrs new file mode 100644 index 000000000..8428e7a4f --- /dev/null +++ b/test/test_sam_annotator/commit-for-test.zarr/.zattrs @@ -0,0 +1 @@ +{"annotator_class":"Annotator2d","commit_history":[{"auto_segmentation":{"boundary_distance_threshold":0.5,"center_distance_threshold":0.5,"min_object_size":100,"object_ids":[1,2,3],"preservation_threshold":0.75,"preserve_mode":"objects","with_background":true}},{"current_object":{"object_ids":[4,5],"preservation_threshold":0.75,"preserve_mode":"objects"}},{"current_object":{"object_ids":[6,7],"preservation_threshold":0.75,"preserve_mode":"objects"}},{"current_object":{"object_ids":[8],"preservation_threshold":0.75,"preserve_mode":"objects"}},{"current_object":{"object_ids":[9],"preservation_threshold":0.75,"preserve_mode":"objects"}},{"current_object":{"object_ids":[10],"preservation_threshold":0.75,"preserve_mode":"objects"}}],"data_signature":"11200d9b5539224a26ad0aae8dd974b0b97ca075","halo":null,"micro_sam_version":"1.4.0","model_hash":"xxh128:72ec5074774761a6e5c05a08942f981e","model_name":"vit_t_lm","model_type":"vit_t","tile_shape":null} diff --git a/test/test_sam_annotator/commit-for-test.zarr/.zgroup b/test/test_sam_annotator/commit-for-test.zarr/.zgroup new file mode 100644 index 000000000..3f3fad2d1 --- /dev/null +++ b/test/test_sam_annotator/commit-for-test.zarr/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} diff --git a/test/test_sam_annotator/commit-for-test.zarr/committed_objects/.zarray b/test/test_sam_annotator/commit-for-test.zarr/committed_objects/.zarray new file mode 100644 index 000000000..15d56101d --- /dev/null +++ b/test/test_sam_annotator/commit-for-test.zarr/committed_objects/.zarray @@ -0,0 +1,20 @@ +{ + "chunks": [ + 512, + 512 + ], + "compressor": { + "id": "zlib", + "level": 5 + }, + "dimension_separator": ".", + "dtype": "W>}Ulp|FOZ?|Z zQ|rsQzdczBKWmfwGtq}IK!cC4a%lx<22is>;{mxRXF~H*4;qZ z^*Mt)Z))Ngt=pT}+V^L0?_Fy0CtCM)BCKE?_}=uz1(16R+0$ePj+F$oeFVA%#IW%# zFtDHey~zelD`vp|$M5p*A^Em|WL9KxLH8}fXj}(c6Uh-z8fd{b$k&6X2g$x7o7{(( zeElwPF@n5%$tJf5+i;ly)p%?R5~(?nDD#YA{Y{Y<1ysAITx^A77(^*2PBvXh(yM01HW#j&P~4tcvwVKB}i(*JsMYo zW3qsK{x-RC;%dWw&>Bvzd9WNP*%h_BUj+^2fY!g|RH+)ppp{EH&TRzii;VkV0(cBV z7`b+C!@tMjSbkJBt$Q2M*U%3p8u0g~{lJU8q|=-e$nq5CFT4TzzBh(Rm;L+A{V?s| zBCUH3vGjNhOn4$+TR#}X_jGInH{;38Z4MxwDM7`wP!3Wil7i!B@yLzgxRpXw#jVVh5-Oc^4N9%|2S4B^7Zg>P@UjO`vO4R7Q7tU?UwwoV*R}4# z%VYN4V&o0Mbr~+$?Fckj#N{hZn3gz0YmFp-D(Z-bZ?zWmV2jsBWki5(0lYX&5-EHS z3a)AQ8A54{^Fq!K1%u`yrKvBTX|9vlrkK@F$-0add6$5ibIkXop2@VEgoS#C$0KCG zRqA7NJm!RIm++SQ7CPD$l|vh0N|eeo$Xfq;*QM$u8BFlbiNlD3P$9h0?bP{&m4QA~ zNweWE^L`_2P3@9x0jm?0OgbY6W429n;eOM8eKO_VHk&byOPy-+BoZFuH_?XEr$Y)j z<}2m2L!{$_a}<%?kQQ2M+}pO9qq#!TX5d(SyG*%f&NiFfdF8!N`Q(cH)lccvH+RpE z6LHD6lf;PIj1=gd+%sZ=$KvXjKr^;tYO!o~f@w=PGP{zP|=-yS@cYPxq)+IV;g*UytB zff%FtV-;`@Sj{&@Jy=HmW-4!Cd0=FRfrzhTb;6y}SEj#>4<|#QsLD-6UAD{pYw;5|OBMD*CE~qO(=p0viscf>=?*0B zHuRS!Fs2S88^c$wi?o88iN)LXY4>S}cIwCBWbb#h=Ma!pwUG!{GBcy|@~y{rcGL?} z`yVHg6HS?|hR|XqQ)rtTM{L!RfAoO}aJ(KrxO*FT%p&PR!(S+;f%=>eqzt!)EmZjz zp%y}%xMtQ<$a7$zChe*Ndq|z*hmnLz8}(29qRxa_jxIb8jX)}R_8A;r)+xY+{~5>6 zPSKalYB@FAIbyUMh$hDqzACM1U%d6ZzNO2ke}1Q!vey@j+uLitX_JnVP5t&6Nki2{T!HfzM1wO)dx*9CD4 zcPJ$abfXP34^-8I#r?~s*qXpSPrZe*=6q*DT>W6r!Z8T6s|72vS!Pr`cs3&_O91iu z_@j$F(yh&!q!C`4ko?E+jYRU!r6UOtRaIx(1c%daDUd>5Pc~vJDBOk7qs=!*z7X7Z zZO=PGESADDn@r3IhuOIx_w?ELi6bg#BaxbhhM7C(5V#J3u9q@%Yyebr{c;Db9`1mERv@Gh=i%4 zr*Ov;Ol$ot>*`Uo!->|^N0I0ol^ zwVyoq{bliBK|C0d-PhY9NlpF)*vb++PX>-rZRO~IdsYCrKMvq{b{d3KGU3GR8^;< z72a*rf0=)N36jKw&ym6*mA_=X{}d5Xs(*3}u^esr0pFLR@q)HpON--GnAvM?;FWJI z^i^gMGnc2v)7cr^X$&s01YNA?Y9w5*@keL0k2E>?L#~K9hD&Z~%a@?NCZ|w0CZ~6x96z&g$RAD?K(d3e=X)3hG#4sa<_J8 zlBH2$qDYKQ$GYzKH!_1z^KY5RNc)AfYcV%4I92m^cu2?2W|>%;sw-fx;jugG4UMpw zn_-u|Zz_duatxPP-5AiVNcl2ZSH+e=n)ipBwHfymmcob11&C?3^G{al6PQRojpUeO z8iy3{3H)V9jwtc+cA@DIR^b_|I|OU6`70|6>?;$S?-6{NCw@U3>5ncBN>8Zp5kZ;0 zG3E*eF^jv8e5b8&2JH&&6mc<62FJWp5{p_lLKr#S#LkFq98lmne$Z``)vZ$1F~#dJ zK50x?+@N1DHyq_4_5yh^9-iVd!ZMH6_R#rDS)Uu0TzLs)t_)64;E^?0M0CYX$l5#E zCUI&@wxQG_&RjIC?cb5eK^z24wKykb^$06yRFc{rxOmAclNog7Jzhz`iK~sQ3RHZe z5>F+Ef1I4c*E+Xg_&LAW2)K0TQ?1O|)ur|2q38I8fB6PLUxXYxVw=DiSx?9Km#?G{ zTd(mCZxfdHc?iF~i+Bhpe}ZkECHyTOX&3ERquPN$I##*<#dkv`;31M_;Uup#T=LcyX)bGsvEy4@U=Tfm=;*6YSzCzNm-u2iZT1ng*46}xp z;~N{b<#^=-mk83c(uK0JcsjOtfEq?2$4iD=Wz@AwJ>`rviL!s;GvN$!mP>rwDy$36 zupfqudqIY!;1Z>AsUoC(_6%MkW{=FW$yE<18GcPx{vo+t>PW!{&NFzDQt@7Y$u=tG zo1;ZWPETDWez4;z>>b%GUTvm3k|=?j`A5+QBr6C@%`{HWD{!yIK}Y0>MCIKvOb4EC3u!nLKHwM=>|FLZ0pPUGJIR(gD9k~<*(ktim@JsbqO3&`-hEg!n_4) zHPT`->N6CPKIsS_nnVIz?ITdx;ng6<@K>6cDF%=R{ zsvb_dA}5C;4JgzMm6WzBBt$VqVJk0t=*yr+kfHLZ9R3}P!sJE$w&H{u_n&SN{IbXl zstfKUOLHdc<#=IT*hZ_;sG^(PAYXWrqEsg$+rz(8Swk5-O`$c$@k#auB-{&P7*t2a z(6arLk^tpM!zL(4W-6_e__VuFm>v~QI?p}yvTCQN`^nLuh4`#UhIM4HPWEB3@x>F@ z1V+6I2DHeDM-A0Ug}Ao*?s_MilU$4RNUDNe$ z2^7bwX8st}!F$3Y(he%X^>ZT$sV*e}??NTB65jQjZug_K)l`s}5uk z%?)9oHI{rE>Odp2v(}l`9S22(7hn>l4GG|T6jlv*$6A`S?jWm-0H?Q6e@(Fi26RJI zFPNx=#pVRIAe@-(|3e7mL3`(Ha<>T2O)mp1@~sDNfMVReaqa&a@JhM<_nDH;!`fT} zjukppVzgFgW!Q&Jnx*X3|+YMPu#ej&XOpA5IHDTSg)O((K*-U4%|Fv{YNs0-~v`3P?) z$AGR3aRzNrQLEjNAE8wW6D;7N8=!R#GKFekk@#gY7%CHoNntm7R!2LK2+86|9K?OQ z?sdc%EEpy)dkeU0mBMY|oBV~*BcT2_LQc(<m=<-J-^*c1!3fYs29(;n?z-<;zvrN(-}sClhb#s+uo|$CoU(491P{ zaAcGAzKqbfNbN4_o`)`EIdc6%^H_7m4#*InA@T`R|NAue);3OL&XNtS<<=CR>7D)0&T144$v@+x^ zFv*l)CsW;p^MhdB9!Q9IUayuakT|Nl9nT8N-U03w1#qT@KO&z++9a(<6ruz=y3zty zuvpBQ8u|z}Z=PaK3TEA@(!Oio;Nx`LnsnHd61Z$hRydfYL$i*JoQH=I&P29O(OfKk z(F_{ggcl1kzJWX1E|d-t`MMXnQsoDx;TgqZC7m*8X$KCHew5*3B**Y(ZaGnkq({d(|oJ&QwtB_VU5ul9AQ_~8-itGWe-@#B^grFHN1^s*ZnJD z>>a&Cj2YhsCz(NkJT9|B!+{xwkj(cGraRDlTks;cVKzA)WSg5v94&}WMa~+QgDQ-= zKE@|p+gE^LbftW>lJIF?4ic>G(G``d-nYU6+BWG|~q@j-1?L<5pd$^pa jGQg8D4zCM)h)LT`ASP*s1~d-;_h*EY^ErX{*P8ze-+Q^L literal 0 HcmV?d00001 diff --git a/test/test_sam_annotator/commit-for-test.zarr/prompts/.zgroup b/test/test_sam_annotator/commit-for-test.zarr/prompts/.zgroup new file mode 100644 index 000000000..3f3fad2d1 --- /dev/null +++ b/test/test_sam_annotator/commit-for-test.zarr/prompts/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} diff --git a/test/test_sam_annotator/commit-for-test.zarr/prompts/10/.zgroup b/test/test_sam_annotator/commit-for-test.zarr/prompts/10/.zgroup new file mode 100644 index 000000000..3f3fad2d1 --- /dev/null +++ b/test/test_sam_annotator/commit-for-test.zarr/prompts/10/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} diff --git a/test/test_sam_annotator/commit-for-test.zarr/prompts/10/point_labels/.zarray b/test/test_sam_annotator/commit-for-test.zarr/prompts/10/point_labels/.zarray new file mode 100644 index 000000000..57d571518 --- /dev/null +++ b/test/test_sam_annotator/commit-for-test.zarr/prompts/10/point_labels/.zarray @@ -0,0 +1,21 @@ +{ + "chunks": [ + 3 + ], + "compressor": { + "blocksize": 0, + "clevel": 5, + "cname": "lz4", + "id": "blosc", + "shuffle": 1 + }, + "dimension_separator": ".", + "dtype": ";tL%c$D^KzsxFj003`Z7uEm( literal 0 HcmV?d00001 diff --git a/test/test_sam_annotator/commit-for-test.zarr/prompts/4/.zgroup b/test/test_sam_annotator/commit-for-test.zarr/prompts/4/.zgroup new file mode 100644 index 000000000..3f3fad2d1 --- /dev/null +++ b/test/test_sam_annotator/commit-for-test.zarr/prompts/4/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} diff --git a/test/test_sam_annotator/commit-for-test.zarr/prompts/4/point_labels/.zarray b/test/test_sam_annotator/commit-for-test.zarr/prompts/4/point_labels/.zarray new file mode 100644 index 000000000..28283849b --- /dev/null +++ b/test/test_sam_annotator/commit-for-test.zarr/prompts/4/point_labels/.zarray @@ -0,0 +1,21 @@ +{ + "chunks": [ + 1 + ], + "compressor": { + "blocksize": 0, + "clevel": 5, + "cname": "lz4", + "id": "blosc", + "shuffle": 1 + }, + "dimension_separator": ".", + "dtype": "=OTx;fNRM{CyE>ZSGfqd literal 0 HcmV?d00001 diff --git a/test/test_sam_annotator/commit-for-test.zarr/prompts/6/.zgroup b/test/test_sam_annotator/commit-for-test.zarr/prompts/6/.zgroup new file mode 100644 index 000000000..3f3fad2d1 --- /dev/null +++ b/test/test_sam_annotator/commit-for-test.zarr/prompts/6/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} diff --git a/test/test_sam_annotator/commit-for-test.zarr/prompts/6/prompts/.zarray b/test/test_sam_annotator/commit-for-test.zarr/prompts/6/prompts/.zarray new file mode 100644 index 000000000..98aab8e36 --- /dev/null +++ b/test/test_sam_annotator/commit-for-test.zarr/prompts/6/prompts/.zarray @@ -0,0 +1,25 @@ +{ + "chunks": [ + 1, + 4, + 2 + ], + "compressor": { + "blocksize": 0, + "clevel": 5, + "cname": "lz4", + "id": "blosc", + "shuffle": 1 + }, + "dimension_separator": ".", + "dtype": "TVeRAr6CS5{tpWqg9@nvhs4{P%AUtI)xNY0HF8zy^G=;03XB= Ag8%>k literal 0 HcmV?d00001 diff --git a/test/test_sam_annotator/commit-for-test.zarr/prompts/9/prompts/.zarray b/test/test_sam_annotator/commit-for-test.zarr/prompts/9/prompts/.zarray new file mode 100644 index 000000000..98aab8e36 --- /dev/null +++ b/test/test_sam_annotator/commit-for-test.zarr/prompts/9/prompts/.zarray @@ -0,0 +1,25 @@ +{ + "chunks": [ + 1, + 4, + 2 + ], + "compressor": { + "blocksize": 0, + "clevel": 5, + "cname": "lz4", + "id": "blosc", + "shuffle": 1 + }, + "dimension_separator": ".", + "dtype": "-qE2{y+y9|AeCKiZ!nb9G1 Date: Sun, 4 May 2025 19:52:50 +0200 Subject: [PATCH 7/7] Add support for passing already compute image embeddings WIP --- micro_sam/automatic_segmentation.py | 33 +++--- micro_sam/multi_dimensional_segmentation.py | 30 +++--- micro_sam/sam_annotator/reproducibility.py | 105 ++++++++++++-------- 3 files changed, 102 insertions(+), 66 deletions(-) diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py index a0fe27914..e5c340d83 100644 --- a/micro_sam/automatic_segmentation.py +++ b/micro_sam/automatic_segmentation.py @@ -77,7 +77,7 @@ def automatic_tracking( segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], input_path: Union[Union[os.PathLike, str], np.ndarray], output_path: Optional[Union[os.PathLike, str]] = None, - embedding_path: Optional[Union[os.PathLike, str]] = None, + embedding_path: Optional[Union[os.PathLike, str, util.ImageEmbeddings]] = None, key: Optional[str] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, @@ -96,6 +96,7 @@ def automatic_tracking( or a container file (e.g. hdf5 or zarr). output_path: The output path where the instance segmentations will be saved. embedding_path: The path where the embeddings are cached already / will be saved. + This argument also accepts already deserialized embeddings. key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. @@ -158,7 +159,7 @@ def automatic_instance_segmentation( segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], input_path: Union[Union[os.PathLike, str], np.ndarray], output_path: Optional[Union[os.PathLike, str]] = None, - embedding_path: Optional[Union[os.PathLike, str]] = None, + embedding_path: Optional[Union[os.PathLike, str, util.ImageEmbeddings]] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, @@ -178,6 +179,7 @@ def automatic_instance_segmentation( or a container file (e.g. hdf5 or zarr). output_path: The output path where the instance segmentations will be saved. embedding_path: The path where the embeddings are cached already / will be saved. + This argument also accepts already deserialized embeddings. key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. ndim: The dimensionality of the data. By default the dimensionality of the data will be used. @@ -219,16 +221,19 @@ def automatic_instance_segmentation( raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") # Precompute the image embeddings. - image_embeddings = util.precompute_image_embeddings( - predictor=predictor, - input_=image_data, - save_path=embedding_path, - ndim=ndim, - tile_shape=tile_shape, - halo=halo, - verbose=verbose, - batch_size=batch_size, - ) + if embedding_path is None or isinstance(embedding_path, (str, os.PathLike)): + image_embeddings = util.precompute_image_embeddings( + predictor=predictor, + input_=image_data, + save_path=embedding_path, + ndim=ndim, + tile_shape=tile_shape, + halo=halo, + verbose=verbose, + batch_size=batch_size, + ) + else: + image_embeddings = embedding_path initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose) # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing. @@ -242,14 +247,14 @@ def automatic_instance_segmentation( masks = segmenter.generate(**generate_kwargs) if isinstance(masks, list): - # whether the predictions from 'generate' are list of dict, + # Whether the predictions from 'generate' are list of dict, # which contains additional info req. for post-processing, eg. area per object. if len(masks) == 0: instances = np.zeros(image_data.shape[:2], dtype="uint32") else: instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) else: - # if (raw) predictions provided, store them as it is w/o further post-processing. + # If (raw) predictions provided, store them as it is w/o further post-processing. instances = masks else: diff --git a/micro_sam/multi_dimensional_segmentation.py b/micro_sam/multi_dimensional_segmentation.py index 7460bc294..2f9e7036e 100644 --- a/micro_sam/multi_dimensional_segmentation.py +++ b/micro_sam/multi_dimensional_segmentation.py @@ -375,16 +375,20 @@ def _segment_slices( assert data.ndim == 3 min_object_size = kwargs.pop("min_object_size", 0) - image_embeddings = util.precompute_image_embeddings( - predictor=predictor, - input_=data, - save_path=embedding_path, - ndim=3, - tile_shape=tile_shape, - halo=halo, - verbose=verbose, - batch_size=batch_size, - ) + # Check if the embeddings still have to be computed. + if embedding_path is None or isinstance(embedding_path, (str, os.PathLike)): + image_embeddings = util.precompute_image_embeddings( + predictor=predictor, + input_=data, + save_path=embedding_path, + ndim=3, + tile_shape=tile_shape, + halo=halo, + verbose=verbose, + batch_size=batch_size, + ) + else: # Otherwise the deserialized embeddings were passed. + image_embeddings = embedding_path offset = 0 segmentation = np.zeros(data.shape, dtype="uint32") @@ -417,7 +421,7 @@ def automatic_3d_segmentation( volume: np.ndarray, predictor: SamPredictor, segmentor: AMGBase, - embedding_path: Optional[Union[str, os.PathLike]] = None, + embedding_path: Optional[Union[str, os.PathLike, util.ImageEmbeddings]] = None, with_background: bool = True, gap_closing: Optional[int] = None, min_z_extent: Optional[int] = None, @@ -438,6 +442,7 @@ def automatic_3d_segmentation( predictor: The SAM model. segmentor: The instance segmentation class. embedding_path: The path to save pre-computed embeddings. + This argument also accepts already deserialized embeddings. with_background: Whether the segmentation has background. gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing. @@ -626,7 +631,7 @@ def automatic_tracking_implementation( timeseries: np.ndarray, predictor: SamPredictor, segmentor: AMGBase, - embedding_path: Optional[Union[str, os.PathLike]] = None, + embedding_path: Optional[Union[str, os.PathLike, util.ImageEmbeddings]] = None, gap_closing: Optional[int] = None, min_time_extent: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, @@ -646,6 +651,7 @@ def automatic_tracking_implementation( predictor: The SAM model. segmentor: The instance segmentation class. embedding_path: The path to save pre-computed embeddings. + This argument also accepts already deserialized embeddings. gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing. min_time_extent: Require a minimal extent in time for the tracked objects. diff --git a/micro_sam/sam_annotator/reproducibility.py b/micro_sam/sam_annotator/reproducibility.py index 21615bc1e..73e537c9f 100644 --- a/micro_sam/sam_annotator/reproducibility.py +++ b/micro_sam/sam_annotator/reproducibility.py @@ -7,14 +7,14 @@ from tqdm import tqdm from .. import util -from ..automatic_segmentation import get_predictor_and_segmenter -from ..instance_segmentation import mask_data_to_segmentation +from ..automatic_segmentation import automatic_tracking, automatic_instance_segmentation, get_predictor_and_segmenter +from ..multi_dimensional_segmentation import segment_mask_in_volume from ._widgets import _mask_matched_objects from .annotator_2d import annotator_2d from .annotator_3d import annotator_3d # from .annotator_tracking import annotator_tracking -from .util import prompt_segmentation +from .util import prompt_segmentation, segment_slices_with_prompts def _load_model_from_commit_file(f): @@ -79,26 +79,28 @@ def _rerun_interactive_segmentation(segmentation, f, predictor, image_embeddings preserve_mode, preservation_threshold = options.pop("preserve_mode"), options.pop("preservation_threshold") g = f["prompts"] + # Load the serialized prompts for all objects in this commit. + boxes, masks = [], [] + points, labels = [], [] + for object_id in object_ids: + prompt_group = g[str(object_id)] + if "point_prompts" in prompt_group: + points.append(prompt_group["point_prompts"][:]) + labels.append(prompt_group["point_labels"][:]) + if "prompts" in prompt_group: + boxes.append(prompt_group["prompts"][:]) + # We can only have a mask if we also have a box prompt. + if "mask" in prompt_group: + masks.append(prompt_group["mask"][:]) + else: + masks.append(None) + + if points: + points = np.concatenate(points, axis=0) + labels = np.concatenate(labels, axis=0) + if annotator_class == "Annotator2d": - # Load the serialized prompts for all objects in this commit. - boxes, masks = [], [] - points, labels = [], [] - for object_id in object_ids: - prompt_group = g[str(object_id)] - if "point_prompts" in prompt_group: - points.append(prompt_group["point_prompts"][:]) - labels.append(prompt_group["point_labels"][:]) - if "prompts" in prompt_group: - boxes.append(prompt_group["prompts"][:]) - # We can only have a mask if we also have a box prompt. - if "mask" in prompt_group: - masks.append(prompt_group["mask"][:]) - else: - masks.append(None) - - if points: - points = np.concatenate(points, axis=0) - labels = np.concatenate(labels, axis=0) + if boxes: # Map boxes to the correct input format. boxes = np.concatenate(boxes, axis=0) @@ -112,11 +114,24 @@ def _rerun_interactive_segmentation(segmentation, f, predictor, image_embeddings multiple_box_prompts=True, batched=batched, previous_segmentation=segmentation, ).astype("uint32") - # TODO implement batched segmentation for these cases. - elif annotator_class == "AnnotatorTracking": - pass elif annotator_class == "Annotator3d": pass + # TODO this needs to be updated + # seg, slices, stop_lower, stop_upper = segment_slices_with_prompts( + # state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"], + # image_embeddings, shape, + # ) + # seg, (z_min, z_max) = segment_mask_in_volume( + # seg, predictor, image_embeddings, slices, + # stop_lower, stop_upper, + # iou_threshold=self.iou_threshold, projection=self.projection, + # box_extension=self.box_extension, + # update_progress=lambda update: pbar_signals.pbar_update.emit(update), + # ) + + elif annotator_class == "AnnotatorTracking": + raise NotImplementedError("Not yet implemented for AnnotatorTracking.") + else: raise RuntimeError(f"Invalid annotator class {annotator_class}.") @@ -124,25 +139,31 @@ def _rerun_interactive_segmentation(segmentation, f, predictor, image_embeddings def _rerun_automatic_segmentation( - image, segmentation, predictor, segmenter, image_embeddings, annotator_class, options + image, segmentation, predictor, segmenter, image_embeddings, annotator_class, tile_shape, halo, options ): object_ids = options.pop("object_ids") preserve_mode, preservation_threshold = options.pop("preserve_mode"), options.pop("preservation_threshold") - with_background, min_object_size = options.pop("with_background"), options.pop("min_object_size") # If there was nothing committed then we don't need to rerun the automatic segmentation. if len(object_ids) == 0: return segmentation - if annotator_class == "Annotator2d": - segmenter.initialize(image=image, image_embeddings=image_embeddings) - seg = segmenter.generate(**options) - seg = mask_data_to_segmentation(seg, with_background=with_background, min_object_size=min_object_size) - # TODO implement auto segmentation for these cases. + if annotator_class in ("Annotator2d", "Annotator3d"): + ndim = 2 if annotator_class == "Annotator2d" else 3 + seg = automatic_instance_segmentation( + predictor, segmenter, image, + embedding_path=image_embeddings, + tile_shape=tile_shape, halo=halo, ndim=ndim, + **options + ) elif annotator_class == "AnnotatorTracking": - pass - elif annotator_class == "Annotator3d": - pass + seg, lineages = automatic_tracking( + predictor, segmenter, image, + embedding_path=image_embeddings, + tile_shape=tile_shape, halo=halo, ndim=ndim, + **options + ) + else: raise RuntimeError(f"Invalid annotator class {annotator_class}.") @@ -178,6 +199,9 @@ def rerun_segmentation_from_commit_file( annotator_class = f.attrs["annotator_class"] ndim = 2 if annotator_class == "Annotator2d" else 3 + # Get the tile shape and halo from the attributes. + tile_shape, halo = f.attrs["tile_shape"], f.attrs["halo"] + # Check that the stored data hash and the input data hash match. _check_data_hash(f, input_data, input_path) @@ -190,8 +214,8 @@ def rerun_segmentation_from_commit_file( input_=input_data, save_path=embedding_path, ndim=ndim, - tile_shape=f.attrs["tile_shape"], - halo=f.attrs["halo"], + tile_shape=tile_shape, + halo=halo, ) # Go through the commit history and redo the action of each commit. @@ -208,11 +232,12 @@ def rerun_segmentation_from_commit_file( layer, options = next(iter(commit.items())) if layer == "current_object": segmentation = _rerun_interactive_segmentation( - segmentation, f, predictor, image_embeddings, annotator_class, options + segmentation, f, predictor, image_embeddings, annotator_class, tile_shape, halo, options ) elif layer == "auto_segmentation": segmentation = _rerun_automatic_segmentation( - input_data, segmentation, predictor, segmenter, image_embeddings, annotator_class, options + input_data, segmentation, predictor, segmenter, image_embeddings, + annotator_class, tile_shape, halo, options ) else: raise RuntimeError(f"Invalid layer {layer} in commit_historty.") @@ -274,7 +299,7 @@ def continue_annotation( elif annotator_class == "AnnotatorTracking": raise NotImplementedError("'continue_annotation_from_commit_file' is not yet supported for AnnotatorTracking.") else: - raise RuntimeError(f"Unsupported annotator class {annotator_class}.") + raise RuntimeError(f"Invalid annotator class {annotator_class}.") def main():