Skip to content

Commit 01d1352

Browse files
committed
Include marker extension in SGN detection
1 parent ad032c8 commit 01d1352

File tree

4 files changed

+115
-107
lines changed

4 files changed

+115
-107
lines changed

flamingo_tools/segmentation/marker_extension.py

Lines changed: 0 additions & 84 deletions
This file was deleted.

flamingo_tools/segmentation/sgn_detection.py

Lines changed: 103 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,119 @@
1-
import multiprocessing as mp
1+
import multiprocessing
2+
import os
23
import threading
34
from concurrent import futures
4-
import os
5-
from typing import Optional, Tuple
5+
from threadpoolctl import threadpool_limits
6+
from typing import Optional, Tuple, Union
67

78
import numpy as np
9+
from numpy.typing import ArrayLike
810
import pandas as pd
11+
from scipy.ndimage import distance_transform_edt
12+
from skimage.segmentation import watershed
913
import zarr
1014

1115
from elf.io import open_file
1216
from elf.parallel.local_maxima import find_local_maxima
1317
from flamingo_tools.segmentation.unet_prediction import prediction_impl
1418
from tqdm import tqdm
1519

20+
from elf.parallel.common import get_blocking
21+
22+
23+
def distance_based_marker_extension(
24+
markers: np.ndarray,
25+
output: ArrayLike,
26+
extension_distance: float,
27+
sampling: Union[float, Tuple[float, ...]],
28+
block_shape: Tuple[int, ...],
29+
n_threads: Optional[int] = None,
30+
verbose: bool = False,
31+
roi: Optional[Tuple[slice, ...]] = None,
32+
):
33+
"""
34+
Extend SGN detection to emulate shape of SGNs for better visualization.
35+
36+
Args:
37+
markers: Array of coordinates for seeding watershed.
38+
output: Output for watershed.
39+
extension_distance: Distance in micrometer for extension.
40+
sampling: Resolution in micrometer.
41+
block_shape:
42+
n_threads:
43+
verbose:
44+
roi:
45+
"""
46+
n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
47+
blocking = get_blocking(output, block_shape, roi, n_threads)
48+
49+
lock = threading.Lock()
50+
51+
# determine the correct halo in pixels based on the sampling and the extension distance.
52+
halo = [round(extension_distance / s) + 2 for s in sampling]
53+
54+
@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
55+
def extend_block(block_id):
56+
block = blocking.getBlockWithHalo(block_id, halo)
57+
outer_block = block.outerBlock
58+
inner_block = block.innerBlock
59+
60+
# TODO get the indices and coordinates of the markers in the INNER block
61+
# markers_in_block_ids = [int(i) for i in np.unique(inner_block)[1:]]
62+
mask = (
63+
(inner_block.begin[0] <= markers[:, 0]) & (markers[:, 0] <= inner_block.end[0]) &
64+
(inner_block.begin[1] <= markers[:, 1]) & (markers[:, 1] <= inner_block.end[1]) &
65+
(inner_block.begin[2] <= markers[:, 2]) & (markers[:, 2] <= inner_block.end[2])
66+
)
67+
markers_in_block_ids = np.where(mask)[0]
68+
markers_in_block_coords = markers[markers_in_block_ids]
69+
70+
# TODO offset the marker coordinates with respect to the OUTER block
71+
markers_in_block_coords = [coord - outer_block.begin for coord in markers_in_block_coords]
72+
markers_in_block_coords = [[round(c) for c in coord] for coord in markers_in_block_coords]
73+
markers_in_block_coords = np.array(markers_in_block_coords, dtype=int)
74+
z, y, x = markers_in_block_coords.T
75+
76+
# Shift index by one so that zero is reserved for background id
77+
markers_in_block_ids += 1
78+
79+
# Create the seed volume.
80+
outer_block_shape = tuple(end - begin for begin, end in zip(outer_block.begin, outer_block.end))
81+
seeds = np.zeros(outer_block_shape, dtype="uint32")
82+
seeds[z, y, x] = markers_in_block_ids
83+
84+
# Compute the distance map.
85+
distance = distance_transform_edt(seeds == 0, sampling=sampling)
86+
87+
# And extend the seeds
88+
mask = distance < extension_distance
89+
segmentation = watershed(distance.max() - distance, markers=seeds, mask=mask)
90+
91+
# Write the segmentation. Note: we need to lock here because we write outside of our inner block
92+
bb = tuple(slice(begin, end) for begin, end in zip(outer_block.begin, outer_block.end))
93+
with lock:
94+
this_output = output[bb]
95+
this_output[mask] = segmentation[mask]
96+
output[bb] = this_output
97+
98+
n_blocks = blocking.numberOfBlocks
99+
with futures.ThreadPoolExecutor(n_threads) as tp:
100+
list(tqdm(
101+
tp.map(extend_block, range(n_blocks)), total=n_blocks, desc="Marker extension", disable=not verbose
102+
))
103+
16104

