Skip to content

Commit 9022292

Browse files
Implement more complex watershed logic WIP
1 parent 05e3912 commit 9022292

File tree

1 file changed

+57
-21
lines changed

1 file changed

+57
-21
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
import warnings
1010
from concurrent import futures
11+
from functools import partial
1112
from typing import Optional, Tuple
1213

1314
import elf.parallel as parallel
@@ -17,7 +18,7 @@
1718
import torch
1819
import z5py
1920

20-
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
21+
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper, MultiTransformationWrapper
2122
from elf.wrapper.resized_volume import ResizedVolume
2223
from elf.io import open_file
2324
from torch_em.util import load_model
@@ -27,6 +28,11 @@
2728
import flamingo_tools.s3_utils as s3_utils
2829
from flamingo_tools.file_utils import read_image_data
2930

31+
try:
32+
import fastfilters as ff
33+
except ImportError:
34+
import vigra.filters as ff
35+
3036

3137
class SelectChannel(SimpleTransformationWrapper):
3238
"""Wrapper to select a chanel from an array-like dataset object.
@@ -217,61 +223,91 @@ def find_mask_block(block_id):
217223
list(tqdm(tp.map(find_mask_block, range(n_blocks)), total=n_blocks))
218224

219225

220-
def segmentation_impl(input_path, output_folder, min_size, original_shape=None):
221-
"""@private
226+
def distance_watershed_implementation(
227+
input_path: str,
228+
output_folder: str,
229+
min_size: int,
230+
center_distance_threshold: float = 0.4,
231+
boundary_distance_threshold: Optional[float] = None,
232+
fg_threshold: float = 0.5,
233+
distance_smoothing: float = 1.6,
234+
original_shape: Optional[Tuple[int, int, int]] = None,
235+
) -> None:
236+
"""
237+
238+
Args:
239+
input_path:
240+
output_folder:
241+
min_size:
242+
center_distance_threshold:
243+
boundary_distance_threshold:
244+
fg_threshold:
245+
distance_smoothing:
246+
original_shape:
222247
"""
223248
input_ = open_file(input_path, "r")["prediction"]
224249

225250
# Limit the number of cores for parallelization.
226251
n_threads = min(16, mp.cpu_count())
227252

228-
# The center distances as input for computing the seeds.
253+
# Get the foreground mask.
254+
mask = ThresholdWrapper(SelectChannel(input_, 0), threshold=fg_threshold)
255+
256+
# Get the the center and boundary distances.
229257
center_distances = SelectChannel(input_, 1)
230-
block_shape = center_distances.chunks
258+
boundary_distances = SelectChannel(input_, 2)
259+
260+
# Apply (lazy) smoothing to both.
261+
smoothing = partial(ff.gaussianSmoothing, sigma=distance_smoothing)
262+
center_distances = SimpleTransformationWrapper(center_distances, transformation=smoothing)
263+
boundary_distances = SimpleTransformationWrapper(boundary_distances, transformation=smoothing)
231264

232-
# Compute the seeds based on smoothed center distances < 0.5.
265+
# Allocate an zarr array for the seeds.
266+
block_shape = center_distances.chunks
233267
seed_path = os.path.join(output_folder, "seeds.zarr")
234268
seed_file = open_file(os.path.join(seed_path), "a")
235269
seeds = seed_file.require_dataset(
236270
"seeds", shape=center_distances.shape, chunks=block_shape, compression="gzip", dtype="uint64"
237271
)
238272

239-
fg_threshold = 0.5
240-
mask = ThresholdWrapper(SelectChannel(input_, 0), threshold=fg_threshold)
273+
# Compute the seed inputs:
274+
# First, threshold the center distances.
275+
seed_inputs = ThresholdWrapper(center_distances, threshold=center_distance_threshold, operator=np.less)
276+
# Then, if a boundary distance threshold was passed threshold the boundary distances and combine both.
277+
if boundary_distance_threshold is not None:
278+
seed_inputs2 = ThresholdWrapper(boundary_distances, threshold=boundary_distance_threshold, operator=np.less)
279+
seed_inputs = MultiTransformationWrapper(np.logical_and, seed_inputs, seed_inputs2)
241280

281+
# Compute the seeds via connected components on the seed inputs.
242282
parallel.label(
243-
data=ThresholdWrapper(center_distances, threshold=0.4, operator=np.less),
244-
out=seeds, block_shape=block_shape, mask=mask, verbose=True, n_threads=n_threads
283+
data=seed_inputs, out=seeds, block_shape=block_shape, mask=mask, verbose=True, n_threads=n_threads
245284
)
246285

247-
# Run the watershed.
248-
if original_shape is None:
249-
seg_path = os.path.join(output_folder, "segmentation.zarr")
250-
else:
251-
seg_path = os.path.join(output_folder, "seg_downscaled.zarr")
252-
286+
# Allocate the zarr array for the segmentation.
287+
seg_path = os.path.join(output_folder, "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr")
253288
seg_file = open_file(seg_path, "a")
254289
seg = seg_file.create_dataset(
255290
"segmentation", shape=seeds.shape, chunks=block_shape, compression="gzip", dtype="uint64"
256291
)
257292

258-
hmap = SelectChannel(input_, 2)
293+
# Compute the segmentation with a seeded watershed
259294
halo = (2, 8, 8)
260295
parallel.seeded_watershed(
261-
hmap, seeds, out=seg, block_shape=block_shape, halo=halo, mask=mask, verbose=True,
296+
boundary_distances, seeds, out=seg, block_shape=block_shape, halo=halo, mask=mask, verbose=True,
262297
n_threads=n_threads,
263298
)
264299

300+
# Apply size filter.
265301
if min_size > 0:
266302
parallel.size_filter(
267303
seg, seg, min_size=min_size, block_shape=block_shape, mask=mask,
268304
verbose=True, n_threads=n_threads, relabel=True,
269305
)
270306

307+
# Reshape to original shape if given.
271308
if original_shape is not None:
272309
out_path = os.path.join(output_folder, "segmentation.zarr")
273310

274-
# This logic should be refactored.
275311
output_seg = ResizedVolume(seg, shape=original_shape, order=0)
276312
with open_file(out_path, "a") as f:
277313
out_seg_volume = f.create_dataset(
@@ -350,7 +386,7 @@ def run_unet_prediction(
350386
)
351387

352388
pmap_out = os.path.join(output_folder, "predictions.zarr")
353-
segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)
389+
distance_watershed_implementation(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)
354390

355391

356392
#
@@ -467,4 +503,4 @@ def run_unet_segmentation_slurm(output_folder: str, min_size: int) -> None:
467503
"""
468504
min_size = int(min_size)
469505
pmap_out = os.path.join(output_folder, "predictions.zarr")
470-
segmentation_impl(pmap_out, output_folder, min_size=min_size)
506+
distance_watershed_implementation(pmap_out, output_folder, min_size=min_size)

0 commit comments

Comments
 (0)