@@ -281,7 +281,7 @@ def distance_watershed_implementation(
281281 input_path : str ,
282282 output_folder : Optional [str ] = None ,
283283 min_size : int = 1000 ,
284- center_distance_threshold : float = 0.4 ,
284+ center_distance_threshold : Optional [ float ] = 0.4 ,
285285 boundary_distance_threshold : Optional [float ] = None ,
286286 fg_threshold : float = 0.5 ,
287287 distance_smoothing : float = 0.0 ,
@@ -342,12 +342,16 @@ def distance_watershed_implementation(
342342 )
343343
344344 # Compute the seed inputs:
345- # First, threshold the center distances.
346- seed_inputs = ThresholdWrapper (center_distances , threshold = center_distance_threshold , operator = np .less )
347- # Then, if a boundary distance threshold was passed threshold the boundary distances and combine both.
348- if boundary_distance_threshold is not None :
345+ if boundary_distance_threshold is None and center_distance_threshold is None :
346+ raise ValueError ("Either boundary_distance_threshold, center_distance_threshold, or both have to be specifie." )
347+ elif boundary_distance_threshold is None :
348+ seed_inputs = ThresholdWrapper (center_distances , threshold = center_distance_threshold , operator = np .less )
349+ elif center_distance_threshold is None :
350+ seed_inputs = ThresholdWrapper (boundary_distances , threshold = boundary_distance_threshold , operator = np .less )
351+ else :
352+ seed_inputs1 = ThresholdWrapper (center_distances , threshold = center_distance_threshold , operator = np .less )
349353 seed_inputs2 = ThresholdWrapper (boundary_distances , threshold = boundary_distance_threshold , operator = np .less )
350- seed_inputs = MultiTransformationWrapper (np .logical_and , seed_inputs , seed_inputs2 )
354+ seed_inputs = MultiTransformationWrapper (np .logical_and , seed_inputs1 , seed_inputs2 )
351355
352356 # Compute the seeds via connected components on the seed inputs.
353357 parallel .label (
@@ -611,6 +615,7 @@ def run_unet_segmentation_slurm(
611615 boundary_distance_threshold : float = 0.5 ,
612616 fg_threshold : float = 0.5 ,
613617 distance_smoothing : float = 0.0 ,
618+ original_shape : Optional [Tuple [int , int , int ]] = None ,
614619) -> None :
615620 """Create segmentation from prediction.
616621
@@ -623,14 +628,16 @@ def run_unet_segmentation_slurm(
623628 fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
624629 distance_smoothing: The sigma value for smoothing the distance predictions with a gaussian kernel.
625630 This may help to reduce border artifacts. If set to 0 (the default) smoothing is not applied.
631+ original_shape: The original shape of the output, in case the prediction was resized.
626632 """
627633 min_size = int (min_size )
628- center_distance_threshold = float (center_distance_threshold )
634+ center_distance_threshold = None if center_distance_threshold is None else float (center_distance_threshold )
629635 boundary_distance_threshold = float (boundary_distance_threshold )
630636 distance_smoothing = float (distance_smoothing )
631637 pmap_out = os .path .join (output_folder , "predictions.zarr" )
632638 distance_watershed_implementation (pmap_out , output_folder , center_distance_threshold = center_distance_threshold ,
633639 boundary_distance_threshold = boundary_distance_threshold ,
634640 fg_threshold = fg_threshold ,
635641 distance_smoothing = distance_smoothing ,
636- min_size = min_size )
642+ min_size = min_size ,
643+ original_shape = original_shape )
0 commit comments