Skip to content

Commit 98385f4

Browse files
Update embedding visualization and add support for ROIs in SAM datasets (#1065)
* Update embedding visualization code * Add roi support for SAM datasets * Update doc strings
1 parent 63f9c00 commit 98385f4

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

micro_sam/training/training.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ def default_sam_dataset(
581581
is_train: bool = True,
582582
min_size: int = 25,
583583
max_sampling_attempts: Optional[int] = None,
584+
rois: Optional[Union[slice, Tuple[slice, ...]]] = None,
584585
**kwargs,
585586
) -> Dataset:
586587
"""Create a PyTorch Dataset for training a SAM model.
@@ -606,6 +607,7 @@ def default_sam_dataset(
606607
is_train: Whether this dataset is used for training or validation. By default, set to 'True'.
607608
min_size: Minimal object size. Smaller objects will be filtered. By default, set to '25'.
608609
max_sampling_attempts: Number of sampling attempts to make from a dataset.
610+
rois: The region of interest(s) for the data.
609611
kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`.
610612
611613
Returns:
@@ -702,6 +704,7 @@ def default_sam_dataset(
702704
ndim=2,
703705
is_seg_dataset=is_seg_dataset,
704706
raw_transform=raw_transform,
707+
rois=rois,
705708
**kwargs
706709
)
707710
n_samples = max(len(loader), 100 if is_train else 5)
@@ -719,6 +722,7 @@ def default_sam_dataset(
719722
sampler=sampler,
720723
n_samples=n_samples,
721724
is_seg_dataset=is_seg_dataset,
725+
rois=rois,
722726
**kwargs,
723727
)
724728

micro_sam/visualization.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,23 @@
1717
# PCA visualization for the image embeddings
1818
#
1919

20-
def compute_pca(embeddings: np.ndarray) -> np.ndarray:
20+
def compute_pca(embeddings: np.ndarray, n_components: int = 3, as_rgb: bool = True) -> np.ndarray:
2121
"""Compute the pca projection of the embeddings to visualize them as RGB image.
2222
2323
Args:
2424
embeddings: The embeddings. For example predicted by the SAM image encoder.
25+
n_components: The number of PCA components to use for dimensionality reduction.
26+
as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb.
2527
2628
Returns:
2729
PCA of the embeddings, mapped to the pixels.
2830
"""
2931
if embeddings.ndim == 4:
30-
pca = embedding_pca(embeddings.squeeze()).transpose((1, 2, 0))
32+
pca = embedding_pca(embeddings.squeeze(), n_components=n_components, as_rgb=as_rgb).transpose((1, 2, 0))
3133
elif embeddings.ndim == 5:
3234
pca = []
3335
for embed in embeddings:
34-
vis = embedding_pca(embed.squeeze()).transpose((1, 2, 0))
36+
vis = embedding_pca(embed.squeeze(), n_components=n_components, as_rgb=as_rgb).transpose((1, 2, 0))
3537
pca.append(vis)
3638
pca = np.stack(pca)
3739
else:
@@ -53,10 +55,10 @@ def _get_crop(embed_shape, shape):
5355
return crop
5456

5557

56-
def _project_embeddings(embeddings, shape, apply_crop=True):
58+
def _project_embeddings(embeddings, shape, apply_crop=True, n_components=3, as_rgb=True):
5759
assert embeddings.ndim == len(shape) + 2, f"{embeddings.shape}, {shape}"
5860

59-
embedding_vis = compute_pca(embeddings)
61+
embedding_vis = compute_pca(embeddings, n_components=n_components, as_rgb=as_rgb)
6062
if not apply_crop:
6163
pass
6264
elif len(shape) == 2:
@@ -107,7 +109,7 @@ def resize_shape(shape):
107109
return np.concatenate([resize(arr, resize_shape(arr.shape)) for arr in arrays], axis=axis)
108110

109111

110-
def _project_tiled_embeddings(image_embeddings):
112+
def _project_tiled_embeddings(image_embeddings, n_components, as_rgb):
111113
features = image_embeddings["features"]
112114
tile_shape, halo, shape = features.attrs["tile_shape"], features.attrs["halo"], features.attrs["shape"]
113115
tiling = blocking([0, 0], shape, tile_shape)
@@ -141,30 +143,34 @@ def _project_tiled_embeddings(image_embeddings):
141143

142144
if features["0"].ndim == 5:
143145
shape = (features["0"].shape[0],) + tuple(shape)
144-
embedding_vis, scale = _project_embeddings(embeds, shape, apply_crop=False)
146+
embedding_vis, scale = _project_embeddings(
147+
embeds, shape, n_components=n_components, as_rgb=as_rgb, apply_crop=False
148+
)
145149
return embedding_vis, scale
146150

147151

148152
def project_embeddings_for_visualization(
149-
image_embeddings: ImageEmbeddings
153+
image_embeddings: ImageEmbeddings, n_components: int = 3, as_rgb: bool = True,
150154
) -> Tuple[np.ndarray, Tuple[float, ...]]:
151155
"""Project image embeddings to pixel-wise PCA.
152156
153157
Args:
154158
image_embeddings: The image embeddings.
159+
n_components: The number of PCA components to use for dimensionality reduction.
160+
as_rgb: Whether to normalize the projected embeddings so that they can be displated as rgb.
155161
156162
Returns:
157163
The PCA of the embeddings.
158164
The scale factor for resizing to the original image size.
159165
"""
160166
is_tiled = image_embeddings["input_size"] is None
161167
if is_tiled:
162-
embedding_vis, scale = _project_tiled_embeddings(image_embeddings)
168+
embedding_vis, scale = _project_tiled_embeddings(image_embeddings, n_components, as_rgb)
163169
else:
164170
embeddings = image_embeddings["features"]
165171
shape = tuple(image_embeddings["original_size"])
166172
if embeddings.ndim == 5:
167173
shape = (embeddings.shape[0],) + shape
168-
embedding_vis, scale = _project_embeddings(embeddings, shape)
174+
embedding_vis, scale = _project_embeddings(embeddings, shape, n_components=n_components, as_rgb=as_rgb)
169175

170176
return embedding_vis, scale

0 commit comments

Comments
 (0)