1212
1313import vigra
1414import numpy as np
15+ import elf .parallel as parallel
16+ from elf .parallel .filters import apply_filter
1517from skimage .measure import label , regionprops
1618from skimage .segmentation import relabel_sequential
1719
@@ -559,11 +561,13 @@ def generate(
559561
560562
561563# Helper function for tiled embedding computation and checking consistent state.
562- def _process_tiled_embeddings (predictor , image , image_embeddings , tile_shape , halo ):
564+ def _process_tiled_embeddings (predictor , image , image_embeddings , tile_shape , halo , verbose ):
563565 if image_embeddings is None :
564566 if tile_shape is None or halo is None :
565567 raise ValueError ("To compute tiled embeddings the parameters tile_shape and halo have to be passed." )
566- image_embeddings = util .precompute_image_embeddings (predictor , image , tile_shape = tile_shape , halo = halo )
568+ image_embeddings = util .precompute_image_embeddings (
569+ predictor , image , tile_shape = tile_shape , halo = halo , verbose = verbose
570+ )
567571
568572 # Use tile shape and halo from the precomputed embeddings if not given.
569573 # Otherwise check that they are consistent.
@@ -650,7 +654,7 @@ def initialize(
650654 self ._original_size = original_size
651655
652656 image_embeddings , tile_shape , halo = _process_tiled_embeddings (
653- self ._predictor , image , image_embeddings , tile_shape , halo
657+ self ._predictor , image , image_embeddings , tile_shape , halo , verbose = verbose ,
654658 )
655659
656660 tiling = blocking ([0 , 0 ], original_size , tile_shape )
@@ -853,6 +857,55 @@ def get_predictor_and_decoder(
853857 return predictor , decoder
854858
855859
860+ def _watershed_from_center_and_boundary_distances_parallel (
861+ center_distances ,
862+ boundary_distances ,
863+ foreground_map ,
864+ center_distance_threshold ,
865+ boundary_distance_threshold ,
866+ foreground_threshold ,
867+ distance_smoothing ,
868+ min_size ,
869+ tile_shape ,
870+ halo ,
871+ n_threads ,
872+ verbose = False ,
873+ ):
874+ center_distances = apply_filter (
875+ center_distances , "gaussianSmoothing" , sigma = distance_smoothing ,
876+ block_shape = tile_shape , n_threads = n_threads
877+ )
878+ boundary_distances = apply_filter (
879+ boundary_distances , "gaussianSmoothing" , sigma = distance_smoothing ,
880+ block_shape = tile_shape , n_threads = n_threads
881+ )
882+
883+ fg_mask = foreground_map > foreground_threshold
884+
885+ marker_map = np .logical_and (
886+ center_distances < center_distance_threshold , boundary_distances < boundary_distance_threshold
887+ )
888+ marker_map [~ fg_mask ] = 0
889+
890+ markers = np .zeros (marker_map .shape , dtype = "uint64" )
891+ markers = parallel .label (
892+ marker_map , out = markers , block_shape = tile_shape , n_threads = n_threads , verbose = verbose ,
893+ )
894+
895+ seg = np .zeros_like (markers , dtype = "uint64" )
896+ seg = parallel .seeded_watershed (
897+ boundary_distances , seeds = markers , out = seg , block_shape = tile_shape ,
898+ halo = halo , n_threads = n_threads , verbose = verbose , mask = fg_mask ,
899+ )
900+
901+ out = np .zeros_like (seg , dtype = "uint64" )
902+ out = parallel .size_filter (
903+ seg , out = out , min_size = min_size , block_shape = tile_shape , n_threads = n_threads , verbose = verbose
904+ )
905+
906+ return out
907+
908+
856909class InstanceSegmentationWithDecoder :
857910 """Generates an instance segmentation without prompts, using a decoder.
858911
@@ -988,6 +1041,9 @@ def generate(
9881041 distance_smoothing : float = 1.6 ,
9891042 min_size : int = 0 ,
9901043 output_mode : Optional [str ] = "binary_mask" ,
1044+ tile_shape : Optional [Tuple [int , int ]] = None ,
1045+ halo : Optional [Tuple [int , int ]] = None ,
1046+ n_threads : Optional [int ] = None ,
9911047 ) -> List [Dict [str , Any ]]:
9921048 """Generate instance segmentation for the currently initialized image.
9931049
@@ -1002,6 +1058,11 @@ def generate(
10021058 distance_smoothing: Sigma value for smoothing the distance predictions.
10031059 min_size: Minimal object size in the segmentation result.
10041060 output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
1061+ tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1062+ This parameter is independent from the tile shape for computing the embeddings.
1063+ If not given then post-processing will not be parallelized.
1064+ halo: Halo for parallel post-processing. See also `tile_shape`.
1065+ n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
10051066
10061067 Returns:
10071068 The instance segmentation masks.
@@ -1013,16 +1074,29 @@ def generate(
10131074 foreground = vigra .filters .gaussianSmoothing (self ._foreground , foreground_smoothing )
10141075 else :
10151076 foreground = self ._foreground
1016- # Further optimization: parallel implementation using elf.parallel functionality.
1017- # (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization)
1018- segmentation = watershed_from_center_and_boundary_distances (
1019- self ._center_distances , self ._boundary_distances , foreground ,
1020- center_distance_threshold = center_distance_threshold ,
1021- boundary_distance_threshold = boundary_distance_threshold ,
1022- foreground_threshold = foreground_threshold ,
1023- distance_smoothing = distance_smoothing ,
1024- min_size = min_size ,
1025- )
1077+
1078+ if tile_shape is None :
1079+ segmentation = watershed_from_center_and_boundary_distances (
1080+ self ._center_distances , self ._boundary_distances , foreground ,
1081+ center_distance_threshold = center_distance_threshold ,
1082+ boundary_distance_threshold = boundary_distance_threshold ,
1083+ foreground_threshold = foreground_threshold ,
1084+ distance_smoothing = distance_smoothing ,
1085+ min_size = min_size ,
1086+ )
1087+ else :
1088+ if halo is None :
1089+ raise ValueError ("You must pass a value for halo if tile_shape is given." )
1090+ segmentation = _watershed_from_center_and_boundary_distances_parallel (
1091+ self ._center_distances , self ._boundary_distances , foreground ,
1092+ center_distance_threshold = center_distance_threshold ,
1093+ boundary_distance_threshold = boundary_distance_threshold ,
1094+ foreground_threshold = foreground_threshold ,
1095+ distance_smoothing = distance_smoothing ,
1096+ min_size = min_size , tile_shape = tile_shape ,
1097+ halo = halo , n_threads = n_threads , verbose = False ,
1098+ )
1099+
10261100 if output_mode is not None :
10271101 segmentation = self ._to_masks (segmentation , output_mode )
10281102 return segmentation
@@ -1094,7 +1168,7 @@ def initialize(
10941168 """
10951169 original_size = image .shape [:2 ]
10961170 image_embeddings , tile_shape , halo = _process_tiled_embeddings (
1097- self ._predictor , image , image_embeddings , tile_shape , halo
1171+ self ._predictor , image , image_embeddings , tile_shape , halo , verbose = verbose ,
10981172 )
10991173 tiling = blocking ([0 , 0 ], original_size , tile_shape )
11001174
@@ -1127,6 +1201,8 @@ def initialize(
11271201 foreground [inner_bb ] = output [0 ][local_bb ]
11281202 center_distances [inner_bb ] = output [1 ][local_bb ]
11291203 boundary_distances [inner_bb ] = output [2 ][local_bb ]
1204+ pbar_update (1 )
1205+
11301206 pbar_close ()
11311207
11321208 # Set the state.
0 commit comments