88import os
99import warnings
1010from concurrent import futures
11+ from functools import partial
1112from typing import Optional , Tuple
1213
1314import elf .parallel as parallel
1718import torch
1819import z5py
1920
20- from elf .wrapper import ThresholdWrapper , SimpleTransformationWrapper
21+ from elf .wrapper import ThresholdWrapper , SimpleTransformationWrapper , SimpleTransformationWrapperWithHalo
2122from elf .wrapper .base import MultiTransformationWrapper
2223from elf .wrapper .resized_volume import ResizedVolume
2324from elf .io import open_file
25+ from skimage .filters import gaussian
2426from torch_em .util import load_model
2527from torch_em .util .prediction import predict_with_halo
2628from tqdm import tqdm
@@ -278,6 +280,7 @@ def distance_watershed_implementation(
278280 center_distance_threshold : float = 0.4 ,
279281 boundary_distance_threshold : Optional [float ] = None ,
280282 fg_threshold : float = 0.5 ,
283+ distance_smoothing : float = 0.0 ,
281284 original_shape : Optional [Tuple [int , int , int ]] = None
282285) -> None :
283286 """Parallel implementation of the distance-prediction based watershed.
@@ -290,6 +293,8 @@ def distance_watershed_implementation(
290293 boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
291294 By default this is set to 'None', in which case the boundary distances are not used for the seeds.
292295 fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
296+ distance_smoothing: The sigma value for smoothing the distance predictions with a gaussian kernel.
297+ This may help to reduce border artifacts. If set to 0 (the default) smoothing is not applied.
293298 original_shape: The original shape to resize the segmentation to.
294299 """
295300 if isinstance (input_path , str ):
@@ -307,11 +312,14 @@ def distance_watershed_implementation(
307312 center_distances = SelectChannel (input_ , 1 )
308313 boundary_distances = SelectChannel (input_ , 2 )
309314
310- # Apply (lazy) smoothing to both.
311- # NOTE: this leads to issues with the parallelization, so we don't implement distance smoothing for now.
312- # smoothing = partial(ff.gaussianSmoothing, sigma=distance_smoothing)
313- # center_distances = SimpleTransformationWrapper(center_distances, transformation=smoothing)
314- # boundary_distances = SimpleTransformationWrapper(boundary_distances, transformation=smoothing)
315+ # Apply (lazy) smoothing to both channels if distance smoothing was set.
316+ if distance_smoothing > 0 :
317+ smooth = partial (gaussian , sigma = distance_smoothing )
318+ # We assume that the gaussian is truncated at 5.3 sigma (tolerance of 1e-6)
319+ halo = int (np .ceil (5.3 * distance_smoothing ))
320+ halo = 3 * (halo ,)
321+ center_distances = SimpleTransformationWrapperWithHalo (center_distances , transformation = smooth , halo = halo )
322+ boundary_distances = SimpleTransformationWrapperWithHalo (boundary_distances , transformation = smooth , halo = halo )
315323
316324 # Allocate the (zarr) array for the seeds.
317325 if output_folder is None :
@@ -427,6 +435,7 @@ def run_unet_prediction(
427435 center_distance_threshold : float = 0.4 ,
428436 boundary_distance_threshold : Optional [float ] = None ,
429437 fg_threshold : float = 0.5 ,
438+ distance_smoothing : float = 0.0 ,
430439 seg_class : Optional [str ] = None ,
431440) -> None :
432441 """Run prediction and segmentation with a distance U-Net.
@@ -446,6 +455,8 @@ def run_unet_prediction(
446455 boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
447456 By default this is set to 'None', in which case the boundary distances are not used for the seeds.
448457 fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
458+ distance_smoothing: The sigma value for smoothing the distance predictions with a gaussian kernel.
459+ This may help to reduce border artifacts. If set to 0 (the default) smoothing is not applied.
449460 seg_class: Specifier for exclusion criterias for mask generation.
450461 """
451462 if output_folder is not None :
@@ -470,7 +481,8 @@ def run_unet_prediction(
470481 pmap_out , output_folder , min_size = min_size , original_shape = original_shape ,
471482 center_distance_threshold = center_distance_threshold ,
472483 boundary_distance_threshold = boundary_distance_threshold ,
473- fg_threshold = fg_threshold
484+ fg_threshold = fg_threshold ,
485+ distance_smoothing = distance_smoothing ,
474486 )
475487
476488 return segmentation
@@ -590,6 +602,7 @@ def run_unet_segmentation_slurm(
590602 center_distance_threshold : float = 0.4 ,
591603 boundary_distance_threshold : float = 0.5 ,
592604 fg_threshold : float = 0.5 ,
605+ distance_smoothing : float = 0.0 ,
593606) -> None :
594607 """Create segmentation from prediction.
595608
@@ -600,10 +613,13 @@ def run_unet_segmentation_slurm(
600613 boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
601614 By default this is set to 'None', in which case the boundary distances are not used for the seeds.
602615 fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
616+ distance_smoothing: The sigma value for smoothing the distance predictions with a gaussian kernel.
617+ This may help to reduce border artifacts. If set to 0 (the default) smoothing is not applied.
603618 """
604619 min_size = int (min_size )
605620 pmap_out = os .path .join (output_folder , "predictions.zarr" )
606621 distance_watershed_implementation (pmap_out , output_folder , center_distance_threshold = center_distance_threshold ,
607622 boundary_distance_threshold = boundary_distance_threshold ,
608623 fg_threshold = fg_threshold ,
624+ distance_smoothing = distance_smoothing ,
609625 min_size = min_size )
0 commit comments