Skip to content

Commit e417023

Browse files
authored
Add support for tiling window in evaluation scripts (#844)
Adds support for tiling window based prediciton by allowing an optional choice for tile shape and overlap shape
1 parent 031d9fb commit e417023

File tree

2 files changed

+85
-16
lines changed

2 files changed

+85
-16
lines changed

micro_sam/evaluation/inference.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
from tqdm import tqdm
88
from copy import deepcopy
9-
from typing import Any, Dict, List, Optional, Union
9+
from typing import Any, Dict, List, Optional, Union, Tuple
1010

1111
import imageio.v3 as imageio
1212
from skimage.segmentation import relabel_sequential
@@ -20,6 +20,7 @@
2020
from ..instance_segmentation import (
2121
mask_data_to_segmentation, get_predictor_and_decoder,
2222
AutomaticMaskGenerator, InstanceSegmentationWithDecoder,
23+
TiledAutomaticMaskGenerator, TiledInstanceSegmentationWithDecoder,
2324
)
2425
from . import instance_segmentation
2526
from ..prompt_generators import PointAndBoxPromptGenerator, IterativePromptGenerator
@@ -539,7 +540,8 @@ def run_amg(
539540
iou_thresh_values: Optional[List[float]] = None,
540541
stability_score_values: Optional[List[float]] = None,
541542
peft_kwargs: Optional[Dict] = None,
542-
cache_embeddings: bool = False
543+
cache_embeddings: bool = False,
544+
tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
543545
) -> str:
544546
"""Run Segment Anything inference for multiple images using automatic mask generation (AMG).
545547
@@ -554,6 +556,7 @@ def run_amg(
554556
stability_score_values: Optional choice of values for grid search of `stability_score` parameter.
555557
peft_kwargs: Keyword arguments for th PEFT wrapper class.
556558
cache_embeddings: Whether to cache embeddings in experiment folder.
559+
tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
557560
558561
Returns:
559562
Filepath where the predictions have been saved.
@@ -566,7 +569,23 @@ def run_amg(
566569
embedding_folder = None
567570

568571
predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs)
569-
amg = AutomaticMaskGenerator(predictor)
572+
573+
# Get the AMG class.
574+
if tiling_window_params:
575+
if not isinstance(tiling_window_params, dict):
576+
raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.")
577+
578+
if "tile_shape" not in tiling_window_params:
579+
raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.")
580+
581+
if "halo" not in tiling_window_params:
582+
raise RuntimeError("'halo' parameter is missing from the provided parameters.")
583+
584+
amg_class = TiledAutomaticMaskGenerator
585+
else:
586+
amg_class = AutomaticMaskGenerator
587+
588+
amg = amg_class(predictor)
570589
amg_prefix = "amg"
571590

572591
# where the predictions are saved
@@ -592,6 +611,7 @@ def run_amg(
592611
prediction_dir=prediction_folder,
593612
result_dir=gs_result_folder,
594613
experiment_folder=experiment_folder,
614+
tiling_window_params=tiling_window_params,
595615
)
596616
return prediction_folder
597617

@@ -610,6 +630,7 @@ def run_instance_segmentation_with_decoder(
610630
test_image_paths: List[Union[str, os.PathLike]],
611631
peft_kwargs: Optional[Dict] = None,
612632
cache_embeddings: bool = False,
633+
tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
613634
) -> str:
614635
"""Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS).
615636
@@ -622,6 +643,7 @@ def run_instance_segmentation_with_decoder(
622643
test_image_paths: The list of filepaths of input images for automatic instance segmentation.
623644
peft_kwargs: Keyword arguments for th PEFT wrapper class.
624645
cache_embeddings: Whether to cache embeddings in experiment folder.
646+
tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
625647
626648
Returns:
627649
Filepath where the predictions have been saved.
@@ -636,7 +658,23 @@ def run_instance_segmentation_with_decoder(
636658
predictor, decoder = get_predictor_and_decoder(
637659
model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs,
638660
)
639-
segmenter = InstanceSegmentationWithDecoder(predictor, decoder)
661+
662+
# Get the AIS class.
663+
if tiling_window_params:
664+
if not isinstance(tiling_window_params, dict):
665+
raise RuntimeError("The tiling window parameters are expected to be provided as a dictionary of params.")
666+
667+
if "tile_shape" not in tiling_window_params:
668+
raise RuntimeError("'tile_shape' parameter is missing from the provided parameters.")
669+
670+
if "halo" not in tiling_window_params:
671+
raise RuntimeError("'halo' parameter is missing from the provided parameters.")
672+
673+
ais_class = TiledInstanceSegmentationWithDecoder
674+
else:
675+
ais_class = InstanceSegmentationWithDecoder
676+
677+
segmenter = ais_class(predictor, decoder)
640678
seg_prefix = "instance_segmentation_with_decoder"
641679

642680
# where the predictions are saved
@@ -650,9 +688,15 @@ def run_instance_segmentation_with_decoder(
650688
grid_search_values = instance_segmentation.default_grid_search_values_instance_segmentation_with_decoder()
651689

652690
instance_segmentation.run_instance_segmentation_grid_search_and_inference(
653-
segmenter, grid_search_values,
654-
val_image_paths, val_gt_paths, test_image_paths,
655-
embedding_dir=embedding_folder, prediction_dir=prediction_folder,
656-
result_dir=gs_result_folder, experiment_folder=experiment_folder,
691+
segmenter=segmenter,
692+
grid_search_values=grid_search_values,
693+
val_image_paths=val_image_paths,
694+
val_gt_paths=val_gt_paths,
695+
test_image_paths=test_image_paths,
696+
embedding_dir=embedding_folder,
697+
prediction_dir=prediction_folder,
698+
result_dir=gs_result_folder,
699+
experiment_folder=experiment_folder,
700+
tiling_window_params=tiling_window_params,
657701
)
658702
return prediction_folder

micro_sam/evaluation/instance_segmentation.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def run_instance_segmentation_grid_search(
156156
image_key: Optional[str] = None,
157157
gt_key: Optional[str] = None,
158158
rois: Optional[Tuple[slice, ...]] = None,
159+
tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
159160
) -> None:
160161
"""Run grid search for automatic mask generation.
161162
@@ -188,6 +189,7 @@ def run_instance_segmentation_grid_search(
188189
gt_key: Key for loading the ground-truth data from a more complex file format like HDF5.
189190
If not given a simple image format like tif is assumed.
190191
rois: Region of interests to resetrict the evaluation to.
192+
tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
191193
"""
192194
verbose_embeddings = False
193195

@@ -233,11 +235,14 @@ def run_instance_segmentation_grid_search(
233235
assert predictor is not None
234236
embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
235237

238+
if tiling_window_params is None:
239+
tiling_window_params = {}
240+
236241
image_embeddings = util.precompute_image_embeddings(
237-
predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
242+
predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params
238243
)
239244

240-
segmenter.initialize(image, image_embeddings)
245+
segmenter.initialize(image, image_embeddings, **tiling_window_params)
241246

242247
_grid_search_iteration(
243248
segmenter, gs_combinations, gt, image_name,
@@ -251,6 +256,7 @@ def run_instance_segmentation_inference(
251256
embedding_dir: Optional[Union[str, os.PathLike]],
252257
prediction_dir: Union[str, os.PathLike],
253258
generate_kwargs: Optional[Dict[str, Any]] = None,
259+
tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
254260
) -> None:
255261
"""Run inference for automatic mask generation.
256262
@@ -260,6 +266,8 @@ def run_instance_segmentation_inference(
260266
embedding_dir: Folder to cache the image embeddings.
261267
prediction_dir: Folder to save the predictions.
262268
generate_kwargs: The keyword arguments for the `generate` method of the segmenter.
269+
tiling_window_params: The parameters to decide whether to use tiling window operation
270+
for automatic segmentation.
263271
"""
264272

265273
verbose_embeddings = False
@@ -285,11 +293,14 @@ def run_instance_segmentation_inference(
285293
assert predictor is not None
286294
embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr")
287295

296+
if tiling_window_params is None:
297+
tiling_window_params = {}
298+
288299
image_embeddings = util.precompute_image_embeddings(
289-
predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings
300+
predictor, image, embedding_path, ndim=2, verbose=verbose_embeddings, **tiling_window_params
290301
)
291302

292-
segmenter.initialize(image, image_embeddings)
303+
segmenter.initialize(image, image_embeddings, **tiling_window_params)
293304

294305
masks = segmenter.generate(**generate_kwargs)
295306

@@ -372,6 +383,7 @@ def run_instance_segmentation_grid_search_and_inference(
372383
result_dir: Union[str, os.PathLike],
373384
fixed_generate_kwargs: Optional[Dict[str, Any]] = None,
374385
verbose_gs: bool = True,
386+
tiling_window_params: Optional[Dict[str, Tuple[int, int]]] = None,
375387
) -> None:
376388
"""Run grid search and inference for automatic mask generation.
377389
@@ -390,11 +402,19 @@ def run_instance_segmentation_grid_search_and_inference(
390402
result_dir: Folder to cache the evaluation results per image.
391403
fixed_generate_kwargs: Fixed keyword arguments for the `generate` method of the segmenter.
392404
verbose_gs: Whether to run the gridsearch for individual images in a verbose mode.
405+
tiling_window_params: The parameters to decide whether to use tiling window operation
406+
for automatic segmentation.
393407
"""
394408
run_instance_segmentation_grid_search(
395-
segmenter, grid_search_values, val_image_paths, val_gt_paths,
396-
result_dir=result_dir, embedding_dir=embedding_dir,
397-
fixed_generate_kwargs=fixed_generate_kwargs, verbose_gs=verbose_gs,
409+
segmenter=segmenter,
410+
grid_search_values=grid_search_values,
411+
image_paths=val_image_paths,
412+
gt_paths=val_gt_paths,
413+
result_dir=result_dir,
414+
embedding_dir=embedding_dir,
415+
fixed_generate_kwargs=fixed_generate_kwargs,
416+
verbose_gs=verbose_gs,
417+
tiling_window_params=tiling_window_params,
398418
)
399419

400420
best_kwargs, best_msa = evaluate_instance_segmentation_grid_search(result_dir, list(grid_search_values.keys()))
@@ -408,5 +428,10 @@ def run_instance_segmentation_grid_search_and_inference(
408428
generate_kwargs.update(best_kwargs)
409429

410430
run_instance_segmentation_inference(
411-
segmenter, test_image_paths, embedding_dir, prediction_dir, generate_kwargs
431+
segmenter=segmenter,
432+
image_paths=test_image_paths,
433+
embedding_dir=embedding_dir,
434+
prediction_dir=prediction_dir,
435+
generate_kwargs=generate_kwargs,
436+
tiling_window_params=tiling_window_params,
412437
)

0 commit comments

Comments
 (0)