Skip to content

Commit 01756a0

Browse files
authored
Extend support for automatic seg functionality to return embeddings (#855)
Extend support for automatic segmentation functionality to return embeddings
1 parent 78f1c8e commit 01756a0

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

micro_sam/automatic_segmentation.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def automatic_instance_segmentation(
7474
tile_shape: Optional[Tuple[int, int]] = None,
7575
halo: Optional[Tuple[int, int]] = None,
7676
verbose: bool = True,
77+
return_embeddings: bool = False,
7778
**generate_kwargs
7879
) -> np.ndarray:
7980
"""Run automatic segmentation for the input image.
@@ -92,6 +93,7 @@ def automatic_instance_segmentation(
9293
tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
9394
halo: Overlap of the tiles for tiled prediction.
9495
verbose: Verbosity flag.
96+
return_embeddings: Whether to return the precomputed image embeddings.
9597
generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class.
9698
9799
Returns:
@@ -142,23 +144,32 @@ def automatic_instance_segmentation(
142144
if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3):
143145
raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}")
144146

145-
instances = automatic_3d_segmentation(
147+
outputs = automatic_3d_segmentation(
146148
volume=image_data,
147149
predictor=predictor,
148150
segmentor=segmenter,
149151
embedding_path=embedding_path,
150152
tile_shape=tile_shape,
151153
halo=halo,
152154
verbose=verbose,
155+
return_embeddings=return_embeddings,
153156
**generate_kwargs
154157
)
155158

159+
if return_embeddings:
160+
instances, image_embeddings = outputs
161+
else:
162+
instances = outputs
163+
164+
# Save the instance segmentation, if 'output_path' provided.
156165
if output_path is not None:
157-
# Save the instance segmentation
158166
output_path = Path(output_path).with_suffix(".tif")
159167
imageio.imwrite(output_path, instances, compression="zlib")
160168

161-
return instances
169+
if return_embeddings:
170+
return instances, image_embeddings
171+
else:
172+
return instances
162173

163174

164175
def main():
@@ -194,8 +205,7 @@ def main():
194205
help=f"The segment anything model that will be used, one of {available_models}."
195206
)
196207
parser.add_argument(
197-
"-c", "--checkpoint", default=None,
198-
help="Checkpoint from which the SAM model will be loaded."
208+
"-c", "--checkpoint", default=None, help="Checkpoint from which the SAM model will be loaded."
199209
)
200210
parser.add_argument(
201211
"--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None

micro_sam/instance_segmentation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,7 @@ def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, ha
572572
# Use tile shape and halo from the precomputed embeddings if not given.
573573
# Otherwise check that they are consistent.
574574
feats = image_embeddings["features"]
575-
tile_shape_, halo_ = feats.attrs["tile_shape"], feats.attrs["halo"]
575+
tile_shape_, halo_ = tuple(feats.attrs["tile_shape"]), tuple(feats.attrs["halo"])
576576
if tile_shape is None:
577577
tile_shape = tile_shape_
578578
elif tile_shape != tile_shape_:
@@ -835,7 +835,7 @@ def get_predictor_and_decoder(
835835
model_type: The type of the image encoder used in the SAM model.
836836
checkpoint_path: Path to the checkpoint from which to load the data.
837837
device: The device.
838-
peft_kwargs: Keyword arguments for th PEFT wrapper class.
838+
peft_kwargs: Keyword arguments for the PEFT wrapper class.
839839
840840
Returns:
841841
The SAM predictor.
@@ -1160,6 +1160,8 @@ def initialize(
11601160
See `util.precompute_image_embeddings` for details.
11611161
i: Index for the image data. Required if `image` has three spatial dimensions
11621162
or a time dimension and two spatial dimensions.
1163+
tile_shape: Shape of the tiles for precomputing image embeddings.
1164+
halo: Overlap of the tiles for tiled precomputation of image embeddings.
11631165
verbose: Dummy input to be compatible with other function signatures.
11641166
pbar_init: Callback to initialize an external progress bar. Must accept number of steps and description.
11651167
Can be used together with pbar_update to handle napari progress bar in other thread.

micro_sam/multi_dimensional_segmentation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ def automatic_3d_segmentation(
368368
tile_shape: Optional[Tuple[int, int]] = None,
369369
halo: Optional[Tuple[int, int]] = None,
370370
verbose: bool = True,
371+
return_embeddings: bool = False,
371372
**kwargs,
372373
) -> np.ndarray:
373374
"""Segment volume in 3d.
@@ -388,6 +389,7 @@ def automatic_3d_segmentation(
388389
tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling.
389390
halo: Overlap of the tiles for tiled prediction.
390391
verbose: Verbosity flag.
392+
return_embeddings: Whether to return the precomputed image embeddings.
391393
kwargs: Keyword arguments for the 'generate' method of the 'segmentor'.
392394
393395
Returns:
@@ -430,4 +432,7 @@ def automatic_3d_segmentation(
430432
verbose=verbose,
431433
)
432434

433-
return segmentation
435+
if return_embeddings:
436+
return segmentation, image_embeddings
437+
else:
438+
return segmentation

0 commit comments

Comments
 (0)