| 
1 |  | -import multiprocessing as mp  | 
 | 1 | +import multiprocessing  | 
 | 2 | +import os  | 
2 | 3 | import threading  | 
3 | 4 | from concurrent import futures  | 
4 |  | -import os  | 
5 |  | -from typing import Optional, Tuple  | 
 | 5 | +from threadpoolctl import threadpool_limits  | 
 | 6 | +from typing import Optional, Tuple, Union  | 
6 | 7 | 
 
  | 
7 | 8 | import numpy as np  | 
 | 9 | +from numpy.typing import ArrayLike  | 
8 | 10 | import pandas as pd  | 
 | 11 | +from scipy.ndimage import distance_transform_edt  | 
 | 12 | +from skimage.segmentation import watershed  | 
9 | 13 | import zarr  | 
10 | 14 | 
 
  | 
11 | 15 | from elf.io import open_file  | 
12 | 16 | from elf.parallel.local_maxima import find_local_maxima  | 
13 | 17 | from flamingo_tools.segmentation.unet_prediction import prediction_impl  | 
14 | 18 | from tqdm import tqdm  | 
15 | 19 | 
 
  | 
 | 20 | +from elf.parallel.common import get_blocking  | 
 | 21 | + | 
 | 22 | + | 
 | 23 | +def distance_based_marker_extension(  | 
 | 24 | +    markers: np.ndarray,  | 
 | 25 | +    output: ArrayLike,  | 
 | 26 | +    extension_distance: float,  | 
 | 27 | +    sampling: Union[float, Tuple[float, ...]],  | 
 | 28 | +    block_shape: Tuple[int, ...],  | 
 | 29 | +    n_threads: Optional[int] = None,  | 
 | 30 | +    verbose: bool = False,  | 
 | 31 | +    roi: Optional[Tuple[slice, ...]] = None,  | 
 | 32 | +):  | 
 | 33 | +    """  | 
 | 34 | +    Extend SGN detection to emulate shape of SGNs for better visualization.  | 
 | 35 | +
  | 
 | 36 | +    Args:  | 
 | 37 | +        markers: Array of coordinates for seeding watershed.  | 
 | 38 | +        output: Output for watershed.  | 
 | 39 | +        extension_distance: Distance in micrometer for extension.  | 
 | 40 | +        sampling: Resolution in micrometer.  | 
 | 41 | +        block_shape:  | 
 | 42 | +        n_threads:  | 
 | 43 | +        verbose:  | 
 | 44 | +        roi:  | 
 | 45 | +    """  | 
 | 46 | +    n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads  | 
 | 47 | +    blocking = get_blocking(output, block_shape, roi, n_threads)  | 
 | 48 | + | 
 | 49 | +    lock = threading.Lock()  | 
 | 50 | + | 
 | 51 | +    # determine the correct halo in pixels based on the sampling and the extension distance.  | 
 | 52 | +    halo = [round(extension_distance / s) + 2 for s in sampling]  | 
 | 53 | + | 
 | 54 | +    @threadpool_limits.wrap(limits=1)  # restrict the numpy threadpool to 1 to avoid oversubscription  | 
 | 55 | +    def extend_block(block_id):  | 
 | 56 | +        block = blocking.getBlockWithHalo(block_id, halo)  | 
 | 57 | +        outer_block = block.outerBlock  | 
 | 58 | +        inner_block = block.innerBlock  | 
 | 59 | + | 
 | 60 | +        # TODO get the indices and coordinates of the markers in the INNER block  | 
 | 61 | +        # markers_in_block_ids = [int(i) for i in np.unique(inner_block)[1:]]  | 
 | 62 | +        mask = (  | 
 | 63 | +            (inner_block.begin[0] <= markers[:, 0]) & (markers[:, 0] <= inner_block.end[0]) &  | 
 | 64 | +            (inner_block.begin[1] <= markers[:, 1]) & (markers[:, 1] <= inner_block.end[1]) &  | 
 | 65 | +            (inner_block.begin[2] <= markers[:, 2]) & (markers[:, 2] <= inner_block.end[2])  | 
 | 66 | +        )  | 
 | 67 | +        markers_in_block_ids = np.where(mask)[0]  | 
 | 68 | +        markers_in_block_coords = markers[markers_in_block_ids]  | 
 | 69 | + | 
 | 70 | +        # TODO offset the marker coordinates with respect to the OUTER block  | 
 | 71 | +        markers_in_block_coords = [coord - outer_block.begin for coord in markers_in_block_coords]  | 
 | 72 | +        markers_in_block_coords = [[round(c) for c in coord] for coord in markers_in_block_coords]  | 
 | 73 | +        markers_in_block_coords = np.array(markers_in_block_coords, dtype=int)  | 
 | 74 | +        z, y, x = markers_in_block_coords.T  | 
 | 75 | + | 
 | 76 | +        # Shift index by one so that zero is reserved for background id  | 
 | 77 | +        markers_in_block_ids += 1  | 
 | 78 | + | 
 | 79 | +        # Create the seed volume.  | 
 | 80 | +        outer_block_shape = tuple(end - begin for begin, end in zip(outer_block.begin, outer_block.end))  | 
 | 81 | +        seeds = np.zeros(outer_block_shape, dtype="uint32")  | 
 | 82 | +        seeds[z, y, x] = markers_in_block_ids  | 
 | 83 | + | 
 | 84 | +        # Compute the distance map.  | 
 | 85 | +        distance = distance_transform_edt(seeds == 0, sampling=sampling)  | 
 | 86 | + | 
 | 87 | +        # And extend the seeds  | 
 | 88 | +        mask = distance < extension_distance  | 
 | 89 | +        segmentation = watershed(distance.max() - distance, markers=seeds, mask=mask)  | 
 | 90 | + | 
 | 91 | +        # Write the segmentation. Note: we need to lock here because we write outside of our inner block  | 
 | 92 | +        bb = tuple(slice(begin, end) for begin, end in zip(outer_block.begin, outer_block.end))  | 
 | 93 | +        with lock:  | 
 | 94 | +            this_output = output[bb]  | 
 | 95 | +            this_output[mask] = segmentation[mask]  | 
 | 96 | +            output[bb] = this_output  | 
 | 97 | + | 
 | 98 | +    n_blocks = blocking.numberOfBlocks  | 
 | 99 | +    with futures.ThreadPoolExecutor(n_threads) as tp:  | 
 | 100 | +        list(tqdm(  | 
 | 101 | +            tp.map(extend_block, range(n_blocks)), total=n_blocks, desc="Marker extension", disable=not verbose  | 
 | 102 | +        ))  | 
 | 103 | + | 
