|
8 | 8 | import os |
9 | 9 | import warnings |
10 | 10 | from concurrent import futures |
| 11 | +from functools import partial |
11 | 12 | from typing import Optional, Tuple |
12 | 13 |
|
13 | 14 | import elf.parallel as parallel |
|
17 | 18 | import torch |
18 | 19 | import z5py |
19 | 20 |
|
20 | | -from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper |
| 21 | +from elf.wrapper import ThresholdWrapper, SimpleTransformationWrapper, MultiTransformationWrapper |
21 | 22 | from elf.wrapper.resized_volume import ResizedVolume |
22 | 23 | from elf.io import open_file |
23 | 24 | from torch_em.util import load_model |
|
27 | 28 | import flamingo_tools.s3_utils as s3_utils |
28 | 29 | from flamingo_tools.file_utils import read_image_data |
29 | 30 |
|
| 31 | +try: |
| 32 | + import fastfilters as ff |
| 33 | +except ImportError: |
| 34 | + import vigra.filters as ff |
| 35 | + |
30 | 36 |
|
31 | 37 | class SelectChannel(SimpleTransformationWrapper): |
32 | 38 | """Wrapper to select a chanel from an array-like dataset object. |
@@ -217,61 +223,91 @@ def find_mask_block(block_id): |
217 | 223 | list(tqdm(tp.map(find_mask_block, range(n_blocks)), total=n_blocks)) |
218 | 224 |
|
219 | 225 |
|
220 | | -def segmentation_impl(input_path, output_folder, min_size, original_shape=None): |
221 | | - """@private |
| 226 | +def distance_watershed_implementation( |
| 227 | + input_path: str, |
| 228 | + output_folder: str, |
| 229 | + min_size: int, |
| 230 | + center_distance_threshold: float = 0.4, |
| 231 | + boundary_distance_threshold: Optional[float] = None, |
| 232 | + fg_threshold: float = 0.5, |
| 233 | + distance_smoothing: float = 1.6, |
| 234 | + original_shape: Optional[Tuple[int, int, int]] = None, |
| 235 | +) -> None: |
| 236 | + """ |
| 237 | +
|
| 238 | + Args: |
| 239 | + input_path: |
| 240 | + output_folder: |
| 241 | + min_size: |
| 242 | + center_distance_threshold: |
| 243 | + boundary_distance_threshold: |
| 244 | + fg_threshold: |
| 245 | + distance_smoothing: |
| 246 | + original_shape: |
222 | 247 | """ |
223 | 248 | input_ = open_file(input_path, "r")["prediction"] |
224 | 249 |
|
225 | 250 | # Limit the number of cores for parallelization. |
226 | 251 | n_threads = min(16, mp.cpu_count()) |
227 | 252 |
|
228 | | - # The center distances as input for computing the seeds. |
| 253 | + # Get the foreground mask. |
| 254 | + mask = ThresholdWrapper(SelectChannel(input_, 0), threshold=fg_threshold) |
| 255 | + |
| 256 | + # Get the the center and boundary distances. |
229 | 257 | center_distances = SelectChannel(input_, 1) |
230 | | - block_shape = center_distances.chunks |
| 258 | + boundary_distances = SelectChannel(input_, 2) |
| 259 | + |
| 260 | + # Apply (lazy) smoothing to both. |
| 261 | + smoothing = partial(ff.gaussianSmoothing, sigma=distance_smoothing) |
| 262 | + center_distances = SimpleTransformationWrapper(center_distances, transformation=smoothing) |
| 263 | + boundary_distances = SimpleTransformationWrapper(boundary_distances, transformation=smoothing) |
231 | 264 |
|
232 | | - # Compute the seeds based on smoothed center distances < 0.5. |
| 265 | + # Allocate an zarr array for the seeds. |
| 266 | + block_shape = center_distances.chunks |
233 | 267 | seed_path = os.path.join(output_folder, "seeds.zarr") |
234 | 268 | seed_file = open_file(os.path.join(seed_path), "a") |
235 | 269 | seeds = seed_file.require_dataset( |
236 | 270 | "seeds", shape=center_distances.shape, chunks=block_shape, compression="gzip", dtype="uint64" |
237 | 271 | ) |
238 | 272 |
|
239 | | - fg_threshold = 0.5 |
240 | | - mask = ThresholdWrapper(SelectChannel(input_, 0), threshold=fg_threshold) |
| 273 | + # Compute the seed inputs: |
| 274 | + # First, threshold the center distances. |
| 275 | + seed_inputs = ThresholdWrapper(center_distances, threshold=center_distance_threshold, operator=np.less) |
| 276 | + # Then, if a boundary distance threshold was passed threshold the boundary distances and combine both. |
| 277 | + if boundary_distance_threshold is not None: |
| 278 | + seed_inputs2 = ThresholdWrapper(boundary_distances, threshold=boundary_distance_threshold, operator=np.less) |
| 279 | + seed_inputs = MultiTransformationWrapper(np.logical_and, seed_inputs, seed_inputs2) |
241 | 280 |
|
| 281 | + # Compute the seeds via connected components on the seed inputs. |
242 | 282 | parallel.label( |
243 | | - data=ThresholdWrapper(center_distances, threshold=0.4, operator=np.less), |
244 | | - out=seeds, block_shape=block_shape, mask=mask, verbose=True, n_threads=n_threads |
| 283 | + data=seed_inputs, out=seeds, block_shape=block_shape, mask=mask, verbose=True, n_threads=n_threads |
245 | 284 | ) |
246 | 285 |
|
247 | | - # Run the watershed. |
248 | | - if original_shape is None: |
249 | | - seg_path = os.path.join(output_folder, "segmentation.zarr") |
250 | | - else: |
251 | | - seg_path = os.path.join(output_folder, "seg_downscaled.zarr") |
252 | | - |
| 286 | + # Allocate the zarr array for the segmentation. |
| 287 | + seg_path = os.path.join(output_folder, "segmentation.zarr" if original_shape is None else "seg_downscaled.zarr") |
253 | 288 | seg_file = open_file(seg_path, "a") |
254 | 289 | seg = seg_file.create_dataset( |
255 | 290 | "segmentation", shape=seeds.shape, chunks=block_shape, compression="gzip", dtype="uint64" |
256 | 291 | ) |
257 | 292 |
|
258 | | - hmap = SelectChannel(input_, 2) |
| 293 | + # Compute the segmentation with a seeded watershed |
259 | 294 | halo = (2, 8, 8) |
260 | 295 | parallel.seeded_watershed( |
261 | | - hmap, seeds, out=seg, block_shape=block_shape, halo=halo, mask=mask, verbose=True, |
| 296 | + boundary_distances, seeds, out=seg, block_shape=block_shape, halo=halo, mask=mask, verbose=True, |
262 | 297 | n_threads=n_threads, |
263 | 298 | ) |
264 | 299 |
|
| 300 | + # Apply size filter. |
265 | 301 | if min_size > 0: |
266 | 302 | parallel.size_filter( |
267 | 303 | seg, seg, min_size=min_size, block_shape=block_shape, mask=mask, |
268 | 304 | verbose=True, n_threads=n_threads, relabel=True, |
269 | 305 | ) |
270 | 306 |
|
| 307 | + # Reshape to original shape if given. |
271 | 308 | if original_shape is not None: |
272 | 309 | out_path = os.path.join(output_folder, "segmentation.zarr") |
273 | 310 |
|
274 | | - # This logic should be refactored. |
275 | 311 | output_seg = ResizedVolume(seg, shape=original_shape, order=0) |
276 | 312 | with open_file(out_path, "a") as f: |
277 | 313 | out_seg_volume = f.create_dataset( |
@@ -350,7 +386,7 @@ def run_unet_prediction( |
350 | 386 | ) |
351 | 387 |
|
352 | 388 | pmap_out = os.path.join(output_folder, "predictions.zarr") |
353 | | - segmentation_impl(pmap_out, output_folder, min_size=min_size, original_shape=original_shape) |
| 389 | + distance_watershed_implementation(pmap_out, output_folder, min_size=min_size, original_shape=original_shape) |
354 | 390 |
|
355 | 391 |
|
356 | 392 | # |
@@ -467,4 +503,4 @@ def run_unet_segmentation_slurm(output_folder: str, min_size: int) -> None: |
467 | 503 | """ |
468 | 504 | min_size = int(min_size) |
469 | 505 | pmap_out = os.path.join(output_folder, "predictions.zarr") |
470 | | - segmentation_impl(pmap_out, output_folder, min_size=min_size) |
| 506 | + distance_watershed_implementation(pmap_out, output_folder, min_size=min_size) |
0 commit comments