|
| 1 | +import numpy as np |
| 2 | +import vigra |
| 3 | + |
| 4 | +from elf.segmentation import embeddings as embed |
| 5 | +from skimage.transform import resize |
| 6 | +try: |
| 7 | + from napari.utils import progress as tqdm |
| 8 | +except ImportError: |
| 9 | + from tqdm import tqdm |
| 10 | + |
| 11 | +from . import util |
| 12 | +from .segment_from_prompts import segment_from_mask |
| 13 | + |
| 14 | + |
| 15 | +# |
| 16 | +# Original SegmentAnything instance segmentation functionality |
| 17 | +# |
| 18 | + |
| 19 | + |
| 20 | +# TODO implement automatic instance segmentation based on the functionalities from segment anything: |
| 21 | +# https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py |
| 22 | + |
| 23 | + |
| 24 | +# |
| 25 | +# Instance segmentation from embeddings |
| 26 | +# |
| 27 | + |
| 28 | + |
| 29 | +def _refine_initial_segmentation(predictor, initial_seg, image_embeddings, i, verbose): |
| 30 | + util.set_precomputed(predictor, image_embeddings, i) |
| 31 | + |
| 32 | + original_size = image_embeddings["original_size"] |
| 33 | + seg = np.zeros(original_size, dtype="uint32") |
| 34 | + |
| 35 | + seg_ids = np.unique(initial_seg) |
| 36 | + # TODO be smarter for overlapping masks, (use automatic_mask_generation from SAM as template) |
| 37 | + for seg_id in tqdm(seg_ids[1:], disable=not verbose, desc="Refine masks for automatic instance segmentation"): |
| 38 | + mask = (initial_seg == seg_id) |
| 39 | + assert mask.shape == (256, 256) |
| 40 | + refined = segment_from_mask(predictor, mask, original_size=original_size).squeeze() |
| 41 | + assert refined.shape == seg.shape |
| 42 | + seg[refined.squeeze()] = seg_id |
| 43 | + |
| 44 | + # import napari |
| 45 | + # v = napari.Viewer() |
| 46 | + # v.add_image(mask) |
| 47 | + # v.add_labels(refined) |
| 48 | + # napari.run() |
| 49 | + |
| 50 | + return seg |
| 51 | + |
| 52 | + |
| 53 | +# This is a first prototype for generating automatic instance segmentations from the image embeddings |
| 54 | +# predicted by the segment anything image encoder. |
| 55 | + |
| 56 | +# Main challenge: the larger the image the worse this will get because of the fixed embedding size. |
| 57 | +# Ideas: |
| 58 | +# - Can we get intermediate, larger embeddings from SAM? |
| 59 | +# - Can we run the encoder in a sliding window and somehow stitch the embeddings? |
| 60 | +# - Or: run the encoder in a sliding window and stitch the initial segmentation result. |
| 61 | +def segment_from_embeddings( |
| 62 | + predictor, image_embeddings, size_threshold=10, i=None, |
| 63 | + offsets=[[-1, 0], [0, -1], [-3, 0], [0, -3]], distance_type="l2", bias=0.0, |
| 64 | + verbose=True, return_initial_seg=False, |
| 65 | +): |
| 66 | + util.set_precomputed(predictor, image_embeddings, i) |
| 67 | + |
| 68 | + embeddings = predictor.get_image_embedding().squeeze().cpu().numpy() |
| 69 | + assert embeddings.shape == (256, 64, 64), f"{embeddings.shape}" |
| 70 | + initial_seg = embed.segment_embeddings_mws( |
| 71 | + embeddings, distance_type=distance_type, offsets=offsets, bias=bias |
| 72 | + ).astype("uint32") |
| 73 | + assert initial_seg.shape == (64, 64), f"{initial_seg.shape}" |
| 74 | + |
| 75 | + # filter out small objects |
| 76 | + seg_ids, sizes = np.unique(initial_seg, return_counts=True) |
| 77 | + initial_seg[np.isin(initial_seg, seg_ids[sizes < size_threshold])] = 0 |
| 78 | + vigra.analysis.relabelConsecutive(initial_seg, out=initial_seg) |
| 79 | + |
| 80 | + # resize to 256 x 256, which is the mask input expected by SAM |
| 81 | + initial_seg = resize( |
| 82 | + initial_seg, (256, 256), order=0, preserve_range=True, anti_aliasing=False |
| 83 | + ).astype(initial_seg.dtype) |
| 84 | + seg = _refine_initial_segmentation(predictor, initial_seg, image_embeddings, i, verbose) |
| 85 | + |
| 86 | + if return_initial_seg: |
| 87 | + initial_seg = resize( |
| 88 | + initial_seg, seg.shape, order=0, preserve_range=True, anti_aliasing=False |
| 89 | + ).astype(seg.dtype) |
| 90 | + return seg, initial_seg |
| 91 | + else: |
| 92 | + return seg |
| 93 | + |
| 94 | + |
| 95 | +# TODO |
| 96 | +def segment_from_embeddings_with_tiling( |
| 97 | + predictor, image, image_embeddings, tile_shape=(256, 256), tile_overlap=(32, 32), |
| 98 | + size_threshold=10, i=None, |
| 99 | + offsets=[[-1, 0], [0, -1], [-3, 0], [0, -3]], distance_type="l2", bias=0.0, |
| 100 | + verbose=True, return_initial_seg=False, |
| 101 | +): |
| 102 | + pass |
0 commit comments