Skip to content

Commit b929d09

Browse files
Add support for RGB images in the 2d annotator
1 parent 58fe511 commit b929d09

File tree

2 files changed

+29
-10
lines changed

2 files changed

+29
-10
lines changed

micro_sam/sam_annotator/annotator_2d.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_r
3535
global PREDICTOR, IMAGE_EMBEDDINGS
3636

3737
PREDICTOR = util.get_sam_model()
38-
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path)
38+
IMAGE_EMBEDDINGS = util.precompute_image_embeddings(PREDICTOR, raw, save_path=embedding_path, ndim=2)
3939
util.set_precomputed(PREDICTOR, IMAGE_EMBEDDINGS)
4040

4141
#
@@ -45,16 +45,23 @@ def annotator_2d(raw, embedding_path=None, show_embeddings=False, segmentation_r
4545
v = Viewer()
4646

4747
v.add_image(raw)
48-
v.add_labels(data=np.zeros(raw.shape, dtype="uint32"), name="auto_segmentation")
48+
if raw.ndim == 2:
49+
shape = raw.shape
50+
elif raw.ndim == 3 and raw.shape[-1] == 3:
51+
shape = raw.shape[:2]
52+
else:
53+
raise ValueError(f"Invalid input image of shape {raw.shape}. Expect either 2D grayscale or 3D RGB image.")
54+
55+
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="auto_segmentation")
4956
if segmentation_result is None:
50-
v.add_labels(data=np.zeros(raw.shape, dtype="uint32"), name="committed_objects")
57+
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="committed_objects")
5158
else:
5259
v.add_labels(segmentation_result, name="committed_objects")
53-
v.add_labels(data=np.zeros(raw.shape, dtype="uint32"), name="current_object")
60+
v.add_labels(data=np.zeros(shape, dtype="uint32"), name="current_object")
5461

5562
# show the PCA of the image embeddings
5663
if show_embeddings:
57-
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS["features"], raw.shape)
64+
embedding_vis, scale = project_embeddings_for_visualization(IMAGE_EMBEDDINGS["features"], shape)
5865
v.add_image(embedding_vis, name="embeddings", scale=scale)
5966

6067
labels = ["positive", "negative"]

micro_sam/util.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,18 @@ def get_sam_model(device=None, model_type="vit_h", checkpoint_path=None, return_
8484
return predictor
8585

8686

87+
def _to_image(input_):
88+
if input_.ndim == 2:
89+
image = np.concatenate([input_[..., None]] * 3, axis=-1)
90+
elif input_.ndim == 3 and input_.shape[-1] == 3:
91+
image = input_
92+
else:
93+
raise ValueError(f"Invalid input image of shape {input_.shape}. Expect either 2D grayscale or 3D RGB image.")
94+
return image
95+
96+
8797
def _compute_2d(input_, predictor):
88-
image = np.concatenate([input_[..., None]] * 3, axis=-1)
98+
image = _to_image(input_)
8999
predictor.set_image(image)
90100
features = predictor.get_image_embedding()
91101
original_size = predictor.original_size
@@ -103,7 +113,7 @@ def _precompute_2d(input_, predictor, save_path):
103113
features = f["features"][:]
104114
original_size, input_size = f.attrs["original_size"], f.attrs["input_size"]
105115
else:
106-
image = np.concatenate([input_[..., None]] * 3, axis=-1)
116+
image = _to_image(input_)
107117
predictor.set_image(image)
108118
features = predictor.get_image_embedding()
109119
original_size, input_size = predictor.original_size, predictor.input_size
@@ -186,7 +196,7 @@ def _precompute_3d(input_, predictor, save_path, lazy_loading):
186196
return image_embeddings
187197

188198

189-
def precompute_image_embeddings(predictor, input_, save_path=None, lazy_loading=False):
199+
def precompute_image_embeddings(predictor, input_, save_path=None, lazy_loading=False, ndim=None):
190200
"""Compute the image embeddings (output of the encoder) for the input.
191201
192202
If save_path is given the embeddings will be loaded/saved in a zarr container.
@@ -198,13 +208,15 @@ def precompute_image_embeddings(predictor, input_, save_path=None, lazy_loading=
198208
lazy_loading [bool] - whether to load all embeddings into memory or return an
199209
object to load them on demand when required. This only has an effect if 'save_path'
200210
is given and if the input is 3D. (default: False)
211+
ndim [int] - the dimensionality of the data. If not given will be deduced from the input data. (default: None)
201212
"""
202213

203-
if input_.ndim == 2:
214+
ndim = input_.ndim if ndim is None else ndim
215+
if ndim == 2:
204216
image_embeddings = _compute_2d(input_, predictor) if save_path is None else\
205217
_precompute_2d(input_, predictor, save_path)
206218

207-
elif input_.ndim == 3:
219+
elif ndim == 3:
208220
image_embeddings = _compute_3d(input_, predictor) if save_path is None else\
209221
_precompute_3d(input_, predictor, save_path, lazy_loading)
210222

0 commit comments

Comments
 (0)