diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py index 6d88824ec..e5c340d83 100644 --- a/micro_sam/automatic_segmentation.py +++ b/micro_sam/automatic_segmentation.py @@ -1,8 +1,9 @@ import os +from functools import partial from glob import glob from tqdm import tqdm from pathlib import Path -from typing import Optional, Union, Tuple +from typing import Dict, List, Optional, Union, Tuple import numpy as np import imageio.v3 as imageio @@ -14,7 +15,7 @@ get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder, AMGBase, AutomaticMaskGenerator, TiledAutomaticMaskGenerator ) -from .multi_dimensional_segmentation import automatic_3d_segmentation +from .multi_dimensional_segmentation import automatic_3d_segmentation, automatic_tracking_implementation def get_predictor_and_segmenter( @@ -71,12 +72,94 @@ def _add_suffix_to_output_path(output_path: Union[str, os.PathLike], suffix: str return str(fpath.with_name(f"{fpath.stem}{suffix}{fext}")) +def automatic_tracking( + predictor: util.SamPredictor, + segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], + input_path: Union[Union[os.PathLike, str], np.ndarray], + output_path: Optional[Union[os.PathLike, str]] = None, + embedding_path: Optional[Union[os.PathLike, str, util.ImageEmbeddings]] = None, + key: Optional[str] = None, + tile_shape: Optional[Tuple[int, int]] = None, + halo: Optional[Tuple[int, int]] = None, + verbose: bool = True, + return_embeddings: bool = False, + annotate: bool = False, + batch_size: int = 1, + **generate_kwargs +) -> Tuple[np.ndarray, List[Dict]]: + """Run automatic tracking for the input timeseries. + + Args: + predictor: The Segment Anything model. + segmenter: The automatic instance segmentation class. + input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), + or a container file (e.g. hdf5 or zarr). + output_path: The output path where the instance segmentations will be saved. + embedding_path: The path where the embeddings are cached already / will be saved. + This argument also accepts already deserialized embeddings. + key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) + or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. + tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. + halo: Overlap of the tiles for tiled prediction. + verbose: Verbosity flag. + return_embeddings: Whether to return the precomputed image embeddings. + annotate: Whether to activate the annotator for continue annotation process. + batch_size: The batch size to compute image embeddings over tiles / z-planes. + By default, does it sequentially, i.e. one after the other. + generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. + + Returns: + """ + if output_path is not None: + # TODO implement saving tracking results in CTC format and use it to save the result here. + raise NotImplementedError("Saving the tracking result to file is currently not supported.") + + # Load the input image file. + if isinstance(input_path, np.ndarray): + image_data = input_path + else: + image_data = util.load_image_data(input_path, key) + + # We perform additional post-processing for AMG-only. + # Otherwise, we ignore additional post-processing for AIS. + if isinstance(segmenter, InstanceSegmentationWithDecoder): + generate_kwargs["output_mode"] = None + + if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): + raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") + + gap_closing, min_time_extent = generate_kwargs.get("gap_closing"), generate_kwargs.get("min_time_extent") + segmentation, lineage, image_embeddings = automatic_tracking_implementation( + image_data, + predictor, + segmenter, + embedding_path=embedding_path, + gap_closing=gap_closing, + min_time_extent=min_time_extent, + tile_shape=tile_shape, + halo=halo, + verbose=verbose, + batch_size=batch_size, + return_image_embeddings=True, + **generate_kwargs, + ) + + if annotate: + # TODO We need to support initialization of the tracking annotator with the tracking result for this. + raise NotImplementedError("Annotation after running the automated tracking is currently not supported.") + + if return_embeddings: + return segmentation, lineage, image_embeddings + else: + return segmentation, lineage + + def automatic_instance_segmentation( predictor: util.SamPredictor, segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], input_path: Union[Union[os.PathLike, str], np.ndarray], output_path: Optional[Union[os.PathLike, str]] = None, - embedding_path: Optional[Union[os.PathLike, str]] = None, + embedding_path: Optional[Union[os.PathLike, str, util.ImageEmbeddings]] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, @@ -96,6 +179,7 @@ def automatic_instance_segmentation( or a container file (e.g. hdf5 or zarr). output_path: The output path where the instance segmentations will be saved. embedding_path: The path where the embeddings are cached already / will be saved. + This argument also accepts already deserialized embeddings. key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. ndim: The dimensionality of the data. By default the dimensionality of the data will be used. @@ -137,16 +221,19 @@ def automatic_instance_segmentation( raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") # Precompute the image embeddings. - image_embeddings = util.precompute_image_embeddings( - predictor=predictor, - input_=image_data, - save_path=embedding_path, - ndim=ndim, - tile_shape=tile_shape, - halo=halo, - verbose=verbose, - batch_size=batch_size, - ) + if embedding_path is None or isinstance(embedding_path, (str, os.PathLike)): + image_embeddings = util.precompute_image_embeddings( + predictor=predictor, + input_=image_data, + save_path=embedding_path, + ndim=ndim, + tile_shape=tile_shape, + halo=halo, + verbose=verbose, + batch_size=batch_size, + ) + else: + image_embeddings = embedding_path initialize_kwargs = dict(image=image_data, image_embeddings=image_embeddings, verbose=verbose) # If we run AIS with tiling then we use the same tile shape for the watershed postprocessing. @@ -160,14 +247,14 @@ def automatic_instance_segmentation( masks = segmenter.generate(**generate_kwargs) if isinstance(masks, list): - # whether the predictions from 'generate' are list of dict, + # Whether the predictions from 'generate' are list of dict, # which contains additional info req. for post-processing, eg. area per object. if len(masks) == 0: instances = np.zeros(image_data.shape[:2], dtype="uint32") else: instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) else: - # if (raw) predictions provided, store them as it is w/o further post-processing. + # If (raw) predictions provided, store them as it is w/o further post-processing. instances = masks else: @@ -258,7 +345,13 @@ def main(): available_models = list(util.get_model_names()) available_models = ", ".join(available_models) - parser = argparse.ArgumentParser(description="Run automatic segmentation for an image.") + parser = argparse.ArgumentParser( + description="Run automatic segmentation for an image using either automatic instance segmentation (AIS) \n" + "or automatic mask generation (AMG). In addition to the arguments explained below,\n" + "you can also passed additional arguments for these two segmentation modes:\n" + "For AIS: '--center_distance_threshold', '--boundary_distance_threshold' and other arguments of `InstanceSegmentationWithDecoder.generate`." # noqa + "For AMG: '--pred_iou_thresh', '--stability_score_thresh' and other arguments of `AutomaticMaskGenerator.generate`." # noqa + ) parser.add_argument( "-i", "--input_path", required=True, type=str, nargs="+", help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " @@ -314,6 +407,10 @@ def main(): help="The batch size for computing image embeddings over tiles or z-plane. " "By default, computes the image embeddings for one tile / z-plane at a time." ) + parser.add_argument( + "--tracking", action="store_true", help="Run tracking instead of instance segmentation. " + "Only supported for timeseries inputs.." + ) parser.add_argument( "-v", "--verbose", action="store_true", help="Whether to allow verbosity of outputs." ) @@ -368,6 +465,10 @@ def _convert_argval(value): embedding_path = args.embedding_path has_one_input = len(input_paths) == 1 + instance_seg_function = automatic_tracking if args.tracking else partial( + automatic_instance_segmentation, ndim=args.ndim + ) + # Run automatic segmentation per image. for path in tqdm(input_paths, desc="Run automatic segmentation"): if has_one_input: # if we have one image only. @@ -393,14 +494,13 @@ def _convert_argval(value): os.makedirs(output_path, exist_ok=True) _output_fpath = os.path.join(output_path, Path(os.path.basename(path)).with_suffix(".tif")) - automatic_instance_segmentation( + instance_seg_function( predictor=predictor, segmenter=segmenter, input_path=path, output_path=_output_fpath, embedding_path=_embedding_fpath, key=args.key, - ndim=args.ndim, tile_shape=args.tile_shape, halo=args.halo, annotate=args.annotate, diff --git a/micro_sam/multi_dimensional_segmentation.py b/micro_sam/multi_dimensional_segmentation.py index eb98ec47c..2f9e7036e 100644 --- a/micro_sam/multi_dimensional_segmentation.py +++ b/micro_sam/multi_dimensional_segmentation.py @@ -375,16 +375,20 @@ def _segment_slices( assert data.ndim == 3 min_object_size = kwargs.pop("min_object_size", 0) - image_embeddings = util.precompute_image_embeddings( - predictor=predictor, - input_=data, - save_path=embedding_path, - ndim=3, - tile_shape=tile_shape, - halo=halo, - verbose=verbose, - batch_size=batch_size, - ) + # Check if the embeddings still have to be computed. + if embedding_path is None or isinstance(embedding_path, (str, os.PathLike)): + image_embeddings = util.precompute_image_embeddings( + predictor=predictor, + input_=data, + save_path=embedding_path, + ndim=3, + tile_shape=tile_shape, + halo=halo, + verbose=verbose, + batch_size=batch_size, + ) + else: # Otherwise the deserialized embeddings were passed. + image_embeddings = embedding_path offset = 0 segmentation = np.zeros(data.shape, dtype="uint32") @@ -417,7 +421,7 @@ def automatic_3d_segmentation( volume: np.ndarray, predictor: SamPredictor, segmentor: AMGBase, - embedding_path: Optional[Union[str, os.PathLike]] = None, + embedding_path: Optional[Union[str, os.PathLike, util.ImageEmbeddings]] = None, with_background: bool = True, gap_closing: Optional[int] = None, min_z_extent: Optional[int] = None, @@ -438,6 +442,7 @@ def automatic_3d_segmentation( predictor: The SAM model. segmentor: The instance segmentation class. embedding_path: The path to save pre-computed embeddings. + This argument also accepts already deserialized embeddings. with_background: Whether the segmentation has background. gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing. @@ -622,16 +627,18 @@ def track_across_frames( return segmentation, lineage -def automatic_tracking( +def automatic_tracking_implementation( timeseries: np.ndarray, predictor: SamPredictor, segmentor: AMGBase, - embedding_path: Optional[Union[str, os.PathLike]] = None, + embedding_path: Optional[Union[str, os.PathLike, util.ImageEmbeddings]] = None, gap_closing: Optional[int] = None, min_time_extent: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, verbose: bool = True, + return_embeddings: bool = False, + batch_size: int = 1, **kwargs, ) -> Tuple[np.ndarray, List[Dict]]: """Automatically track objects in a timesries based on per-frame automatic segmentation. @@ -644,12 +651,15 @@ def automatic_tracking( predictor: The SAM model. segmentor: The instance segmentation class. embedding_path: The path to save pre-computed embeddings. + This argument also accepts already deserialized embeddings. gap_closing: If given, gaps in the segmentation are closed with a binary closing operation. The value is used to determine the number of iterations for the closing. min_time_extent: Require a minimal extent in time for the tracked objects. tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. halo: Overlap of the tiles for tiled prediction. verbose: Verbosity flag. + return_embeddings: Whether to return the precomputed image embeddings. + batch_size: The batch size to compute image embeddings over planes. kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. Returns: @@ -662,15 +672,18 @@ def automatic_tracking( raise RuntimeError( "Automatic tracking requires trackastra. You can install it via 'pip install trackastra'." ) - segmentation, _ = _segment_slices( + segmentation, image_embeddings = _segment_slices( timeseries, predictor, segmentor, embedding_path, verbose, - tile_shape=tile_shape, halo=halo, + tile_shape=tile_shape, halo=halo, batch_size=batch_size, **kwargs, ) segmentation, lineage = track_across_frames( timeseries, segmentation, gap_closing=gap_closing, min_time_extent=min_time_extent, verbose=verbose, ) - return segmentation, lineage + if return_embeddings: + return segmentation, lineage, image_embeddings + else: + return segmentation, lineage def get_napari_track_data( diff --git a/micro_sam/sam_annotator/_annotator.py b/micro_sam/sam_annotator/_annotator.py index b4c042bf1..6300e7dd7 100644 --- a/micro_sam/sam_annotator/_annotator.py +++ b/micro_sam/sam_annotator/_annotator.py @@ -1,6 +1,8 @@ -import napari +from typing import Optional, List + import numpy as np +import napari from qtpy import QtWidgets from magicgui.widgets import Widget, Container, FunctionGui @@ -16,19 +18,39 @@ class _AnnotatorBase(QtWidgets.QScrollArea): The annotators differ in their data dimensionality and the widgets. """ - def _create_layers(self): + def _require_layers(self, layer_choices: Optional[List[str]] = None): + + # Check whether the image is initialized already. And use the image shape and scale for the layers. + state = AnnotatorState() + shape = self._shape if state.image_shape is None else state.image_shape + # Add the label layers for the current object, the automatic segmentation and the committed segmentation. - dummy_data = np.zeros(self._shape, dtype="uint32") + dummy_data = np.zeros(shape, dtype="uint32") + image_scale = state.image_scale # Before adding new layers, we always check whether a layer with this name already exists or not. if "current_object" not in self._viewer.layers: + if layer_choices and "current_object" in layer_choices: # Check at 'commit' call button. + widgets._validation_window_for_missing_layer("current_object") self._viewer.add_labels(data=dummy_data, name="current_object") + if image_scale is not None: + self.layers["current_objects"].scale = image_scale + if "auto_segmentation" not in self._viewer.layers: + if layer_choices and "auto_segmentation" in layer_choices: # Check at 'commit' call button. + widgets._validation_window_for_missing_layer("auto_segmentation") self._viewer.add_labels(data=dummy_data, name="auto_segmentation") + if image_scale is not None: + self.layers["auto_segmentation"].scale = image_scale + if "committed_objects" not in self._viewer.layers: + if layer_choices and "committed_objects" in layer_choices: # Check at 'commit' call button. + widgets._validation_window_for_missing_layer("committed_objects") self._viewer.add_labels(data=dummy_data, name="committed_objects") # Randomize colors so it is easy to see when object committed. self._viewer.layers["committed_objects"].new_colormap() + if image_scale is not None: + self.layers["committed_objects"].scale = image_scale # Add the point layer for point prompts. self._point_labels = ["positive", "negative"] @@ -70,7 +92,7 @@ def _create_widgets(self): # Create the prompt widget. (The same for all plugins.) self._prompt_widget = widgets.create_prompt_menu(self._point_prompt_layer, self._point_labels) - # Create the dictionray for the widgets and get the widgets of the child plugin. + # Create the dictionary for the widgets and get the widgets of the child plugin. self._widgets = {"embeddings": self._embedding_widget, "prompts": self._prompt_widget} self._widgets.update(self._get_widgets()) @@ -131,7 +153,7 @@ def __init__(self, viewer: "napari.viewer.Viewer", ndim: int) -> None: # Initialize with a dummy shape, which is reset to the correct shape once an image is set. self._ndim = ndim self._shape = (256, 256) if ndim == 2 else (16, 256, 256) - self._create_layers() + self._require_layers() # Create all the widgets and add them to the layout. self._create_widgets() @@ -179,6 +201,9 @@ def _update_image(self, segmentation_result=None): ) self._shape = state.image_shape + # Before we reset the layers, we ensure all expected layers exist. + self._require_layers() + # Update the image scale. scale = state.image_scale @@ -187,12 +212,15 @@ def _update_image(self, segmentation_result=None): self._viewer.layers["current_object"].scale = scale self._viewer.layers["auto_segmentation"].data = np.zeros(self._shape, dtype="uint32") self._viewer.layers["auto_segmentation"].scale = scale + if segmentation_result is None or segmentation_result is False: self._viewer.layers["committed_objects"].data = np.zeros(self._shape, dtype="uint32") else: assert segmentation_result.shape == self._shape self._viewer.layers["committed_objects"].data = segmentation_result self._viewer.layers["committed_objects"].scale = scale + self._viewer.layers["point_prompts"].scale = scale self._viewer.layers["prompts"].scale = scale + vutil.clear_annotations(self._viewer, clear_segmentations=False) diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py index 09f5e9870..13ef8abca 100644 --- a/micro_sam/sam_annotator/_state.py +++ b/micro_sam/sam_annotator/_state.py @@ -14,6 +14,7 @@ import torch.nn as nn +import micro_sam import micro_sam.util as util from micro_sam.instance_segmentation import AMGBase, get_decoder from micro_sam.precompute_state import cache_amg_state, cache_is_state @@ -69,6 +70,9 @@ class AnnotatorState(metaclass=Singleton): # z-range to limit the data being committed in 3d / tracking. z_range: Optional[Tuple[int, int]] = None + # annotator_class + annotator: Optional["micro_sam.sam_annotator._annotator._AnnotatorBase"] = None + def initialize_predictor( self, image_data, diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index a840d49bb..7736e3e1a 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -30,9 +30,9 @@ # from napari.qt.threading import thread_worker from napari.utils import progress -from ._state import AnnotatorState from . import util as vutil from ._tooltips import get_tooltip +from ._state import AnnotatorState from .. import instance_segmentation, util from ..multi_dimensional_segmentation import ( segment_mask_in_volume, merge_instance_segmentation_3d, track_across_frames, PROJECTION_MODES, get_napari_track_data @@ -496,8 +496,12 @@ def _mask_matched_objects(seg, prev_seg, preservation_threshold): def _commit_impl(viewer, layer, preserve_mode, preservation_threshold): - # Check if we have a z_range. If yes, use it to set a bounding box. state = AnnotatorState() + + # Check whether all layers exist as expected or create new ones automatically. + state.annotator._require_layers(layer_choices=[layer, "committed_objects"]) + + # Check if we have a z_range. If yes, use it to set a bounding box. if state.z_range is None: bb = np.s_[:] else: @@ -550,8 +554,8 @@ def _get_auto_segmentation_options(state, object_ids): segmentation_options = {"object_ids": [int(object_id) for object_id in object_ids]} if widget.with_decoder: - segmentation_options["boundary_distance_thresh"] = widget.boundary_distance_thresh - segmentation_options["center_distance_thresh"] = widget.center_distance_thresh + segmentation_options["boundary_distance_threshold"] = widget.boundary_distance_thresh + segmentation_options["center_distance_threshold"] = widget.center_distance_thresh else: segmentation_options["pred_iou_thresh"] = widget.pred_iou_thresh segmentation_options["stability_score_thresh"] = widget.stability_score_thresh @@ -582,18 +586,28 @@ def _get_promptable_segmentation_options(state, object_ids): return segmentation_options, is_tracking -def _commit_to_file(path, viewer, layer, seg, mask, bb, extra_attrs=None): +def _get_preservation_settings(state): + widget = state.widgets["commit"] + return { + "preserve_mode": widget.preserve_mode.value, + "preservation_threshold": widget.preservation_threshold.value, + } - # NOTE: zarr-python is quite inefficient and writes empty blocks. - # So we have to use z5py here. - # Deal with issues z5py has with empty folders and require the json. - if os.path.exists(path): - required_json = os.path.join(path, ".zgroup") +# Deal with issues z5py has with empty folders and require the json with zarr group metadata. +def _require_zarr_group_metadata(path, group_name=None): + path_ = path if group_name is None else os.path.join(path, group_name) + if os.path.exists(path_): + required_json = os.path.join(path_, ".zgroup") if not os.path.exists(required_json): with open(required_json, "w") as f: json.dump({"zarr_format": 2}, f) + +def _commit_to_file(path, viewer, layer, seg, mask, bb, extra_attrs=None): + + # NOTE: zarr-python is quite inefficient and writes empty blocks. So we have to use z5py here. + _require_zarr_group_metadata(path) f = z5py.ZarrFile(path, "a") state = AnnotatorState() @@ -610,6 +624,9 @@ def _save_signature(f, data_signature): for key, val in signature.items(): f.attrs[key] = val + # Add the annotator type to the signature. + f.attrs["annotator_class"] = state.annotator.__class__.__name__ + # If the data signature is saved in the file already, # then we check if saved data signature and data signature of our image agree. # If not, this file was used for committing objects from another file. @@ -650,10 +667,14 @@ def _save_signature(f, data_signature): commit_history = f.attrs.get("commit_history", []) object_ids = np.unique(seg[mask]) + # Get the preservation settings from the commit widget. + preservation_settings = _get_preservation_settings(state) + # We committed an automatic segmentation. if layer == "auto_segmentation": # Save the settings of the segmentation widget. segmentation_options = _get_auto_segmentation_options(state, object_ids) + segmentation_options.update(**preservation_settings) commit_history.append({"auto_segmentation": segmentation_options}) # Write the commit history. @@ -664,10 +685,14 @@ def _save_signature(f, data_signature): return segmentation_options, is_tracking = _get_promptable_segmentation_options(state, object_ids) + segmentation_options.update(**preservation_settings) commit_history.append({"current_object": segmentation_options}) + # TODO add support for mask prompt (e.g. from polygon layer) def write_prompts(object_id, prompts, point_prompts, point_labels, track_state=None): - g = f.create_group(f"prompts/{object_id}") + group_name = f"prompts/{object_id}" + g = f.create_group(group_name) + _require_zarr_group_metadata(path, group_name) if prompts is not None and len(prompts) > 0: data = np.array(prompts) g.create_dataset("prompts", data=data, chunks=data.shape) @@ -750,6 +775,7 @@ def commit( commit_path: Select a file path where the committed results and prompts will be saved. This feature is still experimental. """ + # Commit the segmentation layer. _, seg, mask, bb = _commit_impl(viewer, layer, preserve_mode, preservation_threshold) if commit_path is not None: @@ -964,12 +990,27 @@ def _validate_embeddings(viewer: "napari.viewer.Viewer"): # return False -def _validate_prompts(viewer: "napari.viewer.Viewer") -> bool: - if len(viewer.layers["prompts"].data) == 0 and len(viewer.layers["point_prompts"].data) == 0: - msg = "No prompts were given. Please provide prompts to run interactive segmentation." - return _generate_message("error", msg) +def _validation_window_for_missing_layer(layer_choice): + if layer_choice == "committed_objects": + msg = "The 'committed_objects' layer to commit masks is missing. Please try to commit again." else: - return False + msg = f"The '{layer_choice}' layer to commit is missing. Please re-annotate and try again." + + return _generate_message(message_type="error", message=msg) + + +def _validate_layers(viewer: "napari.viewer.Viewer", automatic_segmentation: bool = False) -> bool: + # Check whether all layers exist as expected or create new ones automatically. + state = AnnotatorState() + state.annotator._require_layers() + + if not automatic_segmentation: + # Check prompts layer. + if len(viewer.layers["prompts"].data) == 0 and len(viewer.layers["point_prompts"].data) == 0: + msg = "No prompts were given. Please provide prompts to run interactive segmentation." + return _generate_message("error", msg) + else: + return False @magic_factory(call_button="Segment Object [S]") @@ -982,7 +1023,7 @@ def segment(viewer: "napari.viewer.Viewer", batched: bool = False) -> None: """ if _validate_embeddings(viewer): return None - if _validate_prompts(viewer): + if _validate_layers(viewer): return None shape = viewer.layers["current_object"].data.shape @@ -1016,7 +1057,7 @@ def segment_slice(viewer: "napari.viewer.Viewer") -> None: """ if _validate_embeddings(viewer): return None - if _validate_prompts(viewer): + if _validate_layers(viewer): return None shape = viewer.layers["current_object"].data.shape[1:] @@ -1057,8 +1098,9 @@ def segment_frame(viewer: "napari.viewer.Viewer") -> None: """ if _validate_embeddings(viewer): return None - if _validate_prompts(viewer): + if _validate_layers(viewer): return None + state = AnnotatorState() shape = state.image_shape[1:] position = viewer.dims.point @@ -1626,8 +1668,9 @@ def update_segmentation(seg): def __call__(self): if _validate_embeddings(self._viewer): return None - if _validate_prompts(self._viewer): + if _validate_layers(self._viewer): return None + if self.tracking: return self._run_tracking() else: @@ -1889,6 +1932,9 @@ def update_segmentation(seg): self._viewer.layers["auto_segmentation"].data[i] = seg self._viewer.layers["auto_segmentation"].refresh() + # Validate all layers. + _validate_layers(self._viewer, automatic_segmentation=True) + seg = seg_impl() update_segmentation(seg) # worker = seg_impl() diff --git a/micro_sam/sam_annotator/annotator_2d.py b/micro_sam/sam_annotator/annotator_2d.py index 5afbc63ef..d3f4b3e55 100644 --- a/micro_sam/sam_annotator/annotator_2d.py +++ b/micro_sam/sam_annotator/annotator_2d.py @@ -24,9 +24,18 @@ def _get_widgets(self): "clear": widgets.clear(), } - def __init__(self, viewer: "napari.viewer.Viewer") -> None: + def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: super().__init__(viewer=viewer, ndim=2) + # Set the expected annotator class to the state. + state = AnnotatorState() + + # Reset the state. + if reset_state: + state.reset_state() + + state.annotator = self + def annotator_2d( image: np.ndarray, @@ -85,7 +94,7 @@ def annotator_2d( viewer = napari.Viewer() viewer.add_image(image, name="image") - annotator = Annotator2d(viewer) + annotator = Annotator2d(viewer, reset_state=False) # Trigger layer update of the annotator so that layers have the correct shape. # And initialize the 'committed_objects' with the segmentation result if it was given. diff --git a/micro_sam/sam_annotator/annotator_3d.py b/micro_sam/sam_annotator/annotator_3d.py index 773653b8c..6a8ea79b5 100644 --- a/micro_sam/sam_annotator/annotator_3d.py +++ b/micro_sam/sam_annotator/annotator_3d.py @@ -24,10 +24,19 @@ def _get_widgets(self): "clear": widgets.clear_volume(), } - def __init__(self, viewer: "napari.viewer.Viewer") -> None: + def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: self._with_decoder = AnnotatorState().decoder is not None super().__init__(viewer=viewer, ndim=3) + # Set the expected annotator class to the state. + state = AnnotatorState() + + # Reset the state. + if reset_state: + state.reset_state() + + state.annotator = self + def _update_image(self, segmentation_result=None): super()._update_image(segmentation_result=segmentation_result) # Load the amg state from the embedding path. @@ -95,7 +104,7 @@ def annotator_3d( viewer = napari.Viewer() viewer.add_image(image, name="image") - annotator = Annotator3d(viewer) + annotator = Annotator3d(viewer, reset_state=False) # Trigger layer update of the annotator so that layers have the correct shape. # And initialize the 'committed_objects' with the segmentation result if it was given. diff --git a/micro_sam/sam_annotator/annotator_tracking.py b/micro_sam/sam_annotator/annotator_tracking.py index b9bc82252..6388e676d 100644 --- a/micro_sam/sam_annotator/annotator_tracking.py +++ b/micro_sam/sam_annotator/annotator_tracking.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, List import napari import numpy as np @@ -21,26 +21,40 @@ # This solution is a bit hacky, so I won't move it to _widgets.py yet. -def create_tracking_menu(points_layer, box_layer, states, track_ids): +def create_tracking_menu(points_layer, box_layer, states, track_ids, tracking_widget=None): """@private""" state = AnnotatorState() - state_menu = ComboBox(label="track_state", choices=states, tooltip=get_tooltip("annotator_tracking", "track_state")) - track_id_menu = ComboBox( - label="track_id", choices=list(map(str, track_ids)), tooltip=get_tooltip("annotator_tracking", "track_id") - ) - tracking_widget = Container(widgets=[state_menu, track_id_menu]) + def _get_widget_menu(container, label): + for w in container: + if isinstance(w, ComboBox) and w.label == label: + return w + raise ValueError(f"ComboBox with label '{label}' not found.") + + if tracking_widget is None: + state_menu = ComboBox( + label="track_state", choices=states, tooltip=get_tooltip("annotator_tracking", "track_state") + ) + track_id_menu = ComboBox( + label="track_id", choices=list(map(str, track_ids)), tooltip=get_tooltip("annotator_tracking", "track_id") + ) + tracking_widget = Container(widgets=[state_menu, track_id_menu]) + else: + state_menu = _get_widget_menu(tracking_widget, "track_state") + track_id_menu = _get_widget_menu(tracking_widget, "track_id") def update_state(event): - new_state = str(points_layer.current_properties["state"][0]) - if new_state != state_menu.value: - state_menu.value = new_state + if "state" in points_layer.current_properties: + new_state = str(points_layer.current_properties["state"][0]) + if new_state != state_menu.value: + state_menu.value = new_state def update_track_id(event): - new_id = str(points_layer.current_properties["track_id"][0]) - if new_id != track_id_menu.value: - track_id_menu.value = new_id - state.current_track_id = int(new_id) + if "track_id" in points_layer.current_properties: + new_id = str(points_layer.current_properties["track_id"][0]) + if new_id != track_id_menu.value: + track_id_menu.value = new_id + state.current_track_id = int(new_id) # def update_state_boxes(event): # new_state = str(box_layer.current_properties["state"][0]) @@ -48,10 +62,11 @@ def update_track_id(event): # state_menu.value = new_state def update_track_id_boxes(event): - new_id = str(box_layer.current_properties["track_id"][0]) - if new_id != track_id_menu.value: - track_id_menu.value = new_id - state.current_track_id = int(new_id) + if "track_id" in box_layer.current_properties: + new_id = str(box_layer.current_properties["track_id"][0]) + if new_id != track_id_menu.value: + track_id_menu.value = new_id + state.current_track_id = int(new_id) points_layer.events.current_properties.connect(update_state) points_layer.events.current_properties.connect(update_track_id) @@ -101,59 +116,131 @@ class AnnotatorTracking(_AnnotatorBase): # The tracking annotator needs different settings for the prompt layers # to support the additional tracking state. # That's why we over-ride this function. - def _create_layers(self): - self._point_labels = ["positive", "negative"] - self._track_state_labels = ["track", "division"] + def _require_layers(self, layer_choices: Optional[List[str]] = None): - self._point_prompt_layer = self._viewer.add_points( - name="point_prompts", - property_choices={ - "label": self._point_labels, - "state": self._track_state_labels, - "track_id": ["1"], # we use string to avoid pandas warning - }, - border_color="label", - border_color_cycle=vutil.LABEL_COLOR_CYCLE, - symbol="o", - face_color="state", - face_color_cycle=STATE_COLOR_CYCLE, - border_width=0.4, - size=12, - ndim=self._ndim, - ) - self._point_prompt_layer.border_color_mode = "cycle" - self._point_prompt_layer.face_color_mode = "cycle" - - # Using the box layer to set divisions currently doesn't work. - # That's why some of the code below is commented out. - self._box_prompt_layer = self._viewer.add_shapes( - shape_type="rectangle", - edge_width=4, - ndim=self._ndim, - face_color="transparent", - name="prompts", - edge_color="green", - property_choices={"track_id": ["1"]}, - # property_choces={"track_id": ["1"], "state": self._track_state_labels}, - # edge_color_cycle=STATE_COLOR_CYCLE, - ) - # self._box_prompt_layer.edge_color_mode = "cycle" + # Check whether the image is initialized already. And use the image shape and scale for the layers. + state = AnnotatorState() + shape = self._shape if state.image_shape is None else state.image_shape # Add the label layers for the current object, the automatic segmentation and the committed segmentation. - dummy_data = np.zeros(self._shape, dtype="uint32") - self._viewer.add_labels(data=dummy_data, name="current_object") - self._viewer.add_labels(data=dummy_data, name="auto_segmentation") - self._viewer.add_labels(data=dummy_data, name="committed_objects") - # Randomize colors so it is easy to see when object committed. - self._viewer.layers["committed_objects"].new_colormap() + dummy_data = np.zeros(shape, dtype="uint32") + image_scale = state.image_scale + + # Before adding new layers, we always check whether a layer with this name already exists or not. + if "current_object" not in self._viewer.layers: + if layer_choices and "current_object" in layer_choices: # Check at 'commit' call button. + widgets._validation_window_for_missing_layer("current_object") + self._viewer.add_labels(data=dummy_data, name="current_object") + if image_scale is not None: + self.layers["current_objects"].scale = image_scale + + if "auto_segmentation" not in self._viewer.layers: + if layer_choices and "auto_segmentation" in layer_choices: # Check at 'commit' call button. + widgets._validation_window_for_missing_layer("auto_segmentation") + self._viewer.add_labels(data=dummy_data, name="auto_segmentation") + if image_scale is not None: + self.layers["auto_segmentation"].scale = image_scale + + if "committed_objects" not in self._viewer.layers: + if layer_choices and "committed_objects" in layer_choices: # Check at 'commit' call button. + widgets._validation_window_for_missing_layer("committed_objects") + self._viewer.add_labels(data=dummy_data, name="committed_objects") + # Randomize colors so it is easy to see when object committed. + self._viewer.layers["committed_objects"].new_colormap() + if image_scale is not None: + self.layers["committed_objects"].scale = image_scale + + # Add the point prompts layer. + self._point_labels = ["positive", "negative"] + self._track_state_labels = ["track", "division"] + _point_prompt_property_choices = { + "label": self._point_labels, + "state": self._track_state_labels, + "track_id": ["1"], # we use string to avoid pandas warning + } + + point_layer_mismatch = True + if "point_prompts" in self._viewer.layers: + # Check whether the 'property_choices' match or not. + curr_property_choices = self._viewer.layers["point_prompts"].property_choices + point_layer_mismatch = set(curr_property_choices.keys()) != set(_point_prompt_property_choices.keys()) + + if point_layer_mismatch and "point_prompts" not in self._viewer.layers: + self._point_prompt_layer = self._viewer.add_points( + name="point_prompts", + property_choices=_point_prompt_property_choices, + border_color="label", + border_color_cycle=vutil.LABEL_COLOR_CYCLE, + symbol="o", + face_color="state", + face_color_cycle=STATE_COLOR_CYCLE, + border_width=0.4, + size=12, + ndim=self._ndim, + ) + self._point_prompt_layer.border_color_mode = "cycle" + self._point_prompt_layer.face_color_mode = "cycle" + _new_point_layer = True + else: + self._point_prompt_layer = self._viewer.layers["point_prompts"] + _new_point_layer = False + + # Add the point prompts layer. + _box_prompt_property_choices = {"track_id": ["1"]} + + box_layer_mismatch = True + if "prompts" in self._viewer.layers: + # Check whether the 'property_choices' match or not. + curr_property_choices = self._viewer.layers["prompts"].property_choices + box_layer_mismatch = set(curr_property_choices.keys()) != set(_box_prompt_property_choices.keys()) + + if box_layer_mismatch and "prompts" not in self._viewer.layers: + # Using the box layer to set divisions currently doesn't work. + # That's why some of the code below is commented out. + self._box_prompt_layer = self._viewer.add_shapes( + shape_type="rectangle", + edge_width=4, + ndim=self._ndim, + face_color="transparent", + name="prompts", + edge_color="green", + property_choices=_box_prompt_property_choices, + # property_choices={"track_id": ["1"], "state": self._track_state_labels}, + # edge_color_cycle=STATE_COLOR_CYCLE, + ) + # self._box_prompt_layer.edge_color_mode = "cycle" + _new_box_layer = True + else: + self._box_prompt_layer = self._viewer.layers["prompts"] + _new_box_layer = False + + # Trigger a new connection for the tracking state menu only when a new layer is (re)created. + if _new_point_layer or _new_box_layer: + self._tracking_widget = create_tracking_menu( + points_layer=self._point_prompt_layer, + box_layer=self._box_prompt_layer, + states=self._track_state_labels, + track_ids=list(state.lineage.keys()), + tracking_widget=state.widgets.get("tracking"), + ) + state.widgets["tracking"] = self._tracking_widget def _get_widgets(self): state = AnnotatorState() + self._require_layers() + # Create the tracking state menu. - self._tracking_widget = create_tracking_menu( - self._point_prompt_layer, self._box_prompt_layer, - states=self._track_state_labels, track_ids=list(state.lineage.keys()), - ) + # NOTE: Check whether it exists already from `_require_layers` or needs to be created. + if state.widgets.get("tracking") is None: + self._tracking_widget = create_tracking_menu( + ponts_layer=self._point_prompt_layer, + box_layer=self._box_prompt_layer, + states=self._track_state_labels, + track_ids=list(state.lineage.keys()), + ) + else: + self._tracking_widget = state.widgets.get("tracking") + segment_nd = widgets.SegmentNDWidget(self._viewer, tracking=True) autotrack = widgets.AutoTrackWidget(self._viewer, with_decoder=self._with_decoder, volumetric=True) return { @@ -165,7 +252,7 @@ def _get_widgets(self): "clear": widgets.clear_track(), } - def __init__(self, viewer: "napari.viewer.Viewer") -> None: + def __init__(self, viewer: "napari.viewer.Viewer", reset_state: bool = True) -> None: # Initialize the state for tracking. self._init_track_state() self._with_decoder = AnnotatorState().decoder is not None @@ -173,6 +260,15 @@ def __init__(self, viewer: "napari.viewer.Viewer") -> None: # Go to t=0. self._viewer.dims.current_step = (0, 0, 0) + tuple(sh // 2 for sh in self._shape[1:]) + # Set the expected annotator class to the state. + state = AnnotatorState() + + # Reset the state. + if reset_state: + state.reset_state() + + state.annotator = self + def _init_track_state(self): state = AnnotatorState() state.current_track_id = 1 @@ -239,7 +335,7 @@ def annotator_tracking( viewer = napari.Viewer() viewer.add_image(image, name="image") - annotator = AnnotatorTracking(viewer) + annotator = AnnotatorTracking(viewer, reset_state=False) # Trigger layer update of the annotator so that layers have the correct shape. annotator._update_image() diff --git a/micro_sam/sam_annotator/reproducibility.py b/micro_sam/sam_annotator/reproducibility.py new file mode 100644 index 000000000..73e537c9f --- /dev/null +++ b/micro_sam/sam_annotator/reproducibility.py @@ -0,0 +1,318 @@ +import os +import warnings +from typing import Optional, Union + +import numpy as np +from elf.io import open_file +from tqdm import tqdm + +from .. import util +from ..automatic_segmentation import automatic_tracking, automatic_instance_segmentation, get_predictor_and_segmenter +from ..multi_dimensional_segmentation import segment_mask_in_volume + +from ._widgets import _mask_matched_objects +from .annotator_2d import annotator_2d +from .annotator_3d import annotator_3d +# from .annotator_tracking import annotator_tracking +from .util import prompt_segmentation, segment_slices_with_prompts + + +def _load_model_from_commit_file(f): + # Check which segmentation mode is used, by going through the commit history + # and checking if we have committed the 'auto_segmentation' layer. + # If we did, then we derive which mode was used from the serialized parameters. + amg = None + commit_history = f.attrs["commit_history"] + for commit in commit_history: + layer, options = next(iter(commit.items())) + if layer == "auto_segmentation": + amg = not ("boundary_distance_threshold" in options) + + # Get the predictor and segmenter. + predictor, segmenter = get_predictor_and_segmenter( + model_type=f.attrs["model_name"], + # checkpoint="", TODO we need to also serialize this + amg=amg, + is_tiled=f.attrs["tile_shape"] is not None, + ) + return predictor, segmenter + + +def _check_data_hash(f, input_data, input_path): + data_hash = util._compute_data_signature(input_data) + expected_hash = f.attrs["data_signature"] + if data_hash != expected_hash: + raise RuntimeError( + f"The hash of the input data loaded from {input_path} is {data_hash}, " + f"which does not match the expected hash {expected_hash}." + ) + + +def _write_masks_with_preservation(prev_seg, seg, preserve_mode, preservation_threshold, object_ids): + # Make sure the segmented ids and the object ids match (warn if not), + # and apply the id offset to match them. + segmented_ids = np.setdiff1d(np.unique(seg), [0]) + if len(segmented_ids) != len(object_ids): + warnings.warn( + f"The number of objects found by running auto segmentation (={len(segmented_ids)})" + f" does not match the number of expected objects (={len(object_ids)})." + ) + id_offset = int(np.min(object_ids)) - 1 + seg[seg != 0] += id_offset + + # Write the new segmentation to the previous segmentation, taking the preservation rules into account. + mask = seg != 0 + if preserve_mode != "none": + preserve_mask = prev_seg != 0 + if preserve_mask.sum() != 0: + # In the mode 'objects' we preserve committed objects instead, by comparing the overlaps + # of already committed and newly committed objects. + if preserve_mode == "objects": + preserve_mask = _mask_matched_objects(seg, prev_seg, preservation_threshold) + mask[preserve_mask] = 0 + prev_seg[mask] = seg[mask] + return prev_seg + + +def _rerun_interactive_segmentation(segmentation, f, predictor, image_embeddings, annotator_class, options): + object_ids = options.pop("object_ids") + preserve_mode, preservation_threshold = options.pop("preserve_mode"), options.pop("preservation_threshold") + g = f["prompts"] + + # Load the serialized prompts for all objects in this commit. + boxes, masks = [], [] + points, labels = [], [] + for object_id in object_ids: + prompt_group = g[str(object_id)] + if "point_prompts" in prompt_group: + points.append(prompt_group["point_prompts"][:]) + labels.append(prompt_group["point_labels"][:]) + if "prompts" in prompt_group: + boxes.append(prompt_group["prompts"][:]) + # We can only have a mask if we also have a box prompt. + if "mask" in prompt_group: + masks.append(prompt_group["mask"][:]) + else: + masks.append(None) + + if points: + points = np.concatenate(points, axis=0) + labels = np.concatenate(labels, axis=0) + + if annotator_class == "Annotator2d": + + if boxes: + # Map boxes to the correct input format. + boxes = np.concatenate(boxes, axis=0) + boxes = [ + np.array([box[:, 0].min(), box[:, 1].min(), box[:, 0].max(), box[:, 1].max()]) for box in boxes + ] + + batched = len(object_ids) > 1 + seg = prompt_segmentation( + predictor, points, labels, boxes, masks, segmentation.shape, image_embeddings=image_embeddings, + multiple_box_prompts=True, batched=batched, previous_segmentation=segmentation, + ).astype("uint32") + + elif annotator_class == "Annotator3d": + pass + # TODO this needs to be updated + # seg, slices, stop_lower, stop_upper = segment_slices_with_prompts( + # state.predictor, self._viewer.layers["point_prompts"], self._viewer.layers["prompts"], + # image_embeddings, shape, + # ) + # seg, (z_min, z_max) = segment_mask_in_volume( + # seg, predictor, image_embeddings, slices, + # stop_lower, stop_upper, + # iou_threshold=self.iou_threshold, projection=self.projection, + # box_extension=self.box_extension, + # update_progress=lambda update: pbar_signals.pbar_update.emit(update), + # ) + + elif annotator_class == "AnnotatorTracking": + raise NotImplementedError("Not yet implemented for AnnotatorTracking.") + + else: + raise RuntimeError(f"Invalid annotator class {annotator_class}.") + + return _write_masks_with_preservation(segmentation, seg, preserve_mode, preservation_threshold, object_ids) + + +def _rerun_automatic_segmentation( + image, segmentation, predictor, segmenter, image_embeddings, annotator_class, tile_shape, halo, options +): + object_ids = options.pop("object_ids") + preserve_mode, preservation_threshold = options.pop("preserve_mode"), options.pop("preservation_threshold") + + # If there was nothing committed then we don't need to rerun the automatic segmentation. + if len(object_ids) == 0: + return segmentation + + if annotator_class in ("Annotator2d", "Annotator3d"): + ndim = 2 if annotator_class == "Annotator2d" else 3 + seg = automatic_instance_segmentation( + predictor, segmenter, image, + embedding_path=image_embeddings, + tile_shape=tile_shape, halo=halo, ndim=ndim, + **options + ) + elif annotator_class == "AnnotatorTracking": + seg, lineages = automatic_tracking( + predictor, segmenter, image, + embedding_path=image_embeddings, + tile_shape=tile_shape, halo=halo, ndim=ndim, + **options + ) + + else: + raise RuntimeError(f"Invalid annotator class {annotator_class}.") + + return _write_masks_with_preservation(segmentation, seg, preserve_mode, preservation_threshold, object_ids) + + +def rerun_segmentation_from_commit_file( + commit_file: Union[str, os.PathLike], + input_path: Union[str, os.PathLike, np.ndarray], + input_key: Optional[str] = None, + embedding_path: Optional[Union[str, os.PathLike]] = None, +) -> np.ndarray: + """Rerun a segmentation from the commit history of a commit file. + + Args: + commit_file: The path to the zarr file storing the commit history. + input_path: The path to the image data for the respective micro_sam commit history. + input_key: The key for the image data, in case it is a zarr, n5, hdf5 file or similar. + embedding_path: The path to precomputed embeddings for this project. + + Returns: + The segmentation recreated from the commit history. + """ + # Load the image data and open the zarr commit file. + input_data = util.load_image_data(input_path, key=input_key) + with open_file(commit_file, mode="r") as f: + + # Get the annotator class. + if "annotator_class" not in f.attrs: + raise RuntimeError( + f"You have saved the {commit_file} in a version that does not yet support rerunning the segmentation." + ) + annotator_class = f.attrs["annotator_class"] + ndim = 2 if annotator_class == "Annotator2d" else 3 + + # Get the tile shape and halo from the attributes. + tile_shape, halo = f.attrs["tile_shape"], f.attrs["halo"] + + # Check that the stored data hash and the input data hash match. + _check_data_hash(f, input_data, input_path) + + # Load the model according to the model description stored in the commit file. + predictor, segmenter = _load_model_from_commit_file(f) + + # Compute oder load the image embeddings. + image_embeddings = util.precompute_image_embeddings( + predictor=predictor, + input_=input_data, + save_path=embedding_path, + ndim=ndim, + tile_shape=tile_shape, + halo=halo, + ) + + # Go through the commit history and redo the action of each commit. + # Actions can be: + # - Committing an automatic segmentation result. + # - Committing an interactive segmentation result. + commit_history = f.attrs["commit_history"] + + # Rerun the commit history. + # TODO check if this works correctly for 3d data + shape = image_embeddings["original_size"] + segmentation = np.zeros(shape, dtype="uint32") + for commit in tqdm(commit_history, desc="Rerunning commit history"): + layer, options = next(iter(commit.items())) + if layer == "current_object": + segmentation = _rerun_interactive_segmentation( + segmentation, f, predictor, image_embeddings, annotator_class, tile_shape, halo, options + ) + elif layer == "auto_segmentation": + segmentation = _rerun_automatic_segmentation( + input_data, segmentation, predictor, segmenter, image_embeddings, + annotator_class, tile_shape, halo, options + ) + else: + raise RuntimeError(f"Invalid layer {layer} in commit_historty.") + + return segmentation + + +def load_committed_objects_from_commit_file(commit_file: Union[str, os.PathLike]) -> np.ndarray: + """ + Args: + commit_file: The path to the zarr file storing the commit history. + + Returns: + The committed segmentation. + """ + with open_file(commit_file, mode="r") as f: + return f["committed_objects"][:] + + +def continue_annotation( + commit_file: Union[str, os.PathLike], + input_path: Union[str, os.PathLike], + input_key: Optional[str] = None, + embedding_path: Optional[Union[str, os.PathLike]] = None, +) -> None: + """Start an annotator from a commit file and set it to the commited state. + + This currently does not support files committed with annotator_tracking. + + Args: + commit_file: The path to the zarr file storing the commit history. + input_path: The path to the image data for the respective micro_sam commit history. + input_key: The key for the image data, in case it is a zarr, n5, hdf5 file or similar. + embedding_path: The path to precomputed embeddings for this project. + """ + committed_objects = load_committed_objects_from_commit_file(commit_file) + with open_file(commit_file, mode="r") as f: + if "annotator_class" not in f.attrs: + raise RuntimeError( + f"You have saved {commit_file} in a version that does not support continuing the annotation." + ) + annotator_class = f.attrs["annotator_class"] + model_type = f.attrs["model_name"] + tile_shape = f.attrs["tile_shape"] + halo = f.attrs["halo"] + + input_data = util.load_image_data(input_path, key=input_key) + if annotator_class == "Annotator2d": + annotator_2d( + input_data, embedding_path=embedding_path, segmentation_result=committed_objects, + model_type=model_type, tile_shape=tile_shape, halo=halo, + ) + elif annotator_class == "Annotator3d": + annotator_3d( + input_data, embedding_path=embedding_path, segmentation_result=committed_objects, + model_type=model_type, tile_shape=tile_shape, halo=halo, + ) + # We need to implement initialization of the tracking annotator with a segmentation + tracking state for this. + elif annotator_class == "AnnotatorTracking": + raise NotImplementedError("'continue_annotation_from_commit_file' is not yet supported for AnnotatorTracking.") + else: + raise RuntimeError(f"Invalid annotator class {annotator_class}.") + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Start an annotator from a commit file and set it to the commited state." + ) + parser.add_argument("-c", "--commit_file", required=True, help="The zarr file with the commit history.") + parser.add_argument("-i", "--input_path", required=True, help="The file path of the image data.") + parser.add_argument( + "-k", "--input_key", help="The input key for the image data. Required for zarr, n5 or hdf5 files." + ) + parser.add_argument("-e", "--embedding_path", help="Optional file path for precomputed embeddings.") + args = parser.parse_args() + continue_annotation(args.commit_file, args.input_path, args.input_key, args.embedding_path) diff --git a/micro_sam/training/__init__.py b/micro_sam/training/__init__.py index 576d72ce8..2fae39f1f 100644 --- a/micro_sam/training/__init__.py +++ b/micro_sam/training/__init__.py @@ -6,4 +6,7 @@ from .joint_sam_trainer import JointSamTrainer, JointSamLogger from .simple_sam_trainer import SimpleSamTrainer, MedSAMTrainer from .semantic_sam_trainer import SemanticSamTrainer, SemanticMapsSamTrainer -from .training import train_sam, train_sam_for_configuration, default_sam_loader, default_sam_dataset, CONFIGURATIONS +from .training import ( + train_sam, train_sam_for_configuration, train_instance_segmentation, default_sam_loader, default_sam_dataset, + export_instance_segmentation_model, CONFIGURATIONS, +) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 292e874d8..568afe828 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -2,7 +2,7 @@ import time import warnings from glob import glob -from tqdm import tqdm +from collections import OrderedDict from contextlib import contextmanager, nullcontext from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -13,6 +13,7 @@ from torch.utils.data import random_split from torch.utils.data import DataLoader, Dataset from torch.optim.lr_scheduler import _LRScheduler +from tqdm import tqdm import torch_em from torch_em.util import load_data @@ -28,7 +29,7 @@ from . import sam_trainer as trainers from ..instance_segmentation import get_unetr from . import joint_sam_trainer as joint_trainers -from ..util import get_device, get_model_names, export_custom_sam_model +from ..util import get_device, get_model_names, export_custom_sam_model, get_sam_model from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit, get_raw_transform @@ -165,6 +166,32 @@ def _count_parameters(model_parameters): print(f"The number of trainable parameters for the provided model is {params} (~{round(params, 2)}M)") +def _get_trainer_fit_params(n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training): + if n_iterations is None: + trainer_fit_params = {"epochs": n_epochs} + else: + trainer_fit_params = {"iterations": n_iterations} + + if save_every_kth_epoch is not None: + trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch + + if pbar_signals is not None: + progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) + trainer_fit_params["progress"] = progress_bar_wrapper + + # Avoid overwriting a trained model, if desired by the user. + trainer_fit_params["overwrite_training"] = overwrite_training + return trainer_fit_params + + +def _get_optimizer_and_scheduler(model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs): + optimizer = optimizer_class(model_params, lr=lr) + if scheduler_kwargs is None: + scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} + scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) + return optimizer, scheduler + + def train_sam( name: str, model_type: str, @@ -223,8 +250,7 @@ def train_sam( If passed None, the chosen default parameters are used in ReduceLROnPlateau. save_every_kth_epoch: Save checkpoints after every kth epoch separately. pbar_signals: Controls for napari progress bar. - optimizer_class: The optimizer class. - By default, torch.optim.AdamW is used. + optimizer_class: The optimizer class. By default, torch.optim.AdamW is used. peft_kwargs: Keyword arguments for the PEFT wrapper class. ignore_warnings: Whether to ignore raised warnings. verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. @@ -277,12 +303,9 @@ def train_sam( else: model_params = model.parameters() - optimizer = optimizer_class(model_params, lr=lr) - - if scheduler_kwargs is None: - scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} - - scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) + optimizer, scheduler = _get_optimizer_and_scheduler( + model_params, lr, optimizer_class, scheduler_class, scheduler_kwargs + ) # The trainer which performs training and validation. if with_segmentation_decoder: @@ -330,21 +353,168 @@ def train_sam( save_root=save_root, ) - if n_iterations is None: - trainer_fit_params = {"epochs": n_epochs} - else: - trainer_fit_params = {"iterations": n_iterations} + trainer_fit_params = _get_trainer_fit_params( + n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training + ) + trainer.fit(**trainer_fit_params) + + t_run = time.time() - t_start + hours = int(t_run // 3600) + minutes = int(t_run // 60) + seconds = int(round(t_run % 60, 0)) + print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)") + + +def export_instance_segmentation_model( + trained_model_path: Union[str, os.PathLike], + output_path: Union[str, os.PathLike], + model_type: str, + initial_checkpoint_path: Optional[Union[str, os.PathLike]] = None, +) -> None: + """Export a model trained for instance segmentation with `train_instance_segmentation`. + + The exported model will be compatible with the micro_sam functions, CLI and napari plugin. + It should only be used for automatic segmentation and may not work well for interactive segmentation. + + Args: + trained_model_path: The path to the checkpoint of the model trained for instance segmentation. + output_path: The path where the exported model will be saved. + model_type: The model type. + initial_checkpoint_path: The initial checkpoint path the instance segmentation training was based on (optional). + """ + trained_state = torch.load(trained_model_path, weights_only=False, map_location="cpu")["model_state"] + + # Get the state of the encoder and instance segmentation decoder from the trained checkpoint. + encoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if k.startswith("encoder")]) + decoder_state = OrderedDict([(k, v) for k, v in trained_state.items() if not k.startswith("encoder")]) + + # Load the original state of the model that was used as the basis of instance segmentation training. + _, model_state = get_sam_model( + model_type=model_type, checkpoint_path=initial_checkpoint_path, return_state=True, device="cpu", + ) + # Remove the sam prefix if it's in the model state. + prefix = "sam." + model_state = OrderedDict( + [(k[len(prefix):] if k.startswith(prefix) else k, v) for k, v in model_state.items()] + ) + + # Replace the image encoder state. + model_state = OrderedDict( + [(k, encoder_state[k[6:]] if k.startswith("image_encoder") else v) + for k, v in model_state.items()] + ) + + save_state = {"model_state": model_state, "decoder_state": decoder_state} + torch.save(save_state, output_path) + + +def train_instance_segmentation( + name: str, + model_type: str, + train_loader: DataLoader, + val_loader: DataLoader, + n_epochs: int = 100, + early_stopping: Optional[int] = 10, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, + freeze: Optional[List[str]] = None, + device: Optional[Union[str, torch.device]] = None, + lr: float = 1e-5, + save_root: Optional[Union[str, os.PathLike]] = None, + n_iterations: Optional[int] = None, + scheduler_class: Optional[_LRScheduler] = torch.optim.lr_scheduler.ReduceLROnPlateau, + scheduler_kwargs: Optional[Dict[str, Any]] = None, + save_every_kth_epoch: Optional[int] = None, + pbar_signals: Optional[QObject] = None, + optimizer_class: Optional[Optimizer] = torch.optim.AdamW, + peft_kwargs: Optional[Dict] = None, + ignore_warnings: bool = True, + overwrite_training: bool = True, + **model_kwargs, +) -> None: + """Train a UNETR for instance segmentation using the SAM encoder as backbone. + + This setting corresponds to training a SAM model with an instance segmentation decoder, + without training the model parts for interactive segmentation, + i.e. without training the prompt encoder and mask decoder. + + The checkpoint of the trained model, which will be saved in 'checkpoints/', + will not be compatible with the micro_sam functionality. + You can call the function `export_instance_segmentation_model` with the path to the checkpoint to export it + in a format that is compatible with micro_sam functionality. + Note that the exported model should only be used for automatic segmentation via AIS. + + Args: + name: The name of the model to be trained. The checkpoint and logs will have this name. + model_type: The type of the SAM model. + train_loader: The dataloader for training. + val_loader: The dataloader for validation. + n_epochs: The number of epochs to train for. + early_stopping: Enable early stopping after this number of epochs without improvement. + checkpoint_path: Path to checkpoint for initializing the SAM model. + freeze: Specify parts of the model that should be frozen. Here, only the image_encoder can be frozen. + By default nothing is frozen and the full model is updated. + device: The device to use for training. + lr: The learning rate. + save_root: Optional root directory for saving the checkpoints and logs. + If not given the current working directory is used. + n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given. + scheduler_class: The learning rate scheduler to update the learning rate. + By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used. + scheduler_kwargs: The learning rate scheduler parameters. + If passed None, the chosen default parameters are used in ReduceLROnPlateau. + save_every_kth_epoch: Save checkpoints after every kth epoch separately. + pbar_signals: Controls for napari progress bar. + optimizer_class: The optimizer class. By default, torch.optim.AdamW is used. + peft_kwargs: Keyword arguments for the PEFT wrapper class. + ignore_warnings: Whether to ignore raised warnings. + overwrite_training: Whether to overwrite the trained model stored at the same location. + By default, overwrites the trained model at each run. + If set to 'False', it will avoid retraining the model if the previous run was completed. + model_kwargs: Additional keyword arguments for the `util.get_sam_model`. + """ - if save_every_kth_epoch is not None: - trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch + with _filter_warnings(ignore_warnings): + t_start = time.time() - if pbar_signals is not None: - progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) - trainer_fit_params["progress"] = progress_bar_wrapper + sam_model, state = get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + return_state=True, + peft_kwargs=peft_kwargs, + freeze=freeze, + **model_kwargs + ) + device = get_device(device) + model = get_unetr( + image_encoder=sam_model.sam.image_encoder, + decoder_state=state.get("decoder_state", None), + device=device, + ) - # Avoid overwriting a trained model, if desired by the user. - trainer_fit_params["overwrite_training"] = overwrite_training + loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) + optimizer, scheduler = _get_optimizer_and_scheduler( + model.parameters(), lr, optimizer_class, scheduler_class, scheduler_kwargs + ) + trainer = torch_em.trainer.DefaultTrainer( + name=name, + model=model, + train_loader=train_loader, + val_loader=val_loader, + device=device, + mixed_precision=True, + log_image_interval=50, + compile_model=False, + save_root=save_root, + loss=loss, + metric=loss, + optimizer=optimizer, + lr_scheduler=scheduler, + ) + trainer_fit_params = _get_trainer_fit_params( + n_epochs, n_iterations, save_every_kth_epoch, pbar_signals, overwrite_training + ) trainer.fit(**trainer_fit_params) t_run = time.time() - t_start @@ -395,6 +565,7 @@ def default_sam_dataset( patch_shape: Tuple[int], with_segmentation_decoder: bool, with_channels: Optional[bool] = None, + train_instance_segmentation_only: bool = False, sampler: Optional[Callable] = None, raw_transform: Optional[Callable] = None, n_samples: Optional[int] = None, @@ -417,6 +588,8 @@ def default_sam_dataset( patch_shape: The shape for training patches. with_segmentation_decoder: Whether to train with additional segmentation decoder. with_channels: Whether the image data has channels. By default, it makes the decision based on inputs. + train_instance_segmentation_only: Set this argument to True in order to + pass the dataset to `train_instance_segmentation`. sampler: A sampler to reject batches according to a given criterion. raw_transform: Transformation applied to the image data. If not given the data will be cast to 8bit. @@ -430,6 +603,24 @@ def default_sam_dataset( The segmentation dataset. """ + # Check if this dataset should be used for instance segmentation only training. + # If yes, we set return_instances to False, since the instance channel must not + # be passed for this training mode. + return_instances = True + if train_instance_segmentation_only: + if not with_segmentation_decoder: + raise ValueError( + "If 'train_instance_segmentation_only' is True, then 'with_segmentation_decoder' must also be True." + ) + return_instances = False + + # If a sampler is not passed, then we set a MinInstanceSampler, which requires 3 distinct instances per sample. + # This is necessary, because training for interactive segmentation does not work on 'empty' images. + # However, if we train only the automatic instance segmentation decoder, then this sampler is not required + # and we do not set a default sampler. + if sampler is None and not train_instance_segmentation_only: + sampler = torch_em.data.sampler.MinInstanceSampler(2, min_size=min_size) + # By default, let the 'default_segmentation_dataset' heuristic decide for itself. is_seg_dataset = kwargs.pop("is_seg_dataset", None) @@ -470,16 +661,12 @@ def default_sam_dataset( boundary_distances=True, directed_distances=False, foreground=True, - instances=True, + instances=return_instances, min_size=min_size, ) else: label_transform = torch_em.transform.label.MinSizeLabelTransform(min_size=min_size) - # Set a default sampler if none was passed. - if sampler is None: - sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size) - # Check the patch shape to add a singleton if required. patch_shape = _update_patch_shape( patch_shape=patch_shape, raw_paths=raw_paths, raw_key=raw_key, with_channels=with_channels, @@ -556,6 +743,8 @@ def default_sam_loader(**kwargs) -> DataLoader: "V100": {"model_type": "vit_b"}, "A100": {"model_type": "vit_h"}, } +"""Best training configurations for given hardware resources. +""" def _find_best_configuration(): @@ -579,10 +768,6 @@ def _find_best_configuration(): return "CPU" -"""Best training configurations for given hardware resources. -""" - - def train_sam_for_configuration( name: str, configuration: str, @@ -590,6 +775,7 @@ def train_sam_for_configuration( val_loader: DataLoader, checkpoint_path: Optional[Union[str, os.PathLike]] = None, with_segmentation_decoder: bool = True, + train_instance_segmentation_only: bool = False, model_type: Optional[str] = None, **kwargs, ) -> None: @@ -607,6 +793,8 @@ def train_sam_for_configuration( checkpoint_path: Path to checkpoint for initializing the SAM model. with_segmentation_decoder: Whether to train additional UNETR decoder for automatic instance segmentation. + train_instance_segmentation_only: Whether to train a model only for automatic instance segmentation + using the training implementation `train_instance_segmentation`. By default, `train_sam` is used. model_type: Over-ride the default model type. This can be used to use one of the micro_sam models as starting point instead of a default sam model. @@ -625,15 +813,26 @@ def train_sam_for_configuration( warnings.warn("You have specified a different model type.") train_kwargs.update(**kwargs) - train_sam( - name=name, - train_loader=train_loader, - val_loader=val_loader, - checkpoint_path=checkpoint_path, - with_segmentation_decoder=with_segmentation_decoder, - model_type=model_type, - **train_kwargs - ) + if train_instance_segmentation_only: + train_instance_segmentation( + name=name, + train_loader=train_loader, + val_loader=val_loader, + checkpoint_path=checkpoint_path, + with_segmentation_decoder=with_segmentation_decoder, + model_type=model_type, + **train_kwargs + ) + else: + train_sam( + name=name, + train_loader=train_loader, + val_loader=val_loader, + checkpoint_path=checkpoint_path, + with_segmentation_decoder=with_segmentation_decoder, + model_type=model_type, + **train_kwargs + ) def _export_helper(save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader): @@ -692,6 +891,21 @@ def _export_helper(save_root, checkpoint_name, output_path, model_type, with_seg return final_path +def _parse_segmentation_decoder(segmentation_decoder): + if segmentation_decoder in ("None", "none"): + with_segmentation_decoder, train_instance_segmentation_only = False, False + elif segmentation_decoder == "instances": + with_segmentation_decoder, train_instance_segmentation_only = True, False + elif segmentation_decoder == "instances_only": + with_segmentation_decoder, train_instance_segmentation_only = True, True + else: + raise ValueError( + "The 'segmentation_decoder' argument currently supports the values:\n" + f"'instances', 'instances_only', or 'None'. You have passed {segmentation_decoder}." + ) + return with_segmentation_decoder, train_instance_segmentation_only + + def main(): """@private""" import argparse @@ -753,12 +967,18 @@ def none_or_str(value): return None return value + # This could be extended to train for semantic segmentation or other options. parser.add_argument( "--segmentation_decoder", type=none_or_str, default="instances", - # TODO: in future, we can extend this to semantic seg / or even more advanced stuff. - help="Whether to finetune Segment Anything Model with additional segmentation decoder for desired targets. " - "By default, it uses the 'instances' option, i.e. trains with the additional segmentation decoder for " - "instance segmentation, otherwise pass 'None' for training without the additional segmentation decoder at all." + help="Whether to finetune Segment Anything Model with an additional segmentation decoder. " + "The following options are possible:\n" + "- 'instances' to train with an additional decoder for automatic instance segmentation. " + " This option enables using the automatic instance segmentation (AIS) mode.\n" + "- 'instances_only' to train only the instance segmentation decoder. " + " In this case the parts of SAM that are used for interactive segmentation will not be trained.\n" + "- 'None' to train without an additional segmentation decoder." + " This options trains only the parts of the original SAM.\n" + "By default the option 'instances' is used." ) # Optional advanced settings a user can opt to change the values for. @@ -825,12 +1045,7 @@ def none_or_str(value): device = args.device save_root = args.save_root output_path = args.output_path - - if args.segmentation_decoder and args.segmentation_decoder != "instances": - raise ValueError( - "The 'segmentation_decoder' argument currently supports 'instances' as input argument only." - ) - with_segmentation_decoder = (args.segmentation_decoder is not None) + with_segmentation_decoder, train_instance_segmentation_only = _parse_segmentation_decoder(args.segmentation_decoder) # Get image paths and corresponding keys. train_images, train_gt, train_image_key, train_gt_key = args.images, args.labels, args.image_key, args.label_key @@ -850,6 +1065,7 @@ def none_or_str(value): patch_shape=patch_shape, with_segmentation_decoder=with_segmentation_decoder, raw_transform=_raw_transform, + train_instance_segmentation_only=train_instance_segmentation_only, ) # If val images are not exclusively provided, we create a val split from the training data. @@ -868,6 +1084,7 @@ def none_or_str(value): label_key=val_gt_key, patch_shape=patch_shape, with_segmentation_decoder=with_segmentation_decoder, + train_instance_segmentation_only=train_instance_segmentation_only, raw_transform=_raw_transform, ) @@ -899,11 +1116,17 @@ def none_or_str(value): device=device, save_root=save_root, peft_kwargs=None, # TODO: Allow for PEFT. + train_instance_segmentation_only=train_instance_segmentation_only, ) # 4. Export the model, if desired by the user - final_path = _export_helper( - save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader - ) + if train_instance_segmentation_only and output_path: + trained_path = os.path.join("" if save_root is None else save_root, "checkpoints", checkpoint_name, "best.pt") + export_instance_segmentation_model(trained_path, output_path, model_type, checkpoint_path) + final_path = output_path + else: + final_path = _export_helper( + save_root, checkpoint_name, output_path, model_type, with_segmentation_decoder, val_loader, + ) print(f"Training has finished. The trained model is saved at {final_path}.") diff --git a/scripts/training/train_instance_segmentation.py b/scripts/training/train_instance_segmentation.py new file mode 100644 index 000000000..895cc3ce7 --- /dev/null +++ b/scripts/training/train_instance_segmentation.py @@ -0,0 +1,64 @@ +"""This is an example script for training a model only for automated instance segmentation. +""" +import os + +# This function downloads the DSB dataset, which we use as exaple data for this script. +# You can use any other data with images and associated label masks for training, +# for example images and label masks stored in a .tif format. +from torch_em.data.datasets.light_microscopy.dsb import get_dsb_paths + +# The required functionality for training. +from micro_sam.training import export_instance_segmentation_model, default_sam_loader, train_instance_segmentation + +image_paths, label_paths = get_dsb_paths("./data", source="reduced", split="train", download=True) + +# Use 10% of the data for validation. +# val_len = int(0.1 * len(image_paths)) +# train_images, val_images = image_paths[:-val_len], image_paths[-val_len:] +# train_labels, val_labels = label_paths[:-val_len], label_paths[-val_len:] +train_images, val_images = image_paths[:10], image_paths[-5:] +train_labels, val_labels = label_paths[:10], label_paths[-5:] + +# Run the training. This will train a UNETR with instance segmentation decoder. +# This is equivalent to training the additional instance segmentation decoder +# in the full training logic, WITHOUT training the model for interactive segmentation. + +# Adjust the patch shape to match your images! +patch_shape = (256, 256) + +# First, we create the training and validation loaders. +train_loader = default_sam_loader( + raw_paths=train_images, label_paths=train_labels, + raw_key=None, label_key=None, batch_size=1, + patch_shape=patch_shape, with_segmentation_decoder=True, + train_instance_segmentation_only=True, is_train=True, +) +val_loader = default_sam_loader( + raw_paths=val_images, label_paths=val_labels, + raw_key=None, label_key=None, batch_size=1, + patch_shape=patch_shape, with_segmentation_decoder=True, + train_instance_segmentation_only=True, is_train=False, +) + +# Choose the model type to start the training from. +# We recommend 'vit_b_lm' for any light microscopy images. +model_type = "vit_t_lm" + +# This is the name for the checkpoint that will be trained. +name = "ais-dsb" + +# Then run the training. Check out the docstring of the function for more training options. +train_instance_segmentation( + name=name, + model_type=model_type, + train_loader=train_loader, + val_loader=val_loader +) + +# Finally, we export the trained model to a new format that is compatible with micro_sam: +# This exported model can be used by micro_sam functions, the CLI or in the napari plugin. +# However, it may not work well for interactive segmentation, since it may suffer from 'catastrophic forgetting' +# for this task, because its image encoder was updated without training for interactive segmentation. +checkpoint_path = os.path.join("checkpoints", name, "best.pt") +export_path = "./ais-dsb-model.pt" +export_instance_segmentation_model(checkpoint_path, export_path, model_type) diff --git a/setup.cfg b/setup.cfg index 24bbfec82..f8eb4300a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,6 +53,7 @@ console_scripts = micro_sam.evaluate = micro_sam.evaluation.evaluation:main micro_sam.info = micro_sam.util:micro_sam_info micro_sam.benchmark_sam = micro_sam.evaluation.benchmark_datasets:main + micro_sam.continue_annotation = micro_sam.sam_annotator.reproducibility:main # make sure it gets included in your package diff --git a/test/test_sam_annotator/test_cli.py b/test/test_cli.py similarity index 100% rename from test/test_sam_annotator/test_cli.py rename to test/test_cli.py diff --git a/test/test_sam_annotator/commit-for-test.zarr/.zattrs b/test/test_sam_annotator/commit-for-test.zarr/.zattrs new file mode 100644 index 000000000..8428e7a4f --- /dev/null +++ b/test/test_sam_annotator/commit-for-test.zarr/.zattrs @@ -0,0 +1 @@ +{"annotator_class":"Annotator2d","commit_history":[{"auto_segmentation":{"boundary_distance_threshold":0.5,"center_distance_threshold":0.5,"min_object_size":100,"object_ids":[1,2,3],"preservation_threshold":0.75,"preserve_mode":"objects","with_background":true}},{"current_object":{"object_ids":[4,5],"preservation_threshold":0.75,"preserve_mode":"objects"}},{"current_object":{"object_ids":[6,7],"preservation_threshold":0.75,"preserve_mode":"objects"}},{"current_object":{"object_ids":[8],"preservation_threshold":0.75,"preserve_mode":"objects"}},{"current_object":{"object_ids":[9],"preservation_threshold":0.75,"preserve_mode":"objects"}},{"current_object":{"object_ids":[10],"preservation_threshold":0.75,"preserve_mode":"objects"}}],"data_signature":"11200d9b5539224a26ad0aae8dd974b0b97ca075","halo":null,"micro_sam_version":"1.4.0","model_hash":"xxh128:72ec5074774761a6e5c05a08942f981e","model_name":"vit_t_lm","model_type":"vit_t","tile_shape":null} diff --git a/test/test_sam_annotator/commit-for-test.zarr/.zgroup b/test/test_sam_annotator/commit-for-test.zarr/.zgroup new file mode 100644 index 000000000..3f3fad2d1 --- /dev/null +++ b/test/test_sam_annotator/commit-for-test.zarr/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} diff --git a/test/test_sam_annotator/commit-for-test.zarr/committed_objects/.zarray b/test/test_sam_annotator/commit-for-test.zarr/committed_objects/.zarray new file mode 100644 index 000000000..15d56101d --- /dev/null +++ b/test/test_sam_annotator/commit-for-test.zarr/committed_objects/.zarray @@ -0,0 +1,20 @@ +{ + "chunks": [ + 512, + 512 + ], + "compressor": { + "id": "zlib", + "level": 5 + }, + "dimension_separator": ".", + "dtype": "