1818import torch
1919import z5py
2020
21- from elf .wrapper import ThresholdWrapper , SimpleTransformationWrapper , MultiTransformationWrapper
21+ from elf .wrapper import ThresholdWrapper , SimpleTransformationWrapper
22+ from elf .wrapper .base import MultiTransformationWrapper
2223from elf .wrapper .resized_volume import ResizedVolume
2324from elf .io import open_file
2425from torch_em .util import load_model
@@ -233,17 +234,18 @@ def distance_watershed_implementation(
233234 distance_smoothing : float = 1.6 ,
234235 original_shape : Optional [Tuple [int , int , int ]] = None ,
235236) -> None :
236- """
237+ """Parallel implementation of the distance-prediction based watershed.
237238
238239 Args:
239- input_path:
240- output_folder:
241- min_size:
242- center_distance_threshold:
243- boundary_distance_threshold:
244- fg_threshold:
245- distance_smoothing:
246- original_shape:
240+ input_path: The path to the zarr file with the network predictions.
241+ output_folder: The folder for storing the segmentation and intermediate results.
242+ min_size: The minimal size of objects in the segmentation.
243+ center_distance_threshold: The threshold applied to the distance center predictions to derive seeds.
244+ boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
245+ By default this is set to 'None', in which case the boundary distances are not used for the seeds.
246+ fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
247+ distance_smoothing: The smoothing factor applied to the distance predictions.
248+ original_shape: The original shape to resize the segmentation to.
247249 """
248250 input_ = open_file (input_path , "r" )["prediction" ]
249251
@@ -361,6 +363,10 @@ def run_unet_prediction(
361363 block_shape : Optional [Tuple [int , int , int ]] = None ,
362364 halo : Optional [Tuple [int , int , int ]] = None ,
363365 use_mask : bool = True ,
366+ center_distance_threshold : float = 0.4 ,
367+ boundary_distance_threshold : Optional [float ] = None ,
368+ fg_threshold : float = 0.5 ,
369+ distance_smoothing : float = 1.6 ,
364370) -> None :
365371 """Run prediction and segmentation with a distance U-Net.
366372
@@ -375,6 +381,11 @@ def run_unet_prediction(
375381 block_shape: The block-shape for running the prediction.
376382 halo: The halo (= block overlap) to use for prediction.
377383 use_mask: Whether to use the masking heuristics to not run inference on empty blocks.
384+ center_distance_threshold: The threshold applied to the distance center predictions to derive seeds.
385+ boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
386+ By default this is set to 'None', in which case the boundary distances are not used for the seeds.
387+ fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
388+ distance_smoothing: The smoothing factor applied to the distance predictions.
378389 """
379390 os .makedirs (output_folder , exist_ok = True )
380391
@@ -386,7 +397,11 @@ def run_unet_prediction(
386397 )
387398
388399 pmap_out = os .path .join (output_folder , "predictions.zarr" )
389- distance_watershed_implementation (pmap_out , output_folder , min_size = min_size , original_shape = original_shape )
400+ distance_watershed_implementation (
401+ pmap_out , output_folder , min_size = min_size , original_shape = original_shape ,
402+ center_distance_threshold = center_distance_threshold , boundary_distance_threshold = boundary_distance_threshold ,
403+ fg_threshold = fg_threshold , distance_smoothing = distance_smoothing ,
404+ )
390405
391406
392407#
0 commit comments