Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 63 additions & 20 deletions flamingo_tools/segmentation/unet_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import z5py

from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
from elf.wrapper.base import MultiTransformationWrapper
from elf.wrapper.resized_volume import ResizedVolume
from elf.io import open_file
from torch_em.util import load_model
Expand Down Expand Up @@ -217,61 +218,91 @@ def find_mask_block(block_id):
list(tqdm(tp.map(find_mask_block, range(n_blocks)), total=n_blocks))


def segmentation_impl(input_path, output_folder, min_size, original_shape=None):
"""@private
def distance_watershed_implementation(
input_path: str,
output_folder: str,
min_size: int,
center_distance_threshold: float = 0.4,
boundary_distance_threshold: Optional[float] = None,
fg_threshold: float = 0.5,
original_shape: Optional[Tuple[int, int, int]] = None,
) -> None:
"""Parallel implementation of the distance-prediction based watershed.

Args:
input_path: The path to the zarr file with the network predictions.
output_folder: The folder for storing the segmentation and intermediate results.
min_size: The minimal size of objects in the segmentation.
center_distance_threshold: The threshold applied to the distance center predictions to derive seeds.
boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
By default this is set to 'None', in which case the boundary distances are not used for the seeds.
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
original_shape: The original shape to resize the segmentation to.
"""
input_ = open_file(input_path, "r")["prediction"]

# Limit the number of cores for parallelization.
n_threads = min(16, mp.cpu_count())

# The center distances as input for computing the seeds.
# Get the foreground mask.
mask = ThresholdWrapper(SelectChannel(input_, 0), threshold=fg_threshold)

# Get the the center and boundary distances.
center_distances = SelectChannel(input_, 1)
block_shape = center_distances.chunks
boundary_distances = SelectChannel(input_, 2)

# Compute the seeds based on smoothed center distances < 0.5.
# Apply (lazy) smoothing to both.
# NOTE: this leads to issues with the parallelization, so we don't implement distance smoothing for now.
# smoothing = partial(ff.gaussianSmoothing, sigma=distance_smoothing)
# center_distances = SimpleTransformationWrapper(center_distances, transformation=smoothing)
# boundary_distances = SimpleTransformationWrapper(boundary_distances, transformation=smoothing)

# Allocate an zarr array for the seeds.
block_shape = center_distances.chunks
seed_path = os.path.join(output_folder, "seeds.zarr")
seed_file = open_file(os.path.join(seed_path), "a")
seeds = seed_file.require_dataset(
"seeds", shape=center_distances.shape, chunks=block_shape, compression="gzip", dtype="uint64"
)

fg_threshold = 0.5
mask = ThresholdWrapper(SelectChannel(input_, 0), threshold=fg_threshold)
# Compute the seed inputs:
# First, threshold the center distances.
seed_inputs = ThresholdWrapper(center_distances, threshold=center_distance_threshold, operator=np.less)
# Then, if a boundary distance threshold was passed threshold the boundary distances and combine both.
if boundary_distance_threshold is not None:
seed_inputs2 = ThresholdWrapper(boundary_distances, threshold=boundary_distance_threshold, operator=np.less)
seed_inputs = MultiTransformationWrapper(np.logical_and, seed_inputs, seed_inputs2)

# Compute the seeds via connected components on the seed inputs.
parallel.label(
data=ThresholdWrapper(center_distances, threshold=0.4, operator=np.less),
out=seeds, block_shape=block_shape, mask=mask, verbose=True, n_threads=n_threads
data=seed_inputs, out=seeds, block_shape=block_shape, mask=mask, verbose=True, n_threads=n_threads
)

# Run the watershed.
if original_shape is None:
seg_path = os.path.join(output_folder, "segmentation.zarr")
else:
seg_path = os.path.join(output_folder, "seg_downscaled.zarr")

# Allocate the zarr array for the segmentation.
seg_path = os.path.join(output_folder, "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr")
seg_file = open_file(seg_path, "a")
seg = seg_file.create_dataset(
"segmentation", shape=seeds.shape, chunks=block_shape, compression="gzip", dtype="uint64"
)

hmap = SelectChannel(input_, 2)
# Compute the segmentation with a seeded watershed
halo = (2, 8, 8)
parallel.seeded_watershed(
hmap, seeds, out=seg, block_shape=block_shape, halo=halo, mask=mask, verbose=True,
boundary_distances, seeds, out=seg, block_shape=block_shape, halo=halo, mask=mask, verbose=True,
n_threads=n_threads,
)

# Apply size filter.
if min_size > 0:
parallel.size_filter(
seg, seg, min_size=min_size, block_shape=block_shape, mask=mask,
verbose=True, n_threads=n_threads, relabel=True,
)

# Reshape to original shape if given.
if original_shape is not None:
out_path = os.path.join(output_folder, "segmentation.zarr")

# This logic should be refactored.
output_seg = ResizedVolume(seg, shape=original_shape, order=0)
with open_file(out_path, "a") as f:
out_seg_volume = f.create_dataset(
Expand Down Expand Up @@ -325,6 +356,9 @@ def run_unet_prediction(
block_shape: Optional[Tuple[int, int, int]] = None,
halo: Optional[Tuple[int, int, int]] = None,
use_mask: bool = True,
center_distance_threshold: float = 0.4,
boundary_distance_threshold: Optional[float] = None,
fg_threshold: float = 0.5,
) -> None:
"""Run prediction and segmentation with a distance U-Net.

