Skip to content

Commit c319e36

Browse files
Parallelize AIS post-processing (#851)
1 parent 959af7d commit c319e36

File tree

2 files changed

+95
-15
lines changed

2 files changed

+95
-15
lines changed

micro_sam/automatic_segmentation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ def automatic_instance_segmentation(
120120
verbose=verbose,
121121
)
122122

123-
segmenter.initialize(image=image_data, image_embeddings=image_embeddings)
123+
# If we run AIS with tiling then we use the same tile shape for the watershed postprocessing.
124+
if isinstance(segmenter, InstanceSegmentationWithDecoder) and tile_shape is not None:
125+
generate_kwargs.update({"tile_shape": tile_shape, "halo": halo})
126+
127+
segmenter.initialize(image=image_data, image_embeddings=image_embeddings, verbose=verbose)
124128
masks = segmenter.generate(**generate_kwargs)
125129

126130
if len(masks) == 0: # instance segmentation can have no masks, hence we just save empty labels

micro_sam/instance_segmentation.py

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import vigra
1414
import numpy as np
15+
import elf.parallel as parallel
16+
from elf.parallel.filters import apply_filter
1517
from skimage.measure import label, regionprops
1618
from skimage.segmentation import relabel_sequential
1719

@@ -559,11 +561,13 @@ def generate(
559561

560562

561563
# Helper function for tiled embedding computation and checking consistent state.
562-
def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo):
564+
def _process_tiled_embeddings(predictor, image, image_embeddings, tile_shape, halo, verbose):
563565
if image_embeddings is None:
564566
if tile_shape is None or halo is None:
565567
raise ValueError("To compute tiled embeddings the parameters tile_shape and halo have to be passed.")
566-
image_embeddings = util.precompute_image_embeddings(predictor, image, tile_shape=tile_shape, halo=halo)
568+
image_embeddings = util.precompute_image_embeddings(
569+
predictor, image, tile_shape=tile_shape, halo=halo, verbose=verbose
570+
)
567571

568572
# Use tile shape and halo from the precomputed embeddings if not given.
569573
# Otherwise check that they are consistent.
@@ -650,7 +654,7 @@ def initialize(
650654
self._original_size = original_size
651655

652656
image_embeddings, tile_shape, halo = _process_tiled_embeddings(
653-
self._predictor, image, image_embeddings, tile_shape, halo
657+
self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose,
654658
)
655659

656660
tiling = blocking([0, 0], original_size, tile_shape)
@@ -853,6 +857,55 @@ def get_predictor_and_decoder(
853857
return predictor, decoder
854858

855859

860+
def _watershed_from_center_and_boundary_distances_parallel(
861+
center_distances,
862+
boundary_distances,
863+
foreground_map,
864+
center_distance_threshold,
865+
boundary_distance_threshold,
866+
foreground_threshold,
867+
distance_smoothing,
868+
min_size,
869+
tile_shape,
870+
halo,
871+
n_threads,
872+
verbose=False,
873+
):
874+
center_distances = apply_filter(
875+
center_distances, "gaussianSmoothing", sigma=distance_smoothing,
876+
block_shape=tile_shape, n_threads=n_threads
877+
)
878+
boundary_distances = apply_filter(
879+
boundary_distances, "gaussianSmoothing", sigma=distance_smoothing,
880+
block_shape=tile_shape, n_threads=n_threads
881+
)
882+
883+
fg_mask = foreground_map > foreground_threshold
884+
885+
marker_map = np.logical_and(
886+
center_distances < center_distance_threshold, boundary_distances < boundary_distance_threshold
887+
)
888+
marker_map[~fg_mask] = 0
889+
890+
markers = np.zeros(marker_map.shape, dtype="uint64")
891+
markers = parallel.label(
892+
marker_map, out=markers, block_shape=tile_shape, n_threads=n_threads, verbose=verbose,
893+
)
894+
895+
seg = np.zeros_like(markers, dtype="uint64")
896+
seg = parallel.seeded_watershed(
897+
boundary_distances, seeds=markers, out=seg, block_shape=tile_shape,
898+
halo=halo, n_threads=n_threads, verbose=verbose, mask=fg_mask,
899+
)
900+
901+
out = np.zeros_like(seg, dtype="uint64")
902+
out = parallel.size_filter(
903+
seg, out=out, min_size=min_size, block_shape=tile_shape, n_threads=n_threads, verbose=verbose
904+
)
905+
906+
return out
907+
908+
856909
class InstanceSegmentationWithDecoder:
857910
"""Generates an instance segmentation without prompts, using a decoder.
858911
@@ -988,6 +1041,9 @@ def generate(
9881041
distance_smoothing: float = 1.6,
9891042
min_size: int = 0,
9901043
output_mode: Optional[str] = "binary_mask",
1044+
tile_shape: Optional[Tuple[int, int]] = None,
1045+
halo: Optional[Tuple[int, int]] = None,
1046+
n_threads: Optional[int] = None,
9911047
) -> List[Dict[str, Any]]:
9921048
"""Generate instance segmentation for the currently initialized image.
9931049
@@ -1002,6 +1058,11 @@ def generate(
10021058
distance_smoothing: Sigma value for smoothing the distance predictions.
10031059
min_size: Minimal object size in the segmentation result.
10041060
output_mode: The form masks are returned in. Pass None to directly return the instance segmentation.
1061+
tile_shape: Tile shape for parallelizing the instance segmentation post-processing.
1062+
This parameter is independent from the tile shape for computing the embeddings.
1063+
If not given then post-processing will not be parallelized.
1064+
halo: Halo for parallel post-processing. See also `tile_shape`.
1065+
n_threads: Number of threads for parallel post-processing. See also `tile_shape`.
10051066
10061067
Returns:
10071068
The instance segmentation masks.
@@ -1013,16 +1074,29 @@ def generate(
10131074
foreground = vigra.filters.gaussianSmoothing(self._foreground, foreground_smoothing)
10141075
else:
10151076
foreground = self._foreground
1016-
# Further optimization: parallel implementation using elf.parallel functionality.
1017-
# (Make sure to expose n_threads to avoid over-subscription in case of outer parallelization)
1018-
segmentation = watershed_from_center_and_boundary_distances(
1019-
self._center_distances, self._boundary_distances, foreground,
1020-
center_distance_threshold=center_distance_threshold,
1021-
boundary_distance_threshold=boundary_distance_threshold,
1022-
foreground_threshold=foreground_threshold,
1023-
distance_smoothing=distance_smoothing,
1024-
min_size=min_size,
1025-
)
1077+
1078+
if tile_shape is None:
1079+
segmentation = watershed_from_center_and_boundary_distances(
1080+
self._center_distances, self._boundary_distances, foreground,
1081+
center_distance_threshold=center_distance_threshold,
1082+
boundary_distance_threshold=boundary_distance_threshold,
1083+
foreground_threshold=foreground_threshold,
1084+
distance_smoothing=distance_smoothing,
1085+
min_size=min_size,
1086+
)
1087+
else:
1088+
if halo is None:
1089+
raise ValueError("You must pass a value for halo if tile_shape is given.")
1090+
segmentation = _watershed_from_center_and_boundary_distances_parallel(
1091+
self._center_distances, self._boundary_distances, foreground,
1092+
center_distance_threshold=center_distance_threshold,
1093+
boundary_distance_threshold=boundary_distance_threshold,
1094+
foreground_threshold=foreground_threshold,
1095+
distance_smoothing=distance_smoothing,
1096+
min_size=min_size, tile_shape=tile_shape,
1097+
halo=halo, n_threads=n_threads, verbose=False,
1098+
)
1099+
10261100
if output_mode is not None:
10271101
segmentation = self._to_masks(segmentation, output_mode)
10281102
return segmentation
@@ -1094,7 +1168,7 @@ def initialize(
10941168
"""
10951169
original_size = image.shape[:2]
10961170
image_embeddings, tile_shape, halo = _process_tiled_embeddings(
1097-
self._predictor, image, image_embeddings, tile_shape, halo
1171+
self._predictor, image, image_embeddings, tile_shape, halo, verbose=verbose,
10981172
)
10991173
tiling = blocking([0, 0], original_size, tile_shape)
11001174

@@ -1127,6 +1201,8 @@ def initialize(
11271201
foreground[inner_bb] = output[0][local_bb]
11281202
center_distances[inner_bb] = output[1][local_bb]
11291203
boundary_distances[inner_bb] = output[2][local_bb]
1204+
pbar_update(1)
1205+
11301206
pbar_close()
11311207

11321208
# Set the state.

0 commit comments

Comments
 (0)