Skip to content

Commit 5922499

Browse files
Add SAM instance segmentation
1 parent 90e6a52 commit 5922499

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

development/instance_segmentation.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import napari
33

44
from elf.io import open_file
5-
from micro_sam.segment_instances import segment_from_embeddings
5+
from micro_sam.segment_instances import segment_instances_from_embeddings, segment_instances_sam
66
from micro_sam.visualization import compute_pca
77

88

@@ -11,15 +11,23 @@ def mito_segmentation():
1111
with open_file(input_path) as f:
1212
raw = f["*.png"][-1, :768, :768]
1313

14-
predictor = util.get_sam_model()
14+
predictor, sam = util.get_sam_model(return_sam=True)
15+
16+
print("Run SAM prediction ...")
17+
seg_sam = segment_instances_sam(sam, raw)
18+
1519
image_embeddings = util.precompute_image_embeddings(predictor, raw, "../examples/embeddings/embeddings-mito2d.zarr")
1620
embedding_pca = compute_pca(image_embeddings["features"])
1721

18-
seg, initial_seg = segment_from_embeddings(predictor, image_embeddings=image_embeddings, return_initial_seg=True)
22+
print("Run prediction from embeddings ...")
23+
seg, initial_seg = segment_instances_from_embeddings(
24+
predictor, image_embeddings=image_embeddings, return_initial_seg=True
25+
)
1926

2027
v = napari.Viewer()
2128
v.add_image(raw)
2229
v.add_image(embedding_pca, scale=(12, 12))
30+
v.add_labels(seg_sam)
2331
v.add_labels(seg)
2432
v.add_labels(initial_seg)
2533
napari.run()
@@ -32,21 +40,27 @@ def cell_segmentation():
3240

3341
frame = 11
3442

35-
predictor = util.get_sam_model()
43+
predictor, sam = util.get_sam_model(return_sam=True)
44+
45+
print("Run prediction from embeddings ...")
3646
image_embeddings = util.precompute_image_embeddings(
3747
predictor, timeseries, "../examples/embeddings/embeddings-ctc.zarr"
3848
)
3949
embedding_pca = compute_pca(image_embeddings["features"][frame])
4050

41-
seg, initial_seg = segment_from_embeddings(
51+
seg, initial_seg = segment_instances_from_embeddings(
4252
predictor, image_embeddings=image_embeddings, i=frame, return_initial_seg=True
4353
)
4454

55+
print("Run SAM prediction ...")
56+
seg_sam = segment_instances_sam(sam, timeseries[frame])
57+
4558
v = napari.Viewer()
4659
v.add_image(timeseries[frame])
4760
v.add_image(embedding_pca, scale=(8, 8))
4861
v.add_labels(seg)
4962
v.add_labels(initial_seg)
63+
v.add_labels(seg_sam)
5064
napari.run()
5165

5266

micro_sam/segment_instances.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import vigra
33

44
from elf.segmentation import embeddings as embed
5+
from segment_anything import SamAutomaticMaskGenerator
56
from skimage.transform import resize
7+
68
try:
79
from napari.utils import progress as tqdm
810
except ImportError:
@@ -17,8 +19,18 @@
1719
#
1820

1921

20-
# TODO implement automatic instance segmentation based on the functionalities from segment anything:
21-
# https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py
22+
def segment_instances_sam(sam, image, **kwargs):
23+
segmentor = SamAutomaticMaskGenerator(sam, **kwargs)
24+
25+
image_ = util._to_image(image)
26+
masks = segmentor.generate(image_)
27+
masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
28+
29+
segmentation = np.zeros(image.shape[:2], dtype="uint32")
30+
for seg_id, mask in enumerate(masks, 1):
31+
segmentation[mask["segmentation"]] = seg_id
32+
33+
return segmentation
2234

2335

2436
#
@@ -58,7 +70,7 @@ def _refine_initial_segmentation(predictor, initial_seg, image_embeddings, i, ve
5870
# - Can we get intermediate, larger embeddings from SAM?
5971
# - Can we run the encoder in a sliding window and somehow stitch the embeddings?
6072
# - Or: run the encoder in a sliding window and stitch the initial segmentation result.
61-
def segment_from_embeddings(
73+
def segment_instances_from_embeddings(
6274
predictor, image_embeddings, size_threshold=10, i=None,
6375
offsets=[[-1, 0], [0, -1], [-3, 0], [0, -3]], distance_type="l2", bias=0.0,
6476
verbose=True, return_initial_seg=False,

0 commit comments

Comments
 (0)