66import numpy as np
77from tqdm import tqdm
88from copy import deepcopy
9- from typing import Any , Dict , List , Optional , Union
9+ from typing import Any , Dict , List , Optional , Union , Tuple
1010
1111import imageio .v3 as imageio
1212from skimage .segmentation import relabel_sequential
2020from ..instance_segmentation import (
2121 mask_data_to_segmentation , get_predictor_and_decoder ,
2222 AutomaticMaskGenerator , InstanceSegmentationWithDecoder ,
23+ TiledAutomaticMaskGenerator , TiledInstanceSegmentationWithDecoder ,
2324)
2425from . import instance_segmentation
2526from ..prompt_generators import PointAndBoxPromptGenerator , IterativePromptGenerator
@@ -539,7 +540,8 @@ def run_amg(
539540 iou_thresh_values : Optional [List [float ]] = None ,
540541 stability_score_values : Optional [List [float ]] = None ,
541542 peft_kwargs : Optional [Dict ] = None ,
542- cache_embeddings : bool = False
543+ cache_embeddings : bool = False ,
544+ tiling_window_params : Optional [Dict [str , Tuple [int , int ]]] = None ,
543545) -> str :
544546 """Run Segment Anything inference for multiple images using automatic mask generation (AMG).
545547
@@ -554,6 +556,7 @@ def run_amg(
554556 stability_score_values: Optional choice of values for grid search of `stability_score` parameter.
555557 peft_kwargs: Keyword arguments for th PEFT wrapper class.
556558 cache_embeddings: Whether to cache embeddings in experiment folder.
559+ tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
557560
558561 Returns:
559562 Filepath where the predictions have been saved.
@@ -566,7 +569,23 @@ def run_amg(
566569 embedding_folder = None
567570
568571 predictor = util .get_sam_model (model_type = model_type , checkpoint_path = checkpoint , peft_kwargs = peft_kwargs )
569- amg = AutomaticMaskGenerator (predictor )
572+
573+ # Get the AMG class.
574+ if tiling_window_params :
575+ if not isinstance (tiling_window_params , dict ):
576+ raise RuntimeError ("The tiling window parameters are expected to be provided as a dictionary of params." )
577+
578+ if "tile_shape" not in tiling_window_params :
579+ raise RuntimeError ("'tile_shape' parameter is missing from the provided parameters." )
580+
581+ if "halo" not in tiling_window_params :
582+ raise RuntimeError ("'halo' parameter is missing from the provided parameters." )
583+
584+ amg_class = TiledAutomaticMaskGenerator
585+ else :
586+ amg_class = AutomaticMaskGenerator
587+
588+ amg = amg_class (predictor )
570589 amg_prefix = "amg"
571590
572591 # where the predictions are saved
@@ -592,6 +611,7 @@ def run_amg(
592611 prediction_dir = prediction_folder ,
593612 result_dir = gs_result_folder ,
594613 experiment_folder = experiment_folder ,
614+ tiling_window_params = tiling_window_params ,
595615 )
596616 return prediction_folder
597617
@@ -610,6 +630,7 @@ def run_instance_segmentation_with_decoder(
610630 test_image_paths : List [Union [str , os .PathLike ]],
611631 peft_kwargs : Optional [Dict ] = None ,
612632 cache_embeddings : bool = False ,
633+ tiling_window_params : Optional [Dict [str , Tuple [int , int ]]] = None ,
613634) -> str :
614635 """Run Segment Anything inference for multiple images using additional automatic instance segmentation (AIS).
615636
@@ -622,6 +643,7 @@ def run_instance_segmentation_with_decoder(
622643 test_image_paths: The list of filepaths of input images for automatic instance segmentation.
623644 peft_kwargs: Keyword arguments for th PEFT wrapper class.
624645 cache_embeddings: Whether to cache embeddings in experiment folder.
646+ tiling_window_params: The parameters to decide whether to use tiling window operation for AIS.
625647
626648 Returns:
627649 Filepath where the predictions have been saved.
@@ -636,7 +658,23 @@ def run_instance_segmentation_with_decoder(
636658 predictor , decoder = get_predictor_and_decoder (
637659 model_type = model_type , checkpoint_path = checkpoint , peft_kwargs = peft_kwargs ,
638660 )
639- segmenter = InstanceSegmentationWithDecoder (predictor , decoder )
661+
662+ # Get the AIS class.
663+ if tiling_window_params :
664+ if not isinstance (tiling_window_params , dict ):
665+ raise RuntimeError ("The tiling window parameters are expected to be provided as a dictionary of params." )
666+
667+ if "tile_shape" not in tiling_window_params :
668+ raise RuntimeError ("'tile_shape' parameter is missing from the provided parameters." )
669+
670+ if "halo" not in tiling_window_params :
671+ raise RuntimeError ("'halo' parameter is missing from the provided parameters." )
672+
673+ ais_class = TiledInstanceSegmentationWithDecoder
674+ else :
675+ ais_class = InstanceSegmentationWithDecoder
676+
677+ segmenter = ais_class (predictor , decoder )
640678 seg_prefix = "instance_segmentation_with_decoder"
641679
642680 # where the predictions are saved
@@ -650,9 +688,15 @@ def run_instance_segmentation_with_decoder(
650688 grid_search_values = instance_segmentation .default_grid_search_values_instance_segmentation_with_decoder ()
651689
652690 instance_segmentation .run_instance_segmentation_grid_search_and_inference (
653- segmenter , grid_search_values ,
654- val_image_paths , val_gt_paths , test_image_paths ,
655- embedding_dir = embedding_folder , prediction_dir = prediction_folder ,
656- result_dir = gs_result_folder , experiment_folder = experiment_folder ,
691+ segmenter = segmenter ,
692+ grid_search_values = grid_search_values ,
693+ val_image_paths = val_image_paths ,
694+ val_gt_paths = val_gt_paths ,
695+ test_image_paths = test_image_paths ,
696+ embedding_dir = embedding_folder ,
697+ prediction_dir = prediction_folder ,
698+ result_dir = gs_result_folder ,
699+ experiment_folder = experiment_folder ,
700+ tiling_window_params = tiling_window_params ,
657701 )
658702 return prediction_folder
0 commit comments