|
11 | 11 | from ..visualization import project_embeddings_for_visualization |
12 | 12 | from .util import ( |
13 | 13 | clear_all_prompts, commit_segmentation_widget, create_prompt_menu, |
14 | | - prompt_layer_to_boxes, prompt_layer_to_points, prompt_segmentation, toggle_label, LABEL_COLOR_CYCLE |
| 14 | + prompt_layer_to_boxes, prompt_layer_to_points, prompt_segmentation, toggle_label, LABEL_COLOR_CYCLE, |
| 15 | + _initialize_parser, |
15 | 16 | ) |
16 | 17 |
|
17 | 18 |
|
@@ -60,35 +61,25 @@ def autosegment_widget( |
60 | 61 | v.layers["auto_segmentation"].refresh() |
61 | 62 |
|
62 | 63 |
|
63 | | -def annotator_2d( |
64 | | - raw, embedding_path=None, show_embeddings=False, segmentation_result=None, |
65 | | - model_type="vit_h", tile_shape=None, halo=None, return_viewer=False, |
66 | | -): |
67 | | - # for access to the predictor and the image embeddings in the widgets |
68 | | - global PREDICTOR, IMAGE_EMBEDDINGS, SAM |
| 64 | +def _get_shape(raw): |
| 65 | + if raw.ndim == 2: |
| 66 | + shape = raw.shape |
| 67 | + elif raw.ndim == 3 and raw.shape[-1] == 3: |
| 68 | + shape = raw.shape[:2] |
| 69 | + else: |
| 70 | + raise ValueError(f"Invalid input image of shape {raw.shape}. Expect either 2D grayscale or 3D RGB image.") |
| 71 | + return shape |
69 | 72 |
|
70 | | - PREDICTOR, SAM = util.get_sam_model(model_type=model_type, return_sam=True) |
71 | | - IMAGE_EMBEDDINGS = util.precompute_image_embeddings( |
72 | | - PREDICTOR, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo |
73 | | - ) |
74 | | - # we set the pre-computed image embeddings if we don't use tiling |
75 | | - # (if we use tiling we cannot directly set it because the tile will be chosen dynamically) |
76 | | - if tile_shape is None: |
77 | | - util.set_precomputed(PREDICTOR, IMAGE_EMBEDDINGS) |
| 73 | + |
| 74 | +def _initialize_viewer(raw, segmentation_result, tile_shape, show_embeddings): |
| 75 | + v = Viewer() |
78 | 76 |
|
79 | 77 | # |
80 | 78 | # initialize the viewer and add layers |
81 | 79 | # |
82 | 80 |
|
83 | | - v = Viewer() |
84 | | - |
85 | 81 | v.add_image(raw) |
86 | | - if raw.ndim == 2: |
87 | | - shape = raw.shape |
88 | | - elif raw.ndim == 3 and raw.shape[-1] == 3: |
89 | | - shape = raw.shape[:2] |
90 | | - else: |
91 | | - raise ValueError(f"Invalid input image of shape {raw.shape}. Expect either 2D grayscale or 3D RGB image.") |
| 82 | + shape = _get_shape(raw) |
92 | 83 |
|
93 | 84 | v.add_labels(data=np.zeros(shape, dtype="uint32"), name="auto_segmentation") |
94 | 85 | if segmentation_result is None: |
@@ -157,66 +148,66 @@ def _toggle_label(event=None): |
157 | 148 | def clear_prompts(v): |
158 | 149 | clear_all_prompts(v) |
159 | 150 |
|
| 151 | + return v |
| 152 | + |
| 153 | + |
| 154 | +def _update_viewer(v, raw, show_embeddings, segmentation_result): |
| 155 | + if show_embeddings or segmentation_result is not None: |
| 156 | + raise NotImplementedError |
| 157 | + |
| 158 | + # update the image layer |
| 159 | + v.layers["raw"].data = raw |
| 160 | + shape = _get_shape(raw) |
| 161 | + |
| 162 | + # update the segmentation layers |
| 163 | + v.layers["auto_segmentation"].data = np.zeros(shape, dtype="uint32") |
| 164 | + v.layers["committed_objects"].data = np.zeros(shape, dtype="uint32") |
| 165 | + v.layers["current_object"].data = np.zeros(shape, dtype="uint32") |
| 166 | + |
| 167 | + |
| 168 | +def annotator_2d( |
| 169 | + raw, embedding_path=None, show_embeddings=False, segmentation_result=None, |
| 170 | + model_type="vit_h", tile_shape=None, halo=None, return_viewer=False, v=None, |
| 171 | + predictor=None, |
| 172 | +): |
| 173 | + # for access to the predictor and the image embeddings in the widgets |
| 174 | + global PREDICTOR, IMAGE_EMBEDDINGS |
| 175 | + |
| 176 | + if predictor is None: |
| 177 | + PREDICTOR = util.get_sam_model(model_type=model_type) |
| 178 | + else: |
| 179 | + PREDICTOR = predictor |
| 180 | + IMAGE_EMBEDDINGS = util.precompute_image_embeddings( |
| 181 | + PREDICTOR, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo |
| 182 | + ) |
| 183 | + |
| 184 | + # we set the pre-computed image embeddings if we don't use tiling |
| 185 | + # (if we use tiling we cannot directly set it because the tile will be chosen dynamically) |
| 186 | + if tile_shape is None: |
| 187 | + util.set_precomputed(PREDICTOR, IMAGE_EMBEDDINGS) |
| 188 | + |
| 189 | + # viewer is freshly initialized |
| 190 | + if v is None: |
| 191 | + v = _initialize_viewer(raw, segmentation_result, tile_shape, show_embeddings) |
| 192 | + # we use an existing viewer and just update all the layers |
| 193 | + else: |
| 194 | + _update_viewer(v, raw, show_embeddings, segmentation_result) |
| 195 | + |
160 | 196 | # |
161 | 197 | # start the viewer |
162 | 198 | # |
163 | | - |
164 | | - # clear the initial points needed for workaround |
165 | | - clear_prompts(v) |
| 199 | + clear_all_prompts(v) |
166 | 200 |
|
167 | 201 | if return_viewer: |
168 | 202 | return v |
| 203 | + |
169 | 204 | napari.run() |
170 | 205 |
|
171 | 206 |
|
172 | 207 | def main(): |
173 | | - import argparse |
174 | 208 | import warnings |
175 | 209 |
|
176 | | - parser = argparse.ArgumentParser( |
177 | | - description="Run interactive segmentation for an image." |
178 | | - ) |
179 | | - parser.add_argument( |
180 | | - "-i", "--input", required=True, |
181 | | - help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " |
182 | | - "or elf.io.open_file (e.g. hdf5, zarr, mrc) For the latter you also need to pass the 'key' parameter." |
183 | | - ) |
184 | | - parser.add_argument( |
185 | | - "-k", "--key", |
186 | | - help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, " |
187 | | - "for a image series it is a wild-card, e.g. '*.png' and for mrc it is 'data'." |
188 | | - ) |
189 | | - parser.add_argument( |
190 | | - "-e", "--embedding_path", |
191 | | - help="The filepath for saving/loading the pre-computed image embeddings. " |
192 | | - "NOTE: It is recommended to pass this argument and store the embeddings, " |
193 | | - "otherwise they will be recomputed every time (which can take a long time)." |
194 | | - ) |
195 | | - parser.add_argument( |
196 | | - "-s", "--segmentation_result", |
197 | | - help="Optional filepath to a precomputed segmentation. If passed this will be used to initialize the " |
198 | | - "'committed_objects' layer. This can be useful if you want to correct an existing segmentation or if you " |
199 | | - "have saved intermediate results from the annotator and want to continue with your annotations. " |
200 | | - "Supports the same file formats as 'input'." |
201 | | - ) |
202 | | - parser.add_argument( |
203 | | - "-sk", "--segmentation_key", |
204 | | - help="The key for opening the segmentation data. Same rules as for 'key' apply." |
205 | | - ) |
206 | | - parser.add_argument( |
207 | | - "--show_embeddings", action="store_true", |
208 | | - help="Visualize the embeddings computed by SegmentAnything. This can be helpful for debugging." |
209 | | - ) |
210 | | - parser.add_argument( |
211 | | - "--model_type", default="vit_h", help="The segment anything model that will be used, one of vit_h,l,b." |
212 | | - ) |
213 | | - parser.add_argument( |
214 | | - "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction", default=None |
215 | | - ) |
216 | | - parser.add_argument( |
217 | | - "--halo", nargs="+", type=int, help="The halo for using tiled prediction", default=None |
218 | | - ) |
219 | | - |
| 210 | + parser = _initialize_parser(description="Run interactive segmentation for an image.") |
220 | 211 | args = parser.parse_args() |
221 | 212 | raw = util.load_image_data(args.input, ndim=2, key=args.key) |
222 | 213 |
|
|
0 commit comments