Skip to content

Commit f29a18d

Browse files
Merge pull request #90 from computational-cell-analytics/tiled-amg
Implement TiledAutomaticMaskGenerator
2 parents 65bff2f + b840cb4 commit f29a18d

File tree

6 files changed

+331
-55
lines changed

6 files changed

+331
-55
lines changed

environment_cpu.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ dependencies:
77
- cpuonly
88
- napari
99
- pooch
10-
- python-elf
10+
- python-elf >=0.4.8
1111
- pytorch
1212
- torchvision
1313
- tqdm

environment_gpu.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ name:
77
dependencies:
88
- napari
99
- pooch
10-
- python-elf
10+
- python-elf >=0.4.8
1111
- pytorch
1212
- pytorch-cuda>=11.7 # you may need to update the cuda version to match your system
1313
- torchvision

examples/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# micro_sam examples
2+
3+
Examples for using the micro_sam annotation tools:
4+
- `sam_annotator_2d.py`: run the interactive 2d annotation tool
5+
- `sam_annotator_3d.py`: run the interactive 3d annotation tool
6+
- `sam_annotator_tracking.py`: run the interactive tracking annotation tool
7+
- `sam_image_series_annotator.py`: run the annotation tool for a series of images
8+
9+
The folder `use_as_library` contains example scripts that show how `micro_sam` can be used as a python
10+
library to apply Segment Anything on mult-dimensional data.
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import imageio.v3 as imageio
2+
import napari
3+
4+
from micro_sam import instance_segmentation, util
5+
6+
7+
def cell_segmentation():
8+
"""Run the instance segmentation functionality from micro_sam for segmentation of
9+
HeLA cells. You need to run examples/sam_annotator_2d.py:hela_2d_annotator once before
10+
running this script so that all required data is downloaded and pre-computed.
11+
"""
12+
image_path = "../data/hela-2d-image.png"
13+
embedding_path = "../embeddings/embeddings-hela2d.zarr"
14+
15+
# Load the image, the SAM Model, and the pre-computed embeddings.
16+
image = imageio.imread(image_path)
17+
predictor = util.get_sam_model()
18+
embeddings = util.precompute_image_embeddings(predictor, image, save_path=embedding_path)
19+
20+
# Use the instance segmentation logic of SegmentAnything.
21+
# This works by covering the image with a grid of points, getting the masks for all the poitns
22+
# and only keeping the plausible ones (according to the model predictions).
23+
# While the functionality here does the same as the implementation from SegmentAnything,
24+
# we enable changing the hyperparameters, e.g. 'pred_iou_thresh', without recomputing masks and embeddings,
25+
# to support (interactive) evaluation of different hyperparameters.
26+
27+
# Create the automatic mask generator class.
28+
amg = instance_segmentation.AutomaticMaskGenerator(predictor)
29+
30+
# Initialize the mask generator with the image and the pre-computed embeddings.
31+
amg.initialize(image, embeddings, verbose=True)
32+
33+
# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
34+
# without having to call initialize again.
35+
instances_amg = amg.generate(pred_iou_thresh=0.88)
36+
instances_amg = instance_segmentation.mask_data_to_segmentation(
37+
instances_amg, shape=image.shape, with_background=True
38+
)
39+
40+
# Use the mutex waterhsed based instance segmentation logic.
41+
# Here, we generate initial segmentation masks from the image embeddings, using the mutex watershed algorithm.
42+
# These initial masks are used as prompts for the actual instance segmentation.
43+
# This class uses the same overall design as 'AutomaticMaskGenerator'.
44+
45+
# Create the automatic mask generator class.
46+
amg_mws = instance_segmentation.EmbeddingMaskGenerator(predictor, min_initial_size=10)
47+
48+
# Initialize the mask generator with the image and the pre-computed embeddings.
49+
amg_mws.initialize(image, embeddings, verbose=True)
50+
51+
# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
52+
# without having to call initialize again.
53+
# NOTE: the main advantage of this method is that it's considerably faster than the original implementation.
54+
instances_mws = amg_mws.generate(pred_iou_thresh=0.88)
55+
instances_mws = instance_segmentation.mask_data_to_segmentation(
56+
instances_mws, shape=image.shape, with_background=True
57+
)
58+
59+
# Show the results.
60+
v = napari.Viewer()
61+
v.add_image(image)
62+
v.add_labels(instances_amg)
63+
v.add_labels(instances_mws)
64+
napari.run()
65+
66+
67+
def segmentation_with_tiling():
68+
"""Run the instance segmentation functionality from micro_sam for segmentation of
69+
cells in a large image. You need to run examples/sam_annotator_2d.py:wholeslide_annotator once before
70+
running this script so that all required data is downloaded and pre-computed.
71+
"""
72+
image_path = "../data/whole-slide-example-image.tif"
73+
embedding_path = "../embeddings/whole-slide-embeddings.zarr"
74+
75+
# Load the image, the SAM Model, and the pre-computed embeddings.
76+
image = imageio.imread(image_path)
77+
predictor = util.get_sam_model()
78+
embeddings = util.precompute_image_embeddings(
79+
predictor, image, save_path=embedding_path, tile_shape=(1024, 1024), halo=(256, 256)
80+
)
81+
82+
# Use the instance segmentation logic of SegmentAnything.
83+
# This works by covering the image with a grid of points, getting the masks for all the poitns
84+
# and only keeping the plausible ones (according to the model predictions).
85+
# The functionality here is similar to the instance segmentation in Segment Anything,
86+
# but uses the pre-computed tiled embeddings.
87+
88+
# Create the automatic mask generator class.
89+
amg = instance_segmentation.TiledAutomaticMaskGenerator(predictor)
90+
91+
# Initialize the mask generator with the image and the pre-computed embeddings.
92+
amg.initialize(image, embeddings, verbose=True)
93+
94+
# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
95+
# without having to call initialize again.
96+
instances_amg = amg.generate(pred_iou_thresh=0.88)
97+
instances_amg = instance_segmentation.mask_data_to_segmentation(
98+
instances_amg, shape=image.shape, with_background=True
99+
)
100+
101+
# Use the mutex waterhsed based instance segmentation logic.
102+
# Here, we generate initial segmentation masks from the image embeddings, using the mutex watershed algorithm.
103+
# These initial masks are used as prompts for the actual instance segmentation.
104+
# This class uses the same overall design as 'AutomaticMaskGenerator'.
105+
106+
# Create the automatic mask generator class.
107+
amg_mws = instance_segmentation.TiledEmbeddingMaskGenerator(predictor, min_initial_size=10)
108+
109+
# Initialize the mask generator with the image and the pre-computed embeddings.
110+
amg_mws.initialize(image, embeddings, verbose=True)
111+
112+
# Generate the instance segmentation. You can call this again for different values of 'pred_iou_thresh'
113+
# without having to call initialize again.
114+
# NOTE: the main advantage of this method is that it's considerably faster than the original implementation.
115+
instances_mws = amg_mws.generate(pred_iou_thresh=0.88)
116+
117+
# Show the results.
118+
v = napari.Viewer()
119+
v.add_image(image)
120+
# v.add_labels(instances_amg)
121+
v.add_labels(instances_mws)
122+
napari.run()
123+
124+
125+
def main():
126+
cell_segmentation()
127+
# segmentation_with_tiling()
128+
129+
130+
if __name__ == "__main__":
131+
main()

