Skip to content

Commit c549e65

Browse files
authored
Merge pull request #32 from computational-cell-analytics/more-complex-watershed
Implement more complex watershed logic The newly implemented watershed logic works well for the IHC segmentation (with a value of 0.5 for the boundary distance threshold).
2 parents 05e3912 + bf267d9 commit c549e65

File tree

2 files changed

+70
-21
lines changed

2 files changed

+70
-21
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 63 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import z5py
1919

2020
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
21+
from elf.wrapper.base import MultiTransformationWrapper
2122
from elf.wrapper.resized_volume import ResizedVolume
2223
from elf.io import open_file
2324
from torch_em.util import load_model
@@ -217,61 +218,91 @@ def find_mask_block(block_id):
217218
list(tqdm(tp.map(find_mask_block, range(n_blocks)), total=n_blocks))
218219

219220

220-
def segmentation_impl(input_path, output_folder, min_size, original_shape=None):
221-
"""@private
221+
def distance_watershed_implementation(
222+
input_path: str,
223+
output_folder: str,
224+
min_size: int,
225+
center_distance_threshold: float = 0.4,
226+
boundary_distance_threshold: Optional[float] = None,
227+
fg_threshold: float = 0.5,
228+
original_shape: Optional[Tuple[int, int, int]] = None,
229+
) -> None:
230+
"""Parallel implementation of the distance-prediction based watershed.
231+
232+
Args:
233+
input_path: The path to the zarr file with the network predictions.
234+
output_folder: The folder for storing the segmentation and intermediate results.
235+
min_size: The minimal size of objects in the segmentation.
236+
center_distance_threshold: The threshold applied to the distance center predictions to derive seeds.
237+
boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
238+
By default this is set to 'None', in which case the boundary distances are not used for the seeds.
239+
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
240+
original_shape: The original shape to resize the segmentation to.
222241
"""
223242
input_ = open_file(input_path, "r")["prediction"]
224243

225244
# Limit the number of cores for parallelization.
226245
n_threads = min(16, mp.cpu_count())
227246

228-
# The center distances as input for computing the seeds.
247+
# Get the foreground mask.
248+
mask = ThresholdWrapper(SelectChannel(input_, 0), threshold=fg_threshold)
249+
250+
# Get the the center and boundary distances.
229251
center_distances = SelectChannel(input_, 1)
230-
block_shape = center_distances.chunks
252+
boundary_distances = SelectChannel(input_, 2)
231253

232-
# Compute the seeds based on smoothed center distances < 0.5.
254+
# Apply (lazy) smoothing to both.
255+
# NOTE: this leads to issues with the parallelization, so we don't implement distance smoothing for now.
256+
# smoothing = partial(ff.gaussianSmoothing, sigma=distance_smoothing)
257+
# center_distances = SimpleTransformationWrapper(center_distances, transformation=smoothing)
258+
# boundary_distances = SimpleTransformationWrapper(boundary_distances, transformation=smoothing)
259+
260+
# Allocate an zarr array for the seeds.
261+
block_shape = center_distances.chunks
233262
seed_path = os.path.join(output_folder, "seeds.zarr")
234263
seed_file = open_file(os.path.join(seed_path), "a")
235264
seeds = seed_file.require_dataset(
236265
"seeds", shape=center_distances.shape, chunks=block_shape, compression="gzip", dtype="uint64"
237266
)
238267

239-
fg_threshold = 0.5
240-
mask = ThresholdWrapper(SelectChannel(input_, 0), threshold=fg_threshold)
268+
# Compute the seed inputs:
269+
# First, threshold the center distances.
270+
seed_inputs = ThresholdWrapper(center_distances, threshold=center_distance_threshold, operator=np.less)
271+
# Then, if a boundary distance threshold was passed threshold the boundary distances and combine both.
272+
if boundary_distance_threshold is not None:
273+
seed_inputs2 = ThresholdWrapper(boundary_distances, threshold=boundary_distance_threshold, operator=np.less)
274+
seed_inputs = MultiTransformationWrapper(np.logical_and, seed_inputs, seed_inputs2)
241275

276+
# Compute the seeds via connected components on the seed inputs.
242277
parallel.label(
243-
data=ThresholdWrapper(center_distances, threshold=0.4, operator=np.less),
244-
out=seeds, block_shape=block_shape, mask=mask, verbose=True, n_threads=n_threads
278+
data=seed_inputs, out=seeds, block_shape=block_shape, mask=mask, verbose=True, n_threads=n_threads
245279
)
246280

247-
# Run the watershed.
248-
if original_shape is None:
249-
seg_path = os.path.join(output_folder, "segmentation.zarr")
250-
else:
251-
seg_path = os.path.join(output_folder, "seg_downscaled.zarr")
252-
281+
# Allocate the zarr array for the segmentation.
282+
seg_path = os.path.join(output_folder, "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr")
253283
seg_file = open_file(seg_path, "a")
254284
seg = seg_file.create_dataset(
255285
"segmentation", shape=seeds.shape, chunks=block_shape, compression="gzip", dtype="uint64"
256286
)
257287

258-
hmap = SelectChannel(input_, 2)
288+
# Compute the segmentation with a seeded watershed
259289
halo = (2, 8, 8)
260290
parallel.seeded_watershed(
261-
hmap, seeds, out=seg, block_shape=block_shape, halo=halo, mask=mask, verbose=True,
291+
boundary_distances, seeds, out=seg, block_shape=block_shape, halo=halo, mask=mask, verbose=True,
262292
n_threads=n_threads,
263293
)
264294

295+
# Apply size filter.
265296
if min_size > 0:
266297
parallel.size_filter(
267298
seg, seg, min_size=min_size, block_shape=block_shape, mask=mask,
268299
verbose=True, n_threads=n_threads, relabel=True,
269300
)
270301

302+
# Reshape to original shape if given.
271303
if original_shape is not None:
272304
out_path = os.path.join(output_folder, "segmentation.zarr")
273305

274-
# This logic should be refactored.
275306
output_seg = ResizedVolume(seg, shape=original_shape, order=0)
276307
with open_file(out_path, "a") as f:
277308
out_seg_volume = f.create_dataset(
@@ -325,6 +356,9 @@ def run_unet_prediction(
325356
block_shape: Optional[Tuple[int, int, int]] = None,
326357
halo: Optional[Tuple[int, int, int]] = None,
327358
use_mask: bool = True,
359+
center_distance_threshold: float = 0.4,
360+
boundary_distance_threshold: Optional[float] = None,
361+
fg_threshold: float = 0.5,
328362
) -> None:
329363
"""Run prediction and segmentation with a distance U-Net.
330364
@@ -339,6 +373,10 @@ def run_unet_prediction(
339373
block_shape: The block-shape for running the prediction.
340374
halo: The halo (= block overlap) to use for prediction.
341375
use_mask: Whether to use the masking heuristics to not run inference on empty blocks.
376+
center_distance_threshold: The threshold applied to the distance center predictions to derive seeds.
377+
boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
378+
By default this is set to 'None', in which case the boundary distances are not used for the seeds.
379+
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
342380
"""
343381
os.makedirs(output_folder, exist_ok=True)
344382

@@ -350,7 +388,12 @@ def run_unet_prediction(
350388
)
351389

352390
pmap_out = os.path.join(output_folder, "predictions.zarr")
353-
segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)
391+
distance_watershed_implementation(
392+
pmap_out, output_folder, min_size=min_size, original_shape=original_shape,
393+
center_distance_threshold=center_distance_threshold,
394+
boundary_distance_threshold=boundary_distance_threshold,
395+
fg_threshold=fg_threshold,
396+
)
354397

355398

356399
#
@@ -467,4 +510,4 @@ def run_unet_segmentation_slurm(output_folder: str, min_size: int) -> None:
467510
"""
468511
min_size = int(min_size)
469512
pmap_out = os.path.join(output_folder, "predictions.zarr")
470-
segmentation_impl(pmap_out, output_folder, min_size=min_size)
513+
distance_watershed_implementation(pmap_out, output_folder, min_size=min_size)

test/test_segmentation/test_unet_prediction.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _create_data(self, tmp_dir, use_tif):
3131
f.create_dataset(key, data=data, chunks=(32, 32, 32))
3232
return path, key
3333

34-
def _test_run_unet_prediction(self, use_tif, use_mask):
34+
def _test_run_unet_prediction(self, use_tif, use_mask, **extra_kwargs):
3535
from flamingo_tools.segmentation import run_unet_prediction
3636

3737
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -42,6 +42,7 @@ def _test_run_unet_prediction(self, use_tif, use_mask):
4242
input_path, input_key, output_folder, model_path,
4343
scale=None, min_size=100,
4444
block_shape=(64, 64, 64), halo=(16, 16, 16),
45+
**extra_kwargs
4546
)
4647

4748
expected_path = os.path.join(output_folder, "segmentation.zarr")
@@ -64,6 +65,11 @@ def test_run_unet_prediction_tif(self):
6465
def test_run_unet_prediction_tif_mask(self):
6566
self._test_run_unet_prediction(use_tif=True, use_mask=True)
6667

68+
def test_run_unet_prediction_complex_watershed(self):
69+
self._test_run_unet_prediction(
70+
use_tif=False, use_mask=True, center_distance_threshold=0.5, boundary_distance_threshold=0.5,
71+
)
72+
6773

6874
if __name__ == "__main__":
6975
unittest.main()

0 commit comments

Comments
 (0)