16 | 104 | 
 
  | 
17 | 105 | def sgn_detection(  | 
18 | 106 |     input_path: str,  | 
19 | 107 |     input_key: str,  | 
20 | 108 |     output_folder: str,  | 
21 | 109 |     model_path: str,  | 
 | 110 | +    extension_distance: float,  | 
 | 111 | +    sampling: Union[float, Tuple[float, ...]],  | 
22 | 112 |     block_shape: Optional[Tuple[int, int, int]] = None,  | 
23 | 113 |     halo: Optional[Tuple[int, int, int]] = None,  | 
24 |  | -    spot_radius: int = 2,  | 
 | 114 | +    n_threads: Optional[int] = None,  | 
25 | 115 | ):  | 
26 |  | -    """Run prediction for sgn detection.  | 
 | 116 | +    """Run prediction for SGN detection.  | 
27 | 117 | 
  | 
28 | 118 |     Args:  | 
29 | 119 |         input_path: Input path to image channel for SGN detection.  | 
@@ -70,22 +160,14 @@ def sgn_detection(  | 
70 | 160 |             chunks=chunks, compression="gzip"  | 
71 | 161 |         )  | 
72 | 162 | 
 
  | 
73 |  | -        lock = threading.Lock()  | 
74 |  | - | 
75 |  | -        def add_halo_segm(detection_index):  | 
76 |  | -            """Create a segmentation volume around all detected spots.  | 
77 |  | -            """  | 
78 |  | -            coord = list(detections[detection_index])  | 
79 |  | -            block_begin = [round(c) - spot_radius for c in coord]  | 
80 |  | -            block_end = [round(c) + spot_radius for c in coord]  | 
81 |  | -            volume_index = tuple(slice(beg, end) for beg, end in zip(block_begin, block_end))  | 
82 |  | -            with lock:  | 
83 |  | -                output_dataset[volume_index] = int(detection_index) + 1  | 
84 |  | - | 
85 |  | -        # Limit the number of cores for parallelization.  | 
86 |  | -        n_threads = min(16, mp.cpu_count())  | 
87 |  | -        with futures.ThreadPoolExecutor(n_threads) as filter_pool:  | 
88 |  | -            list(tqdm(filter_pool.map(add_halo_segm, range(len(detections))), total=len(detections)))  | 
 | 163 | +        distance_based_marker_extension(  | 
 | 164 | +            markers=detections,  | 
 | 165 | +            output=output_dataset,  | 
 | 166 | +            extension_distance=extension_distance,  | 
 | 167 | +            sampling=sampling,  | 
 | 168 | +            block_shape=(128, 128, 128),  | 
 | 169 | +            n_threads=n_threads,  | 
 | 170 | +        )  | 
89 | 171 | 
 
  | 
90 | 172 |         # Save the result in mobie compatible format.  | 
91 | 173 |         detections = np.concatenate(  | 
 | 
0 commit comments