|
| 1 | +import os |
| 2 | +import pickle |
| 3 | + |
| 4 | +from glob import glob |
| 5 | +from pathlib import Path |
| 6 | +from typing import Optional, Tuple, Union |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +from segment_anything.predictor import SamPredictor |
| 10 | +from tqdm import tqdm |
| 11 | + |
| 12 | +from . import instance_segmentation, util |
| 13 | + |
| 14 | + |
| 15 | +def cache_amg_state( |
| 16 | + predictor: SamPredictor, |
| 17 | + raw: np.ndarray, |
| 18 | + image_embeddings: util.ImageEmbeddings, |
| 19 | + save_path: Union[str, os.PathLike], |
| 20 | + verbose: bool = True, |
| 21 | + **kwargs, |
| 22 | +) -> instance_segmentation.AMGBase: |
| 23 | + """Compute and cache or load the state for the automatic mask generator. |
| 24 | +
|
| 25 | + Args: |
| 26 | + predictor: The segment anything predictor. |
| 27 | + raw: The image data. |
| 28 | + image_embeddings: The image embeddings. |
| 29 | + save_path: The embedding save path. The AMG state will be stored in <save_path>/amg_state.pickle. |
| 30 | + verbose: Whether to run the computation verbose. |
| 31 | + kwargs: The keyword arguments for the amg class. |
| 32 | +
|
| 33 | + Returns: |
| 34 | + The automatic mask generator class with the cached state. |
| 35 | + """ |
| 36 | + is_tiled = image_embeddings["input_size"] is None |
| 37 | + amg = instance_segmentation.get_amg(predictor, is_tiled, **kwargs) |
| 38 | + |
| 39 | + save_path_amg = os.path.join(save_path, "amg_state.pickle") |
| 40 | + if os.path.exists(save_path_amg): |
| 41 | + if verbose: |
| 42 | + print("Load the AMG state from", save_path_amg) |
| 43 | + with open(save_path_amg, "rb") as f: |
| 44 | + amg_state = pickle.load(f) |
| 45 | + amg.set_state(amg_state) |
| 46 | + return amg |
| 47 | + |
| 48 | + if verbose: |
| 49 | + print("Precomputing the state for instance segmentation.") |
| 50 | + amg.initialize(raw, image_embeddings=image_embeddings, verbose=verbose) |
| 51 | + with open(save_path_amg, "wb") as f: |
| 52 | + pickle.dump(amg.get_state(), f) |
| 53 | + |
| 54 | + return amg |
| 55 | + |
| 56 | + |
| 57 | +def _precompute_state_for_file( |
| 58 | + predictor, input_path, output_path, key, ndim, tile_shape, halo, precompute_amg_state, |
| 59 | +): |
| 60 | + image_data = util.load_image_data(input_path, key) |
| 61 | + output_path = Path(output_path).with_suffix(".zarr") |
| 62 | + embeddings = util.precompute_image_embeddings( |
| 63 | + predictor, image_data, output_path, ndim=ndim, tile_shape=tile_shape, halo=halo, |
| 64 | + ) |
| 65 | + if precompute_amg_state: |
| 66 | + cache_amg_state(predictor, image_data, embeddings, output_path, verbose=True) |
| 67 | + |
| 68 | + |
| 69 | +def _precompute_state_for_files( |
| 70 | + predictor, input_files, output_path, ndim, tile_shape, halo, precompute_amg_state, |
| 71 | +): |
| 72 | + os.makedirs(output_path, exist_ok=True) |
| 73 | + for file_path in tqdm(input_files, desc="Precompute state for files."): |
| 74 | + out_path = os.path.join(output_path, os.path.basename(file_path)) |
| 75 | + _precompute_state_for_file( |
| 76 | + predictor, file_path, out_path, |
| 77 | + key=None, ndim=ndim, tile_shape=tile_shape, halo=halo, |
| 78 | + precompute_amg_state=precompute_amg_state, |
| 79 | + ) |
| 80 | + |
| 81 | + |
| 82 | +def precompute_state( |
| 83 | + input_path: Union[os.PathLike, str], |
| 84 | + output_path: Union[os.PathLike, str], |
| 85 | + model_type: str = util._DEFAULT_MODEL, |
| 86 | + checkpoint_path: Optional[Union[os.PathLike, str]] = None, |
| 87 | + key: Optional[str] = None, |
| 88 | + ndim: Union[int] = None, |
| 89 | + tile_shape: Optional[Tuple[int, int]] = None, |
| 90 | + halo: Optional[Tuple[int, int]] = None, |
| 91 | + precompute_amg_state: bool = False, |
| 92 | +) -> None: |
| 93 | + """Precompute the image embeddings and other optional state for the input image(s). |
| 94 | +
|
| 95 | + Args: |
| 96 | + input_path: The input image file(s). Can either be a single image file (e.g. tif or png), |
| 97 | + a container file (e.g. hdf5 or zarr) or a folder with images files. |
| 98 | + In case of a container file the argument `key` must be given. In case of a folder |
| 99 | + it can be given to provide a glob pattern to subselect files from the folder. |
| 100 | + output_path: The output path were the embeddings and other state will be saved. |
| 101 | + model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. |
| 102 | + checkpoint_path: Path to a checkpoint for a custom model. |
| 103 | + key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) |
| 104 | + and can be used to provide a glob pattern if the input is a folder with image files. |
| 105 | + ndim: The dimensionality of the data. |
| 106 | + tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. |
| 107 | + halo: Overlap of the tiles for tiled prediction. |
| 108 | + precompute_amg_state: Whether to precompute the state for automatic instance segmentation |
| 109 | + in addition to the image embeddings. |
| 110 | + """ |
| 111 | + predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path) |
| 112 | + # check if we precompute the state for a single file or for a folder with image files |
| 113 | + if os.path.isdir(input_path) and Path(input_path).suffix not in (".n5", ".zarr"): |
| 114 | + pattern = "*" if key is None else key |
| 115 | + input_files = glob(os.path.join(input_path, pattern)) |
| 116 | + _precompute_state_for_files( |
| 117 | + predictor, input_files, output_path, |
| 118 | + ndim=ndim, tile_shape=tile_shape, halo=halo, |
| 119 | + precompute_amg_state=precompute_amg_state, |
| 120 | + ) |
| 121 | + else: |
| 122 | + _precompute_state_for_file( |
| 123 | + predictor, input_path, output_path, key, |
| 124 | + ndim=ndim, tile_shape=tile_shape, halo=halo, |
| 125 | + precompute_amg_state=precompute_amg_state, |
| 126 | + ) |
| 127 | + |
| 128 | + |
| 129 | +def main(): |
| 130 | + """@private""" |
| 131 | + import argparse |
| 132 | + |
| 133 | + parser = argparse.ArgumentParser(description="Compute the embeddings for an image.") |
| 134 | + parser.add_argument("-i", "--input_path", required=True) |
| 135 | + parser.add_argument("-o", "--output_path", required=True) |
| 136 | + parser.add_argument("-m", "--model_type", default="vit_h") |
| 137 | + parser.add_argument("-c", "--checkpoint_path", default=None) |
| 138 | + parser.add_argument("-k", "--key") |
| 139 | + parser.add_argument( |
| 140 | + "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None |
| 141 | + ) |
| 142 | + parser.add_argument( |
| 143 | + "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None |
| 144 | + ) |
| 145 | + parser.add_argument("-n", "--ndim") |
| 146 | + parser.add_argument("-p", "--precompute_amg_state") |
| 147 | + |
| 148 | + args = parser.parse_args() |
| 149 | + precompute_state( |
| 150 | + args.input_path, args.output_path, args.model_type, args.checkpoint_path, |
| 151 | + key=args.key, tile_shape=args.tile_shape, halo=args.halo, ndim=args.ndim, |
| 152 | + precompute_amg_state=args.precompute_amg_state, |
| 153 | + ) |
| 154 | + |
| 155 | + |
| 156 | +if __name__ == "__main__": |
| 157 | + main() |
0 commit comments