17105
def sgn_detection(
18106
input_path: str,
19107
input_key: str,
20108
output_folder: str,
21109
model_path: str,
110+
extension_distance: float,
111+
sampling: Union[float, Tuple[float, ...]],
22112
block_shape: Optional[Tuple[int, int, int]] = None,
23113
halo: Optional[Tuple[int, int, int]] = None,
24-
spot_radius: int = 2,
114+
n_threads: Optional[int] = None,
25115
):
26-
"""Run prediction for sgn detection.
116+
"""Run prediction for SGN detection.
27117
28118
Args:
29119
input_path: Input path to image channel for SGN detection.
@@ -70,22 +160,14 @@ def sgn_detection(
70160
chunks=chunks, compression="gzip"
71161
)
72162

73-
lock = threading.Lock()
74-
75-
def add_halo_segm(detection_index):
76-
"""Create a segmentation volume around all detected spots.
77-
"""
78-
coord = list(detections[detection_index])
79-
block_begin = [round(c) - spot_radius for c in coord]
80-
block_end = [round(c) + spot_radius for c in coord]
81-
volume_index = tuple(slice(beg, end) for beg, end in zip(block_begin, block_end))
82-
with lock:
83-
output_dataset[volume_index] = int(detection_index) + 1
84-
85-
# Limit the number of cores for parallelization.
86-
n_threads = min(16, mp.cpu_count())
87-
with futures.ThreadPoolExecutor(n_threads) as filter_pool:
88-
list(tqdm(filter_pool.map(add_halo_segm, range(len(detections))), total=len(detections)))
163+
distance_based_marker_extension(
164+
markers=detections,
165+
output=output_dataset,
166+
extension_distance=extension_distance,
167+
sampling=sampling,
168+
block_shape=(128, 128, 128),
169+
n_threads=n_threads,
170+
)
89171

90172
# Save the result in mobie compatible format.
91173
detections = np.concatenate(

scripts/sgn_detection/sgn_detection.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ def main():
1414
parser.add_argument("-k", "--input_key", default=None,
1515
help="The key / internal path to image data.")
1616

17+
parser.add_argument("-d", "--extension_distance", type=float, default=12, help="Extension distance.")
18+
parser.add_argument("-r", "--resolution", type=float, nargs="+", default=[3.0, 1.887779, 1.887779],
19+
help="Resolution of input in micrometer.")
20+
1721
parser.add_argument("--s3", action="store_true", help="Use S3 bucket.")
1822
parser.add_argument("--s3_credentials", type=str, default=None,
1923
help="Input file containing S3 credentials. "
@@ -28,6 +32,11 @@ def main():
2832
block_shape = (12, 128, 128)
2933
halo = (10, 64, 64)
3034

35+
if len(args.resolution) == 1:
36+
resolution = tuple(args.resolution, args.resolution, args.resolution)
37+
else:
38+
resolution = tuple(args.resolution)
39+
3140
if args.s3:
3241
input_path, fs = s3_utils.get_s3_path(args.input, bucket_name=args.s3_bucket_name,
3342
service_endpoint=args.s3_service_endpoint,
@@ -37,7 +46,8 @@ def main():
3746
input_path = args.input
3847

3948
sgn_detection(input_path=input_path, input_key=args.input_key, output_folder=args.output_folder,
40-
model_path=args.model, block_shape=block_shape, halo=halo)
49+
model_path=args.model, block_shape=block_shape, halo=halo,
50+
extension_distance=args.extension_distance, sampling=resolution)
4151

4252

4353
if __name__ == "__main__":

scripts/sgn_detection/sgn_marker_extension.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import scipy.ndimage as ndimage
99

1010
from flamingo_tools.s3_utils import get_s3_path
11-
from flamingo_tools.segmentation.marker_extension import distance_based_marker_extension
11+
from flamingo_tools.segmentation.sgn_detection import distance_based_marker_extension
1212
from flamingo_tools.file_utils import read_image_data
1313

1414

0 commit comments

Comments
 (0)