micro_sam/instance_segmentation.py

Lines changed: 140 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def mask_data_to_segmentation(
8282
#
8383

8484

85-
class _AMGBase(ABC):
86-
"""
85+
class AMGBase(ABC):
86+
"""Base class for the automatic mask generators.
8787
"""
8888
def __init__(self):
8989
# the state that has to be computed by the 'initialize' method of the child classes
@@ -277,7 +277,7 @@ def set_state(self, state: Dict[str, Any]) -> None:
277277
self._is_initialized = True
278278

279279

280-
class AutomaticMaskGenerator(_AMGBase):
280+
class AutomaticMaskGenerator(AMGBase):
281281
"""Generates an instance segmentation without prompts, using a point grid.
282282
283283
This class implements the same logic as
@@ -358,8 +358,11 @@ def _process_batch(self, points, im_size):
358358

359359
def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_embeddings):
360360
# crop the image and calculate embeddings
361-
x0, y0, x1, y1 = crop_box
362-
cropped_im = image[y0:y1, x0:x1, :]
361+
if crop_box is None:
362+
cropped_im = image
363+
else:
364+
x0, y0, x1, y1 = crop_box
365+
cropped_im = image[y0:y1, x0:x1, :]
363366
cropped_im_size = cropped_im.shape[:2]
364367

