1818import z5py
1919
2020from elf .wrapper import ThresholdWrapper , SimpleTransformationWrapper
21+ from elf .wrapper .base import MultiTransformationWrapper
2122from elf .wrapper .resized_volume import ResizedVolume
2223from elf .io import open_file
2324from torch_em .util import load_model
@@ -217,61 +218,91 @@ def find_mask_block(block_id):
217218 list (tqdm (tp .map (find_mask_block , range (n_blocks )), total = n_blocks ))
218219
219220
220- def segmentation_impl (input_path , output_folder , min_size , original_shape = None ):
221- """@private
221+ def distance_watershed_implementation (
222+ input_path : str ,
223+ output_folder : str ,
224+ min_size : int ,
225+ center_distance_threshold : float = 0.4 ,
226+ boundary_distance_threshold : Optional [float ] = None ,
227+ fg_threshold : float = 0.5 ,
228+ original_shape : Optional [Tuple [int , int , int ]] = None ,
229+ ) -> None :
230+ """Parallel implementation of the distance-prediction based watershed.
231+
232+ Args:
233+ input_path: The path to the zarr file with the network predictions.
234+ output_folder: The folder for storing the segmentation and intermediate results.
235+ min_size: The minimal size of objects in the segmentation.
236+ center_distance_threshold: The threshold applied to the distance center predictions to derive seeds.
237+ boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
238+ By default this is set to 'None', in which case the boundary distances are not used for the seeds.
239+ fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
240+ original_shape: The original shape to resize the segmentation to.
222241 """
223242 input_ = open_file (input_path , "r" )["prediction" ]
224243
225244 # Limit the number of cores for parallelization.
226245 n_threads = min (16 , mp .cpu_count ())
227246
228- # The center distances as input for computing the seeds.
247+ # Get the foreground mask.
248+ mask = ThresholdWrapper (SelectChannel (input_ , 0 ), threshold = fg_threshold )
249+
250+ # Get the the center and boundary distances.
229251 center_distances = SelectChannel (input_ , 1 )
230- block_shape = center_distances . chunks
252+ boundary_distances = SelectChannel ( input_ , 2 )
231253
232- # Compute the seeds based on smoothed center distances < 0.5.
254+ # Apply (lazy) smoothing to both.
255+ # NOTE: this leads to issues with the parallelization, so we don't implement distance smoothing for now.
256+ # smoothing = partial(ff.gaussianSmoothing, sigma=distance_smoothing)
257+ # center_distances = SimpleTransformationWrapper(center_distances, transformation=smoothing)
258+ # boundary_distances = SimpleTransformationWrapper(boundary_distances, transformation=smoothing)
259+
260+ # Allocate an zarr array for the seeds.
261+ block_shape = center_distances .chunks
233262 seed_path = os .path .join (output_folder , "seeds.zarr" )
234263 seed_file = open_file (os .path .join (seed_path ), "a" )
235264 seeds = seed_file .require_dataset (
236265 "seeds" , shape = center_distances .shape , chunks = block_shape , compression = "gzip" , dtype = "uint64"
237266 )
238267
239- fg_threshold = 0.5
240- mask = ThresholdWrapper (SelectChannel (input_ , 0 ), threshold = fg_threshold )
268+ # Compute the seed inputs:
269+ # First, threshold the center distances.
270+ seed_inputs = ThresholdWrapper (center_distances , threshold = center_distance_threshold , operator = np .less )
271+ # Then, if a boundary distance threshold was passed threshold the boundary distances and combine both.
272+ if boundary_distance_threshold is not None :
273+ seed_inputs2 = ThresholdWrapper (boundary_distances , threshold = boundary_distance_threshold , operator = np .less )
274+ seed_inputs = MultiTransformationWrapper (np .logical_and , seed_inputs , seed_inputs2 )
241275
276+ # Compute the seeds via connected components on the seed inputs.
242277 parallel .label (
243- data = ThresholdWrapper (center_distances , threshold = 0.4 , operator = np .less ),
244- out = seeds , block_shape = block_shape , mask = mask , verbose = True , n_threads = n_threads
278+ data = seed_inputs , out = seeds , block_shape = block_shape , mask = mask , verbose = True , n_threads = n_threads
245279 )
246280
247- # Run the watershed.
248- if original_shape is None :
249- seg_path = os .path .join (output_folder , "segmentation.zarr" )
250- else :
251- seg_path = os .path .join (output_folder , "seg_downscaled.zarr" )
252-
281+ # Allocate the zarr array for the segmentation.
282+ seg_path = os .path .join (output_folder , "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr" )
253283 seg_file = open_file (seg_path , "a" )
254284 seg = seg_file .create_dataset (
255285 "segmentation" , shape = seeds .shape , chunks = block_shape , compression = "gzip" , dtype = "uint64"
256286 )
257287
258- hmap = SelectChannel ( input_ , 2 )
288+ # Compute the segmentation with a seeded watershed
259289 halo = (2 , 8 , 8 )
260290 parallel .seeded_watershed (
261- hmap , seeds , out = seg , block_shape = block_shape , halo = halo , mask = mask , verbose = True ,
291+ boundary_distances , seeds , out = seg , block_shape = block_shape , halo = halo , mask = mask , verbose = True ,
262292 n_threads = n_threads ,
263293 )
264294
295+ # Apply size filter.
265296 if min_size > 0 :
266297 parallel .size_filter (
267298 seg , seg , min_size = min_size , block_shape = block_shape , mask = mask ,
268299 verbose = True , n_threads = n_threads , relabel = True ,
269300 )
270301
302+ # Reshape to original shape if given.
271303 if original_shape is not None :
272304 out_path = os .path .join (output_folder , "segmentation.zarr" )
273305
274- # This logic should be refactored.
275306 output_seg = ResizedVolume (seg , shape = original_shape , order = 0 )
276307 with open_file (out_path , "a" ) as f :
277308 out_seg_volume = f .create_dataset (
@@ -325,6 +356,9 @@ def run_unet_prediction(
325356 block_shape : Optional [Tuple [int , int , int ]] = None ,
326357 halo : Optional [Tuple [int , int , int ]] = None ,
327358 use_mask : bool = True ,
359+ center_distance_threshold : float = 0.4 ,
360+ boundary_distance_threshold : Optional [float ] = None ,
361+ fg_threshold : float = 0.5 ,
328362) -> None :
329363 """Run prediction and segmentation with a distance U-Net.
330364
@@ -339,6 +373,10 @@ def run_unet_prediction(
339373 block_shape: The block-shape for running the prediction.
340374 halo: The halo (= block overlap) to use for prediction.
341375 use_mask: Whether to use the masking heuristics to not run inference on empty blocks.
376+ center_distance_threshold: The threshold applied to the distance center predictions to derive seeds.
377+ boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
378+ By default this is set to 'None', in which case the boundary distances are not used for the seeds.
379+ fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
342380 """
343381 os .makedirs (output_folder , exist_ok = True )
344382
@@ -350,7 +388,12 @@ def run_unet_prediction(
350388 )
351389
352390 pmap_out = os .path .join (output_folder , "predictions.zarr" )
353- segmentation_impl (pmap_out , output_folder , min_size = min_size , original_shape = original_shape )
391+ distance_watershed_implementation (
392+ pmap_out , output_folder , min_size = min_size , original_shape = original_shape ,
393+ center_distance_threshold = center_distance_threshold ,
394+ boundary_distance_threshold = boundary_distance_threshold ,
395+ fg_threshold = fg_threshold ,
396+ )
354397
355398
356399#
@@ -467,4 +510,4 @@ def run_unet_segmentation_slurm(output_folder: str, min_size: int) -> None:
467510 """
468511 min_size = int (min_size )
469512 pmap_out = os .path .join (output_folder , "predictions.zarr" )
470- segmentation_impl (pmap_out , output_folder , min_size = min_size )
513+ distance_watershed_implementation (pmap_out , output_folder , min_size = min_size )
0 commit comments