Skip to content

Commit 7f3b265

Browse files
Update unet prediction code
1 parent 3ec31ad commit 7f3b265

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def distance_watershed_implementation(
281281
input_path: str,
282282
output_folder: Optional[str] = None,
283283
min_size: int = 1000,
284-
center_distance_threshold: float = 0.4,
284+
center_distance_threshold: Optional[float] = 0.4,
285285
boundary_distance_threshold: Optional[float] = None,
286286
fg_threshold: float = 0.5,
287287
distance_smoothing: float = 0.0,
@@ -342,12 +342,16 @@ def distance_watershed_implementation(
342342
)
343343

344344
# Compute the seed inputs:
345-
# First, threshold the center distances.
346-
seed_inputs = ThresholdWrapper(center_distances, threshold=center_distance_threshold, operator=np.less)
347-
# Then, if a boundary distance threshold was passed threshold the boundary distances and combine both.
348-
if boundary_distance_threshold is not None:
345+
if boundary_distance_threshold is None and center_distance_threshold is None:
346+
raise ValueError("Either boundary_distance_threshold, center_distance_threshold, or both have to be specifie.")
347+
elif boundary_distance_threshold is None:
348+
seed_inputs = ThresholdWrapper(center_distances, threshold=center_distance_threshold, operator=np.less)
349+
elif center_distance_threshold is None:
350+
seed_inputs = ThresholdWrapper(boundary_distances, threshold=boundary_distance_threshold, operator=np.less)
351+
else:
352+
seed_inputs1 = ThresholdWrapper(center_distances, threshold=center_distance_threshold, operator=np.less)
349353
seed_inputs2 = ThresholdWrapper(boundary_distances, threshold=boundary_distance_threshold, operator=np.less)
350-
seed_inputs = MultiTransformationWrapper(np.logical_and, seed_inputs, seed_inputs2)
354+
seed_inputs = MultiTransformationWrapper(np.logical_and, seed_inputs1, seed_inputs2)
351355

352356
# Compute the seeds via connected components on the seed inputs.
353357
parallel.label(
@@ -611,6 +615,7 @@ def run_unet_segmentation_slurm(
611615
boundary_distance_threshold: float = 0.5,
612616
fg_threshold: float = 0.5,
613617
distance_smoothing: float = 0.0,
618+
original_shape: Optional[Tuple[int, int, int]] = None,
614619
) -> None:
615620
"""Create segmentation from prediction.
616621
@@ -623,14 +628,16 @@ def run_unet_segmentation_slurm(
623628
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
624629
distance_smoothing: The sigma value for smoothing the distance predictions with a gaussian kernel.
625630
This may help to reduce border artifacts. If set to 0 (the default) smoothing is not applied.
631+
original_shape: The original shape of the output, in case the prediction was resized.
626632
"""
627633
min_size = int(min_size)
628-
center_distance_threshold = float(center_distance_threshold)
634+
center_distance_threshold = None if center_distance_threshold is None else float(center_distance_threshold)
629635
boundary_distance_threshold = float(boundary_distance_threshold)
630636
distance_smoothing = float(distance_smoothing)
631637
pmap_out = os.path.join(output_folder, "predictions.zarr")
632638
distance_watershed_implementation(pmap_out, output_folder, center_distance_threshold=center_distance_threshold,
633639
boundary_distance_threshold=boundary_distance_threshold,
634640
fg_threshold=fg_threshold,
635641
distance_smoothing=distance_smoothing,
636-
min_size=min_size)
642+
min_size=min_size,
643+
original_shape=original_shape)

0 commit comments

Comments
 (0)