Skip to content

Commit 4b41cc7

Browse files
Add doc strings and update test
1 parent 9022292 commit 4b41cc7

File tree

2 files changed

+33
-12
lines changed

2 files changed

+33
-12
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import torch
1919
import z5py
2020

21-
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper, MultiTransformationWrapper
21+
from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper
22+
from elf.wrapper.base import MultiTransformationWrapper
2223
from elf.wrapper.resized_volume import ResizedVolume
2324
from elf.io import open_file
2425
from torch_em.util import load_model
@@ -233,17 +234,18 @@ def distance_watershed_implementation(
233234
distance_smoothing: float = 1.6,
234235
original_shape: Optional[Tuple[int, int, int]] = None,
235236
) -> None:
236-
"""
237+
"""Parallel implementation of the distance-prediction based watershed.
237238
238239
Args:
239-
input_path:
240-
output_folder:
241-
min_size:
242-
center_distance_threshold:
243-
boundary_distance_threshold:
244-
fg_threshold:
245-
distance_smoothing:
246-
original_shape:
240+
input_path: The path to the zarr file with the network predictions.
241+
output_folder: The folder for storing the segmentation and intermediate results.
242+
min_size: The minimal size of objects in the segmentation.
243+
center_distance_threshold: The threshold applied to the distance center predictions to derive seeds.
244+
boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
245+
By default this is set to 'None', in which case the boundary distances are not used for the seeds.
246+
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
247+
distance_smoothing: The smoothing factor applied to the distance predictions.
248+
original_shape: The original shape to resize the segmentation to.
247249
"""
248250
input_ = open_file(input_path, "r")["prediction"]
249251

@@ -361,6 +363,10 @@ def run_unet_prediction(
361363
block_shape: Optional[Tuple[int, int, int]] = None,
362364
halo: Optional[Tuple[int, int, int]] = None,
363365
use_mask: bool = True,
366+
center_distance_threshold: float = 0.4,
367+
boundary_distance_threshold: Optional[float] = None,
368+
fg_threshold: float = 0.5,
369+
distance_smoothing: float = 1.6,
364370
) -> None:
365371
"""Run prediction and segmentation with a distance U-Net.
366372
@@ -375,6 +381,11 @@ def run_unet_prediction(
375381
block_shape: The block-shape for running the prediction.
376382
halo: The halo (= block overlap) to use for prediction.
377383
use_mask: Whether to use the masking heuristics to not run inference on empty blocks.
384+
center_distance_threshold: The threshold applied to the distance center predictions to derive seeds.
385+
boundary_distance_threshold: The threshold applied to the boundary predictions to derive seeds.
386+
By default this is set to 'None', in which case the boundary distances are not used for the seeds.
387+
fg_threshold: The threshold applied to the foreground prediction for deriving the watershed mask.
388+
distance_smoothing: The smoothing factor applied to the distance predictions.
378389
"""
379390
os.makedirs(output_folder, exist_ok=True)
380391

@@ -386,7 +397,11 @@ def run_unet_prediction(
386397
)
387398

388399
pmap_out = os.path.join(output_folder, "predictions.zarr")
389-
distance_watershed_implementation(pmap_out, output_folder, min_size=min_size, original_shape=original_shape)
400+
distance_watershed_implementation(
401+
pmap_out, output_folder, min_size=min_size, original_shape=original_shape,
402+
center_distance_threshold=center_distance_threshold, boundary_distance_threshold=boundary_distance_threshold,
403+
fg_threshold=fg_threshold, distance_smoothing=distance_smoothing,
404+
)
390405

391406

392407
#

test/test_segmentation/test_unet_prediction.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def _create_data(self, tmp_dir, use_tif):
3131
f.create_dataset(key, data=data, chunks=(32, 32, 32))
3232
return path, key
3333

34-
def _test_run_unet_prediction(self, use_tif, use_mask):
34+
def _test_run_unet_prediction(self, use_tif, use_mask, **extra_kwargs):
3535
from flamingo_tools.segmentation import run_unet_prediction
3636

3737
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -42,6 +42,7 @@ def _test_run_unet_prediction(self, use_tif, use_mask):
4242
input_path, input_key, output_folder, model_path,
4343
scale=None, min_size=100,
4444
block_shape=(64, 64, 64), halo=(16, 16, 16),
45+
**extra_kwargs
4546
)
4647

4748
expected_path = os.path.join(output_folder, "segmentation.zarr")
@@ -64,6 +65,11 @@ def test_run_unet_prediction_tif(self):
6465
def test_run_unet_prediction_tif_mask(self):
6566
self._test_run_unet_prediction(use_tif=True, use_mask=True)
6667

68+
def test_run_unet_prediction_complex_watershed(self):
69+
self._test_run_unet_prediction(
70+
use_tif=False, use_mask=True, center_distance_threshold=0.5, boundary_distance_threshold=0.5,
71+
)
72+
6773

6874
if __name__ == "__main__":
6975
unittest.main()

0 commit comments

Comments
 (0)