|
| 1 | +import multiprocessing |
| 2 | +import os |
| 3 | +import threading |
| 4 | +from concurrent import futures |
| 5 | +from threadpoolctl import threadpool_limits |
| 6 | +from typing import Optional, Tuple, Union |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +from numpy.typing import ArrayLike |
| 10 | +import pandas as pd |
| 11 | +from scipy.ndimage import distance_transform_edt |
| 12 | +from skimage.segmentation import watershed |
| 13 | +import zarr |
| 14 | + |
| 15 | +from elf.io import open_file |
| 16 | +from elf.parallel.local_maxima import find_local_maxima |
| 17 | +from flamingo_tools.segmentation.unet_prediction import prediction_impl |
| 18 | +from tqdm import tqdm |
| 19 | + |
| 20 | +from elf.parallel.common import get_blocking |
| 21 | + |
| 22 | + |
| 23 | +def distance_based_marker_extension( |
| 24 | + markers: np.ndarray, |
| 25 | + output: ArrayLike, |
| 26 | + extension_distance: float, |
| 27 | + sampling: Union[float, Tuple[float, ...]], |
| 28 | + block_shape: Tuple[int, ...], |
| 29 | + n_threads: Optional[int] = None, |
| 30 | + verbose: bool = False, |
| 31 | + roi: Optional[Tuple[slice, ...]] = None, |
| 32 | +): |
| 33 | + """ |
| 34 | + Extend SGN detection to emulate shape of SGNs for better visualization. |
| 35 | +
|
| 36 | + Args: |
| 37 | + markers: Array of coordinates for seeding watershed. |
| 38 | + output: Output for watershed. |
| 39 | + extension_distance: Distance in micrometer for extension. |
| 40 | + sampling: Resolution in micrometer. |
| 41 | + block_shape: |
| 42 | + n_threads: |
| 43 | + verbose: |
| 44 | + roi: |
| 45 | + """ |
| 46 | + n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads |
| 47 | + blocking = get_blocking(output, block_shape, roi, n_threads) |
| 48 | + |
| 49 | + lock = threading.Lock() |
| 50 | + |
| 51 | + # determine the correct halo in pixels based on the sampling and the extension distance. |
| 52 | + halo = [round(extension_distance / s) + 2 for s in sampling] |
| 53 | + |
| 54 | + @threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription |
| 55 | + def extend_block(block_id): |
| 56 | + block = blocking.getBlockWithHalo(block_id, halo) |
| 57 | + outer_block = block.outerBlock |
| 58 | + inner_block = block.innerBlock |
| 59 | + |
| 60 | + # get the indices and coordinates of the markers in the INNER block |
| 61 | + mask = ( |
| 62 | + (inner_block.begin[0] <= markers[:, 0]) & (markers[:, 0] <= inner_block.end[0]) & |
| 63 | + (inner_block.begin[1] <= markers[:, 1]) & (markers[:, 1] <= inner_block.end[1]) & |
| 64 | + (inner_block.begin[2] <= markers[:, 2]) & (markers[:, 2] <= inner_block.end[2]) |
| 65 | + ) |
| 66 | + markers_in_block_ids = np.where(mask)[0] |
| 67 | + markers_in_block_coords = markers[markers_in_block_ids] |
| 68 | + |
| 69 | + # proceed if detections fall within inner block |
| 70 | + if len(markers_in_block_coords) > 0: |
| 71 | + markers_in_block_coords = [coord - outer_block.begin for coord in markers_in_block_coords] |
| 72 | + markers_in_block_coords = [[round(c) for c in coord] for coord in markers_in_block_coords] |
| 73 | + |
| 74 | + markers_in_block_coords = np.array(markers_in_block_coords, dtype=int) |
| 75 | + z, y, x = markers_in_block_coords.T |
| 76 | + |
| 77 | + # Shift index by one so that zero is reserved for background id |
| 78 | + markers_in_block_ids += 1 |
| 79 | + |
| 80 | + # Create the seed volume. |
| 81 | + outer_block_shape = tuple(end - begin for begin, end in zip(outer_block.begin, outer_block.end)) |
| 82 | + seeds = np.zeros(outer_block_shape, dtype="uint32") |
| 83 | + seeds[z, y, x] = markers_in_block_ids |
| 84 | + |
| 85 | + # Compute the distance map. |
| 86 | + distance = distance_transform_edt(seeds == 0, sampling=sampling) |
| 87 | + |
| 88 | + # And extend the seeds |
| 89 | + mask = distance < extension_distance |
| 90 | + segmentation = watershed(distance.max() - distance, markers=seeds, mask=mask) |
| 91 | + |
| 92 | + # Write the segmentation. Note: we need to lock here because we write outside of our inner block |
| 93 | + bb = tuple(slice(begin, end) for begin, end in zip(outer_block.begin, outer_block.end)) |
| 94 | + with lock: |
| 95 | + this_output = output[bb] |
| 96 | + this_output[mask] = segmentation[mask] |
| 97 | + output[bb] = this_output |
| 98 | + |
| 99 | + n_blocks = blocking.numberOfBlocks |
| 100 | + with futures.ThreadPoolExecutor(n_threads) as tp: |
| 101 | + list(tqdm( |
| 102 | + tp.map(extend_block, range(n_blocks)), total=n_blocks, desc="Marker extension", disable=not verbose |
| 103 | + )) |
| 104 | + |
| 105 | + |
| 106 | +def sgn_detection( |
| 107 | + input_path: str, |
| 108 | + input_key: str, |
| 109 | + output_folder: str, |
| 110 | + model_path: str, |
| 111 | + extension_distance: float, |
| 112 | + sampling: Union[float, Tuple[float, ...]], |
| 113 | + block_shape: Optional[Tuple[int, int, int]] = None, |
| 114 | + halo: Optional[Tuple[int, int, int]] = None, |
| 115 | + n_threads: Optional[int] = None, |
| 116 | +): |
| 117 | + """Run prediction for SGN detection. |
| 118 | +
|
| 119 | + Args: |
| 120 | + input_path: Input path to image channel for SGN detection. |
| 121 | + input_key: Input key for resolution of image channel and mask channel. |
| 122 | + output_folder: Output folder for SGN segmentation. |
| 123 | + model_path: Path to model for SGN detection. |
| 124 | + block_shape: The block-shape for running the prediction. |
| 125 | + halo: The halo (= block overlap) to use for prediction. |
| 126 | + spot_radius: Radius in pixel to convert spot detection of SGNs into a volume. |
| 127 | + """ |
| 128 | + if block_shape is None: |
| 129 | + block_shape = (12, 128, 128) |
| 130 | + if halo is None: |
| 131 | + halo = (10, 64, 64) |
| 132 | + |
| 133 | + # Skip existing prediction, which is saved in output_folder/predictions.zarr |
| 134 | + skip_prediction = False |
| 135 | + output_path = os.path.join(output_folder, "predictions.zarr") |
| 136 | + prediction_key = "prediction" |
| 137 | + if os.path.exists(output_path) and prediction_key in zarr.open(output_path, "r"): |
| 138 | + skip_prediction = True |
| 139 | + |
| 140 | + if not skip_prediction: |
| 141 | + prediction_impl( |
| 142 | + input_path, input_key, output_folder, model_path, |
| 143 | + scale=None, block_shape=block_shape, halo=halo, |
| 144 | + apply_postprocessing=False, output_channels=1, |
| 145 | + ) |
| 146 | + |
| 147 | + detection_path = os.path.join(output_folder, "SGN_detection.tsv") |
| 148 | + if not os.path.exists(detection_path): |
| 149 | + input_ = zarr.open(output_path, "r")[prediction_key] |
| 150 | + detections_maxima = find_local_maxima( |
| 151 | + input_, block_shape=block_shape, min_distance=4, threshold_abs=0.5, verbose=True, n_threads=16, |
| 152 | + ) |
| 153 | + |
| 154 | + # Save the result in mobie compatible format. |
| 155 | + detections = np.concatenate( |
| 156 | + [np.arange(1, len(detections_maxima) + 1)[:, None], detections_maxima[:, ::-1]], axis=1 |
| 157 | + ) |
| 158 | + detections = pd.DataFrame(detections, columns=["spot_id", "x", "y", "z"]) |
| 159 | + detections.to_csv(detection_path, index=False, sep="\t") |
| 160 | + |
| 161 | + # extend detection |
| 162 | + shape = input_.shape |
| 163 | + chunks = (128, 128, 128) |
| 164 | + segmentation_path = os.path.join(output_folder, "segmentation.zarr") |
| 165 | + output = open_file(segmentation_path, mode="a") |
| 166 | + segmentation_key = "segmentation" |
| 167 | + output_dataset = output.create_dataset( |
| 168 | + segmentation_key, shape=shape, dtype=np.uint64, |
| 169 | + chunks=chunks, compression="gzip" |
| 170 | + ) |
| 171 | + |
| 172 | + distance_based_marker_extension( |
| 173 | + markers=detections_maxima, |
| 174 | + output=output_dataset, |
| 175 | + extension_distance=extension_distance, |
| 176 | + sampling=sampling, |
| 177 | + block_shape=(128, 128, 128), |
| 178 | + n_threads=n_threads, |
| 179 | + ) |
| 180 | + |
0 commit comments