Skip to content

Commit b70f0ae

Browse files
Ensure 2d embeddings in micro_sam.evaluation
1 parent c44d33d commit b70f0ae

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

micro_sam/evaluation/automatic_mask_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def run_amg_grid_search(
122122
gt = imageio.imread(gt_path)
123123

124124
embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
125-
image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path)
125+
image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2)
126126
amg.initialize(image, image_embeddings)
127127

128128
_grid_search(
@@ -170,7 +170,7 @@ def run_amg_inference(
170170
image = imageio.imread(image_path)
171171

172172
embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
173-
image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path)
173+
image_embeddings = util.precompute_image_embeddings(predictor, image, embedding_path, ndim=2)
174174

175175
amg.initialize(image, image_embeddings)
176176
masks = amg.generate(**amg_generate_kwargs)

micro_sam/evaluation/inference.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def precompute_all_embeddings(
211211
predictor: SamPredictor,
212212
image_paths: List[Union[str, os.PathLike]],
213213
embedding_dir: Union[str, os.PathLike],
214-
):
214+
) -> None:
215215
"""Precompute all image embeddings.
216216
217217
To enable running different inference tasks in parallel afterwards.
@@ -225,7 +225,7 @@ def precompute_all_embeddings(
225225
image_name = os.path.basename(image_path)
226226
im = imageio.imread(image_path)
227227
embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
228-
util.precompute_image_embeddings(predictor, im, embedding_path)
228+
util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)
229229

230230

231231
def _precompute_prompts(gt_path, use_points, use_boxes, n_positives, n_negatives, dilation, transform_function):
@@ -392,7 +392,7 @@ def run_inference_with_prompts(
392392
gt = relabel_sequential(gt)[0]
393393

394394
embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
395-
image_embeddings = util.precompute_image_embeddings(predictor, im, embedding_path)
395+
image_embeddings = util.precompute_image_embeddings(predictor, im, embedding_path, ndim=2)
396396
util.set_precomputed(predictor, image_embeddings)
397397

398398
this_prompts, cached_point_prompts, cached_box_prompts = _load_prompts(

0 commit comments

Comments
 (0)