Expand All @@ -339,6 +373,10 @@ def run_unet_prediction(
block_shape: The block-shape for running the prediction.
halo: The halo (= block overlap) to use for prediction.
use_mask: Whether to use the masking heuristics to not run inference on empty blocks.
center_distance_threshold: The threshold applied to the distance center predictions to derive seeds.
boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
By default this is set to 'None', in which case the boundary distances are not used for the seeds.
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
"""
os.makedirs(output_folder, exist_ok=True)

Expand All @@ -350,7 +388,12 @@ def run_unet_prediction(
)

pmap_out = os.path.join(output_folder, "predictions.zarr")
segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)
distance_watershed_implementation(
pmap_out, output_folder, min_size=min_size, original_shape=original_shape,
center_distance_threshold=center_distance_threshold,
boundary_distance_threshold=boundary_distance_threshold,
fg_threshold=fg_threshold,
)


#
Expand Down Expand Up @@ -467,4 +510,4 @@ def run_unet_segmentation_slurm(output_folder: str, min_size: int) -> None:
"""
min_size = int(min_size)
pmap_out = os.path.join(output_folder, "predictions.zarr")
segmentation_impl(pmap_out, output_folder, min_size=min_size)
distance_watershed_implementation(pmap_out, output_folder, min_size=min_size)
8 changes: 7 additions & 1 deletion test/test_segmentation/test_unet_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _create_data(self, tmp_dir, use_tif):
f.create_dataset(key, data=data, chunks=(32, 32, 32))
return path, key

def _test_run_unet_prediction(self, use_tif, use_mask):
def _test_run_unet_prediction(self, use_tif, use_mask, **extra_kwargs):
from flamingo_tools.segmentation import run_unet_prediction

with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -42,6 +42,7 @@ def _test_run_unet_prediction(self, use_tif, use_mask):
input_path, input_key, output_folder, model_path,
scale=None, min_size=100,
block_shape=(64, 64, 64), halo=(16, 16, 16),
**extra_kwargs
)

expected_path = os.path.join(output_folder, "segmentation.zarr")
Expand All @@ -64,6 +65,11 @@ def test_run_unet_prediction_tif(self):
def test_run_unet_prediction_tif_mask(self):
self._test_run_unet_prediction(use_tif=True, use_mask=True)

def test_run_unet_prediction_complex_watershed(self):
self._test_run_unet_prediction(
use_tif=False, use_mask=True, center_distance_threshold=0.5, boundary_distance_threshold=0.5,
)


if __name__ == "__main__":
unittest.main()