Skip to content

Commit ea58712

Browse files
Make tiling parameters in initialize for tiled embedding mask generator optional if embeds are passed
1 parent ba2857a commit ea58712

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

micro_sam/instance_segmentation.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import multiprocessing as mp
2+
import warnings
23
from abc import ABC
34
from concurrent import futures
45
from copy import deepcopy
@@ -619,20 +620,43 @@ def _compute_mask_data_tiled(self, image_embeddings, i, initial_segmentations, n
619620
def initialize(
620621
self,
621622
image: np.ndarray,
622-
tile_shape: List[int],
623-
halo: List[int],
624623
image_embeddings=None,
625624
i: Optional[int] = None,
625+
tile_shape: Optional[List[int]] = None,
626+
halo: Optional[List[int]] = None,
626627
verbose: bool = False,
627628
embedding_save_path: Optional[str] = None,
628629
):
629630
"""
630631
"""
631632
original_size = image.shape[:2]
632-
if image_embeddings is None:
633+
634+
have_tiling_params = (tile_shape is not None) and (halo is not None)
635+
if image_embeddings is None and have_tiling_params:
636+
if embedding_save_path is None:
637+
raise ValueError(
638+
"You have passed neither pre-computed embeddings nor a path for saving embeddings."
639+
"Embeddings with tiling can only be computed if a save path is given."
640+
)
633641
image_embeddings = util.precompute_image_embeddings(
634642
self.predictor, image, tile_shape=tile_shape, halo=halo, save_path=embedding_save_path
635643
)
644+
elif image_embeddings is None and not have_tiling_params:
645+
raise ValueError("You passed neither pre-computed embeddings nor tiling parameters (tile_shape and halo)")
646+
else:
647+
feats = image_embeddings["features"]
648+
tile_shape_, halo_ = feats.attrs["tile_shape"], feats.attrs["halo"]
649+
if have_tiling_params and (
650+
(list(tile_shape) != list(tile_shape_)) or
651+
(list(halo) != list(halo_))
652+
):
653+
warnings.warn(
654+
"You have passed both pre-computed embeddings and tiling parameters (tile_shape and halo) and"
655+
"the values of the tiling parameters from the embeddings disagree with the ones that were passed."
656+
"The tiling parameters you have passed wil be ignored."
657+
)
658+
tile_shape = tile_shape_
659+
halo = halo_
636660

637661
tiling = blocking([0, 0], original_size, tile_shape)
638662
n_tiles = tiling.numberOfBlocks

test/test_instance_segmentation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_tiled_embedding_mask_generator(self):
102102
predictor, image_embeddings = self._get_model(image, tile_shape, halo, self.embedding_path)
103103

104104
amg = TiledEmbeddingMaskGenerator(predictor)
105-
amg.initialize(image, image_embeddings=image_embeddings, tile_shape=tile_shape, halo=halo)
105+
amg.initialize(image, image_embeddings=image_embeddings)
106106
predicted = amg.generate(pred_iou_thresh=0.96)
107107
initial_seg = amg.get_initial_segmentation()
108108

0 commit comments

Comments
 (0)