Skip to content

Commit 748277b

Browse files
Simplify initialize functions by removing embedding_path
1 parent 5295f91 commit 748277b

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

micro_sam/instance_segmentation.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def _process_crop(self, image, crop_box, crop_layer_idx, verbose, precomputed_em
320320
return data
321321

322322
@torch.no_grad()
323-
def initialize(self, image: np.ndarray, image_embeddings=None, i=None, embedding_path=None, verbose=False):
323+
def initialize(self, image: np.ndarray, image_embeddings=None, i=None, verbose=False):
324324
"""
325325
"""
326326
original_size = image.shape[:2]
@@ -333,7 +333,7 @@ def initialize(self, image: np.ndarray, image_embeddings=None, i=None, embedding
333333
# otherwise we have to recompute the embeddings for each crop and can't precompute
334334
if len(crop_boxes) == 1:
335335
if image_embeddings is None:
336-
image_embeddings = util.precompute_image_embeddings(self.predictor, image, save_path=embedding_path)
336+
image_embeddings = util.precompute_image_embeddings(self.predictor, image)
337337
util.set_precomputed(self.predictor, image_embeddings, i=i)
338338
precomputed_embeddings = True
339339
else:
@@ -475,13 +475,13 @@ def _compute_mask_data(self, initial_segmentation, original_size, verbose):
475475
return mask_data
476476

477477
@torch.no_grad()
478-
def initialize(self, image: np.ndarray, image_embeddings=None, i=None, embedding_path=None, verbose=False):
478+
def initialize(self, image: np.ndarray, image_embeddings=None, i=None, verbose=False):
479479
"""
480480
"""
481481
original_size = image.shape[:2]
482482

483483
if image_embeddings is None:
484-
image_embeddings = util.precompute_image_embeddings(self.predictor, image, save_path=embedding_path)
484+
image_embeddings = util.precompute_image_embeddings(self.predictor, image,)
485485
util.set_precomputed(self.predictor, image_embeddings, i=i)
486486

487487
# compute the initial segmentation via embedding based MWS and then refine the masks
@@ -600,16 +600,13 @@ def initialize(
600600
halo: List[int],
601601
image_embeddings=None,
602602
i=None,
603-
embedding_path=None,
604603
verbose=False,
605604
):
606605
"""
607606
"""
608607
original_size = image.shape[:2]
609608
if image_embeddings is None:
610-
image_embeddings = util.precompute_image_embeddings(
611-
self.predictor, image, save_path=embedding_path, tile_shape=tile_shape, halo=halo,
612-
)
609+
image_embeddings = util.precompute_image_embeddings(self.predictor, image, tile_shape=tile_shape, halo=halo)
613610

614611
tiling = blocking([0, 0], original_size, tile_shape)
615612
n_tiles = tiling.numberOfBlocks

0 commit comments

Comments
 (0)