|
| 1 | +import napari |
| 2 | +import numpy as np |
| 3 | +from napari import Viewer |
| 4 | +from magicgui import magicgui |
| 5 | + |
| 6 | +from .annotator_2d import _get_shape |
| 7 | +from .util import _initialize_parser |
| 8 | +from ..import util |
| 9 | +from ..import segment_instances |
| 10 | +from ..visualization import project_embeddings_for_visualization |
| 11 | + |
| 12 | + |
| 13 | +@magicgui(call_button="Automatic Segmentation [S]") |
| 14 | +def autosegment_widget( |
| 15 | + v: Viewer, |
| 16 | + pred_iou_thresh: float = 0.88, |
| 17 | + stability_score_thresh: float = 0.95, |
| 18 | + min_initial_size: int = 10, |
| 19 | + box_extension: float = 0.1, |
| 20 | + with_background: bool = True, |
| 21 | + use_box: bool = True, |
| 22 | + use_mask: bool = True, |
| 23 | + use_points: bool = False, |
| 24 | +): |
| 25 | + is_tiled = IMAGE_EMBEDDINGS["input_size"] is None |
| 26 | + if is_tiled: |
| 27 | + seg, initial_seg = segment_instances.segment_instances_from_embeddings_with_tiling( |
| 28 | + PREDICTOR, IMAGE_EMBEDDINGS, with_background=with_background, |
| 29 | + box_extension=box_extension, pred_iou_thresh=pred_iou_thresh, |
| 30 | + stability_score_thresh=stability_score_thresh, |
| 31 | + min_initial_size=min_initial_size, return_initial_segmentation=True, verbose=2, |
| 32 | + use_box=use_box, use_mask=use_mask, use_points=use_points, |
| 33 | + ) |
| 34 | + else: |
| 35 | + seg, initial_seg = segment_instances.segment_instances_from_embeddings( |
| 36 | + PREDICTOR, IMAGE_EMBEDDINGS, with_background=with_background, |
| 37 | + box_extension=box_extension, pred_iou_thresh=pred_iou_thresh, |
| 38 | + stability_score_thresh=stability_score_thresh, |
| 39 | + min_initial_size=min_initial_size, return_initial_segmentation=True, verbose=2, |
| 40 | + use_box=use_box, use_mask=use_mask, use_points=use_points, |
| 41 | + ) |
| 42 | + |
| 43 | + v.layers["auto_segmentation"].data = seg |
| 44 | + v.layers["auto_segmentation"].refresh() |
| 45 | + |
| 46 | + v.layers["initial_segmentation"].data = initial_seg |
| 47 | + v.layers["initial_segmentation"].refresh() |
| 48 | + |
| 49 | + |
| 50 | +def interactive_instance_segmentation( |
| 51 | + raw, embedding_path=None, model_type="vit_h", tile_shape=None, halo=None, checkpoint=None, |
| 52 | +): |
| 53 | + """Visualizing and debugging automatic instance segmentation. |
| 54 | + """ |
| 55 | + global PREDICTOR, IMAGE_EMBEDDINGS |
| 56 | + |
| 57 | + PREDICTOR = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint) |
| 58 | + IMAGE_EMBEDDINGS = util.precompute_image_embeddings( |
| 59 | + PREDICTOR, raw, save_path=embedding_path, ndim=2, tile_shape=tile_shape, halo=halo, |
| 60 | + ) |
| 61 | + |
| 62 | + shape = _get_shape(raw) |
| 63 | + |
| 64 | + v = napari.Viewer() |
| 65 | + |
| 66 | + v.add_image(raw) |
| 67 | + embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS) |
| 68 | + v.add_image(embedding_vis, name="embeddings", scale=scale, visible=False) |
| 69 | + v.add_labels(data=np.zeros(shape, dtype="uint32"), name="initial_segmentation") |
| 70 | + v.add_labels(data=np.zeros(shape, dtype="uint32"), name="auto_segmentation") |
| 71 | + |
| 72 | + v.window.add_dock_widget(autosegment_widget) |
| 73 | + |
| 74 | + napari.run() |
| 75 | + |
| 76 | + |
| 77 | +def main(): |
| 78 | + parser = _initialize_parser( |
| 79 | + description="Run the automatic instance segmentation in interactive mode" |
| 80 | + "to determine the optimal segmentation parameters.", |
| 81 | + with_segmentation_result=False, with_show_embeddings=False, |
| 82 | + ) |
| 83 | + parser.add_argument( |
| 84 | + "-c", "--checkpoint", default=None, |
| 85 | + help="Path to alternative checkpoint instead of the standard model." |
| 86 | + ) |
| 87 | + args = parser.parse_args() |
| 88 | + raw = util.load_image_data(args.input, ndim=2, key=args.key) |
| 89 | + interactive_instance_segmentation( |
| 90 | + raw, args.embedding_path, args.model_type, args.tile_shape, args.halo, args.checkpoint |
| 91 | + ) |
0 commit comments