365368
if not precomputed_embeddings:
@@ -477,7 +480,7 @@ def generate(
477480
)
478481
data.cat(crop_data)
479482

480-
if len(self.crop_boxes) > 1:
483+
if len(self.crop_boxes) > 1 and len(data["crop_boxes"]) > 0:
481484
# Prefer masks from smaller crops
482485
scores = 1 / box_area(data["crop_boxes"])
483486
scores = scores.to(data["boxes"].device)
@@ -494,7 +497,7 @@ def generate(
494497
return masks
495498

496499

497-
class EmbeddingMaskGenerator(_AMGBase):
500+
class EmbeddingMaskGenerator(AMGBase):
498501
"""Generates an instance segmentation without prompts, using an initial segmentations derived from image embeddings.
499502
500503
Uses an intial segmentation derived from the image embeddings via the Mutex Watershed,
@@ -718,6 +721,133 @@ def set_state(self, state: Dict[str, Any]) -> None:
718721
super().set_state(state)
719722

720723

724+
def _compute_tiled_embeddings(predictor, image, image_embeddings, embedding_save_path, tile_shape, halo):
725+
have_tiling_params = (tile_shape is not None) and (halo is not None)
726+
if image_embeddings is None and have_tiling_params:
727+
if embedding_save_path is None:
728+
raise ValueError(
729+
"You have passed neither pre-computed embeddings nor a path for saving embeddings."
730+
"Embeddings with tiling can only be computed if a save path is given."
731+
)
732+
image_embeddings = util.precompute_image_embeddings(
733+
predictor, image, tile_shape=tile_shape, halo=halo, save_path=embedding_save_path
734+
)
735+
elif image_embeddings is None and not have_tiling_params:
736+
raise ValueError("You passed neither pre-computed embeddings nor tiling parameters (tile_shape and halo)")
737+
else:
738+
feats = image_embeddings["features"]
739+
tile_shape_, halo_ = feats.attrs["tile_shape"], feats.attrs["halo"]
740+
if have_tiling_params and (
741+
(list(tile_shape) != list(tile_shape_)) or
742+
(list(halo) != list(halo_))
743+
):
744+
warnings.warn(
745+
"You have passed both pre-computed embeddings and tiling parameters (tile_shape and halo) and"
746+
"the values of the tiling parameters from the embeddings disagree with the ones that were passed."
747+
"The tiling parameters you have passed wil be ignored."
748+
)
749+
tile_shape = tile_shape_
750+
halo = halo_
751+
752+
return image_embeddings, tile_shape, halo
753+
754+
755+
class TiledAutomaticMaskGenerator(AutomaticMaskGenerator):
756+
"""Generates an instance segmentation without prompts, using a point grid.
757+
758+
Implements the same functionality as `AutomaticMaskGenerator` but for tiled embeddings.
759+
760+
Args:
761+
predictor: The segment anything predictor.
762+
points_per_side: The number of points to be sampled along one side of the image.
763+
If None, `point_grids` must provide explicit point sampling.
764+
points_per_batch: The number of points run simultaneously by the model.
765+
Higher numbers may be faster but use more GPU memory.
766+
point_grids: A lisst over explicit grids of points used for sampling masks.
767+
Normalized to [0, 1] with respect to the image coordinate system.
768+
"""
769+
770+
# We only expose the arguments that make sense for the tiled mask generator.
771+
# Anything related to crops doesn't make sense, because we re-use that functionality
772+
# for tiling, so these parameters wouldn't have any effect.
773+
def __init__(
774+
self,
775+
predictor: SamPredictor,
776+
points_per_side: Optional[int] = 32,
777+
points_per_batch: int = 64,
778+
point_grids: Optional[List[np.ndarray]] = None,
779+
) -> None:
780+
super().__init__(
781+
predictor=predictor,
782+
points_per_side=points_per_side,
783+
points_per_batch=points_per_batch,
784+
point_grids=point_grids,
785+
)
786+
787+
@torch.no_grad()
788+
def initialize(
789+
self,
790+
image: np.ndarray,
791+
image_embeddings: Optional[util.ImageEmbeddings] = None,
792+
i: Optional[int] = None,
793+
tile_shape: Optional[Tuple[int, int]] = None,
794+
halo: Optional[Tuple[int, int]] = None,
795+
verbose: bool = False,
796+
embedding_save_path: Optional[str] = None,
797+
) -> None:
798+
"""Initialize image embeddings and masks for an image.
799+
800+
Args:
801+
image: The input image, volume or timeseries.
802+
image_embeddings: Optional precomputed image embeddings.
803+
See `util.precompute_image_embeddings` for details.
804+
i: Index for the image data. Required if `image` has three spatial dimensions
805+
or a time dimension and two spatial dimensions.
806+
tile_shape: The tile shape for embedding prediction.
807+
halo: The overlap of between tiles.
808+
verbose: Whether to print computation progress.
809+
embedding_save_path: Where to save the image embeddings.
810+
"""
811+
original_size = image.shape[:2]
812+
image_embeddings, tile_shape, halo = _compute_tiled_embeddings(
813+
self._predictor, image, image_embeddings, embedding_save_path, tile_shape, halo
814+
)
815+
816+
tiling = blocking([0, 0], original_size, tile_shape)
817+
n_tiles = tiling.numberOfBlocks
818+
819+
mask_data = []
820+
for tile_id in tqdm(range(n_tiles), total=n_tiles, desc="Compute masks for tile", disable=not verbose):
821+
# get the bounding box for this tile and crop the image data
822+
tile = tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock
823+
tile_bb = tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end))
824+
tile_data = image[tile_bb]
825+
826+
# set the pre-computed embeddings for this tile
827+
features = image_embeddings["features"][tile_id]
828+
tile_embeddings = {
829+
"features": features,
830+
"input_size": features.attrs["input_size"],
831+
"original_size": features.attrs["original_size"],
832+
}
833+
util.set_precomputed(self._predictor, tile_embeddings, i)
834+
835+
# compute the mask data for this tile and append it
836+
this_mask_data = self._process_crop(
837+
tile_data, crop_box=None, crop_layer_idx=0, verbose=verbose, precomputed_embeddings=True
838+
)
839+
mask_data.append(this_mask_data)
840+
841+
# set the initialized data
842+
self._is_initialized = True
843+
self._crop_list = mask_data
844+
self._original_size = original_size
845+
846+
# the crop box is always the full local tile
847+
tiles = [tiling.getBlockWithHalo(tile_id, list(halo)).outerBlock for tile_id in range(n_tiles)]
848+
self._crop_boxes = [[tile.begin[1], tile.begin[0], tile.end[1], tile.end[0]] for tile in tiles]
849+
850+
721851
class TiledEmbeddingMaskGenerator(EmbeddingMaskGenerator):
722852
"""Generates an instance segmentation without prompts, using an initial segmentations derived from image embeddings.
723853
@@ -812,33 +942,9 @@ def initialize(
812942
embedding_save_path: Where to save the image embeddings.
813943
"""
814944
original_size = image.shape[:2]
815-
816-
have_tiling_params = (tile_shape is not None) and (halo is not None)
817-
if image_embeddings is None and have_tiling_params:
818-
if embedding_save_path is None:
819-
raise ValueError(
820-
"You have passed neither pre-computed embeddings nor a path for saving embeddings."
821-
"Embeddings with tiling can only be computed if a save path is given."
822-
)
823-
image_embeddings = util.precompute_image_embeddings(
824-
self._predictor, image, tile_shape=tile_shape, halo=halo, save_path=embedding_save_path
825-
)
826-
elif image_embeddings is None and not have_tiling_params:
827-
raise ValueError("You passed neither pre-computed embeddings nor tiling parameters (tile_shape and halo)")
828-
else:
829-
feats = image_embeddings["features"]
830-
tile_shape_, halo_ = feats.attrs["tile_shape"], feats.attrs["halo"]
831-
if have_tiling_params and (
832-
(list(tile_shape) != list(tile_shape_)) or
833-
(list(halo) != list(halo_))
834-
):
835-
warnings.warn(
836-
"You have passed both pre-computed embeddings and tiling parameters (tile_shape and halo) and"
837-
"the values of the tiling parameters from the embeddings disagree with the ones that were passed."
838-
"The tiling parameters you have passed wil be ignored."
839-
)
840-
tile_shape = tile_shape_
841-
halo = halo_
945+
image_embeddings, tile_shape, halo = _compute_tiled_embeddings(
946+
self._predictor, image, image_embeddings, embedding_save_path, tile_shape, halo
947+
)
842948

843949
tiling = blocking([0, 0], original_size, tile_shape)
844950
n_tiles = tiling.numberOfBlocks

0 commit comments

Comments
 (0)