diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index 8b754d3..6759cc2 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -18,6 +18,7 @@ import z5py from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper +from elf.wrapper.base import MultiTransformationWrapper from elf.wrapper.resized_volume import ResizedVolume from elf.io import open_file from torch_em.util import load_model @@ -217,61 +218,91 @@ def find_mask_block(block_id): list(tqdm(tp.map(find_mask_block, range(n_blocks)), total=n_blocks)) -def segmentation_impl(input_path, output_folder, min_size, original_shape=None): - """@private +def distance_watershed_implementation( + input_path: str, + output_folder: str, + min_size: int, + center_distance_threshold: float = 0.4, + boundary_distance_threshold: Optional[float] = None, + fg_threshold: float = 0.5, + original_shape: Optional[Tuple[int, int, int]] = None, +) -> None: + """Parallel implementation of the distance-prediction based watershed. + + Args: + input_path: The path to the zarr file with the network predictions. + output_folder: The folder for storing the segmentation and intermediate results. + min_size: The minimal size of objects in the segmentation. + center_distance_threshold: The threshold applied to the distance center predictions to derive seeds. + boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds. + By default this is set to 'None', in which case the boundary distances are not used for the seeds. + fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask. + original_shape: The original shape to resize the segmentation to. """ input_ = open_file(input_path, "r")["prediction"] # Limit the number of cores for parallelization. n_threads = min(16, mp.cpu_count()) - # The center distances as input for computing the seeds. + # Get the foreground mask. + mask = ThresholdWrapper(SelectChannel(input_, 0), threshold=fg_threshold) + + # Get the the center and boundary distances. center_distances = SelectChannel(input_, 1) - block_shape = center_distances.chunks + boundary_distances = SelectChannel(input_, 2) - # Compute the seeds based on smoothed center distances < 0.5. + # Apply (lazy) smoothing to both. + # NOTE: this leads to issues with the parallelization, so we don't implement distance smoothing for now. + # smoothing = partial(ff.gaussianSmoothing, sigma=distance_smoothing) + # center_distances = SimpleTransformationWrapper(center_distances, transformation=smoothing) + # boundary_distances = SimpleTransformationWrapper(boundary_distances, transformation=smoothing) + + # Allocate an zarr array for the seeds. + block_shape = center_distances.chunks seed_path = os.path.join(output_folder, "seeds.zarr") seed_file = open_file(os.path.join(seed_path), "a") seeds = seed_file.require_dataset( "seeds", shape=center_distances.shape, chunks=block_shape, compression="gzip", dtype="uint64" ) - fg_threshold = 0.5 - mask = ThresholdWrapper(SelectChannel(input_, 0), threshold=fg_threshold) + # Compute the seed inputs: + # First, threshold the center distances. + seed_inputs = ThresholdWrapper(center_distances, threshold=center_distance_threshold, operator=np.less) + # Then, if a boundary distance threshold was passed threshold the boundary distances and combine both. + if boundary_distance_threshold is not None: + seed_inputs2 = ThresholdWrapper(boundary_distances, threshold=boundary_distance_threshold, operator=np.less) + seed_inputs = MultiTransformationWrapper(np.logical_and, seed_inputs, seed_inputs2) + # Compute the seeds via connected components on the seed inputs. parallel.label( - data=ThresholdWrapper(center_distances, threshold=0.4, operator=np.less), - out=seeds, block_shape=block_shape, mask=mask, verbose=True, n_threads=n_threads + data=seed_inputs, out=seeds, block_shape=block_shape, mask=mask, verbose=True, n_threads=n_threads ) - # Run the watershed. - if original_shape is None: - seg_path = os.path.join(output_folder, "segmentation.zarr") - else: - seg_path = os.path.join(output_folder, "seg_downscaled.zarr") - + # Allocate the zarr array for the segmentation. + seg_path = os.path.join(output_folder, "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr") seg_file = open_file(seg_path, "a") seg = seg_file.create_dataset( "segmentation", shape=seeds.shape, chunks=block_shape, compression="gzip", dtype="uint64" ) - hmap = SelectChannel(input_, 2) + # Compute the segmentation with a seeded watershed halo = (2, 8, 8) parallel.seeded_watershed( - hmap, seeds, out=seg, block_shape=block_shape, halo=halo, mask=mask, verbose=True, + boundary_distances, seeds, out=seg, block_shape=block_shape, halo=halo, mask=mask, verbose=True, n_threads=n_threads, ) + # Apply size filter. if min_size > 0: parallel.size_filter( seg, seg, min_size=min_size, block_shape=block_shape, mask=mask, verbose=True, n_threads=n_threads, relabel=True, ) + # Reshape to original shape if given. if original_shape is not None: out_path = os.path.join(output_folder, "segmentation.zarr") - # This logic should be refactored. output_seg = ResizedVolume(seg, shape=original_shape, order=0) with open_file(out_path, "a") as f: out_seg_volume = f.create_dataset( @@ -325,6 +356,9 @@ def run_unet_prediction( block_shape: Optional[Tuple[int, int, int]] = None, halo: Optional[Tuple[int, int, int]] = None, use_mask: bool = True, + center_distance_threshold: float = 0.4, + boundary_distance_threshold: Optional[float] = None, + fg_threshold: float = 0.5, ) -> None: """Run prediction and segmentation with a distance U-Net. @@ -339,6 +373,10 @@ def run_unet_prediction( block_shape: The block-shape for running the prediction. halo: The halo (= block overlap) to use for prediction. use_mask: Whether to use the masking heuristics to not run inference on empty blocks. + center_distance_threshold: The threshold applied to the distance center predictions to derive seeds. + boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds. + By default this is set to 'None', in which case the boundary distances are not used for the seeds. + fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask. """ os.makedirs(output_folder, exist_ok=True) @@ -350,7 +388,12 @@ def run_unet_prediction( ) pmap_out = os.path.join(output_folder, "predictions.zarr") - segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape) + distance_watershed_implementation( + pmap_out, output_folder, min_size=min_size, original_shape=original_shape, + center_distance_threshold=center_distance_threshold, + boundary_distance_threshold=boundary_distance_threshold, + fg_threshold=fg_threshold, + ) # @@ -467,4 +510,4 @@ def run_unet_segmentation_slurm(output_folder: str, min_size: int) -> None: """ min_size = int(min_size) pmap_out = os.path.join(output_folder, "predictions.zarr") - segmentation_impl(pmap_out, output_folder, min_size=min_size) + distance_watershed_implementation(pmap_out, output_folder, min_size=min_size) diff --git a/test/test_segmentation/test_unet_prediction.py b/test/test_segmentation/test_unet_prediction.py index 9038bcc..63495f3 100644 --- a/test/test_segmentation/test_unet_prediction.py +++ b/test/test_segmentation/test_unet_prediction.py @@ -31,7 +31,7 @@ def _create_data(self, tmp_dir, use_tif): f.create_dataset(key, data=data, chunks=(32, 32, 32)) return path, key - def _test_run_unet_prediction(self, use_tif, use_mask): + def _test_run_unet_prediction(self, use_tif, use_mask, **extra_kwargs): from flamingo_tools.segmentation import run_unet_prediction with tempfile.TemporaryDirectory() as tmp_dir: @@ -42,6 +42,7 @@ def _test_run_unet_prediction(self, use_tif, use_mask): input_path, input_key, output_folder, model_path, scale=None, min_size=100, block_shape=(64, 64, 64), halo=(16, 16, 16), + **extra_kwargs ) expected_path = os.path.join(output_folder, "segmentation.zarr") @@ -64,6 +65,11 @@ def test_run_unet_prediction_tif(self): def test_run_unet_prediction_tif_mask(self): self._test_run_unet_prediction(use_tif=True, use_mask=True) + def test_run_unet_prediction_complex_watershed(self): + self._test_run_unet_prediction( + use_tif=False, use_mask=True, center_distance_threshold=0.5, boundary_distance_threshold=0.5, + ) + if __name__ == "__main__": unittest.main()