Skip to content

Commit 41dbc51

Browse files
committed
Marker extension of SGN segmentation
1 parent 20369b7 commit 41dbc51

File tree

2 files changed

+173
-0
lines changed

2 files changed

+173
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from threadpoolctl import threadpool_limits
2+
3+
import multiprocessing
4+
from concurrent import futures
5+
from threading import Lock
6+
from typing import Optional, Tuple, Union
7+
8+
import numpy as np
9+
from numpy.typing import ArrayLike
10+
11+
from scipy.ndimage import distance_transform_edt
12+
from skimage.segmentation import watershed
13+
from tqdm import tqdm
14+
15+
from elf.parallel.common import get_blocking
16+
17+
18+
def distance_based_marker_extension(
19+
markers: np.ndarray,
20+
output: ArrayLike,
21+
extension_distance: float,
22+
sampling: Union[float, Tuple[float, ...]],
23+
block_shape: Tuple[int, ...],
24+
n_threads: Optional[int] = None,
25+
verbose: bool = False,
26+
roi: Optional[Tuple[slice, ...]] = None,
27+
):
28+
n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
29+
blocking = get_blocking(output, block_shape, roi, n_threads)
30+
31+
lock = Lock()
32+
33+
# determine the correct halo in pixels based on the sampling and the extension distance.
34+
halo = [round(extension_distance / s) + 2 for s in sampling]
35+
36+
@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
37+
def extend_block(block_id):
38+
block = blocking.getBlockWithHalo(block_id, halo)
39+
outer_block = block.outerBlock
40+
inner_block = block.innerBlock
41+
42+
# TODO get the indices and coordinates of the markers in the INNER block
43+
# markers_in_block_ids = [int(i) for i in np.unique(inner_block)[1:]]
44+
mask = (
45+
(inner_block.begin[0] <= markers[:, 0]) & (markers[:, 0] <= inner_block.end[0]) &
46+
(inner_block.begin[1] <= markers[:, 1]) & (markers[:, 1] <= inner_block.end[1]) &
47+
(inner_block.begin[2] <= markers[:, 2]) & (markers[:, 2] <= inner_block.end[2])
48+
)
49+
markers_in_block_ids = np.where(mask)[0]
50+
markers_in_block_coords = markers[markers_in_block_ids]
51+
52+
# TODO offset the marker coordinates with respect to the OUTER block
53+
markers_in_block_coords = [coord - outer_block.begin for coord in markers_in_block_coords]
54+
markers_in_block_coords = [[round(c) for c in coord] for coord in markers_in_block_coords]
55+
markers_in_block_coords = np.array(markers_in_block_coords, dtype=int)
56+
z, y, x = markers_in_block_coords.T
57+
58+
# Shift index by one so that zero is reserved for background id
59+
markers_in_block_ids += 1
60+
61+
# Create the seed volume.
62+
outer_block_shape = tuple(end - begin for begin, end in zip(outer_block.begin, outer_block.end))
63+
seeds = np.zeros(outer_block_shape, dtype="uint32")
64+
seeds[z, y, x] = markers_in_block_ids
65+
66+
# Compute the distance map.
67+
distance = distance_transform_edt(seeds == 0, sampling=sampling)
68+
69+
# And extend the seeds
70+
mask = distance < extension_distance
71+
segmentation = watershed(distance.max() - distance, markers=seeds, mask=mask)
72+
73+
# Write the segmentation. Note: we need to lock here because we write outside of our inner block
74+
bb = tuple(slice(begin, end) for begin, end in zip(outer_block.begin, outer_block.end))
75+
with lock:
76+
this_output = output[bb]
77+
this_output[mask] = segmentation[mask]
78+
output[bb] = this_output
79+
80+
n_blocks = blocking.numberOfBlocks
81+
with futures.ThreadPoolExecutor(n_threads) as tp:
82+
list(tqdm(
83+
tp.map(extend_block, range(n_blocks)), total=n_blocks, desc="Marker extension", disable=not verbose
84+
))
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import argparse
2+
import os
3+
4+
import numpy as np
5+
import pandas as pd
6+
import zarr
7+
from elf.io import open_file
8+
import scipy.ndimage as ndimage
9+
10+
from flamingo_tools.s3_utils import get_s3_path
11+
from flamingo_tools.segmentation.marker_extension import distance_based_marker_extension
12+
from flamingo_tools.file_utils import read_image_data
13+
14+
15+
def main():
16+
parser = argparse.ArgumentParser(
17+
description="Script for the extension of an SGN detection. "
18+
"Either locally or on an S3 bucket.")
19+
20+
parser.add_argument("-c", "--cochlea", required=True, help="Cochlea in MoBIE.")
21+
parser.add_argument("-s", "--seg_channel", required=True, help="Segmentation channel.")
22+
parser.add_argument("-o", "--output", required=True, help="Output directory for segmentation.")
23+
parser.add_argument("--input", default=None, help="Input tif.")
24+
25+
parser.add_argument("--component_labels", type=int, nargs="+", default=[1],
26+
help="Component labels of SGN_detect.")
27+
parser.add_argument("-d", "--extension_distance", type=float, default=8, help="Extension distance.")
28+
parser.add_argument("-r", "--resolution", type=float, nargs="+", default=[3.0, 1.887779, 1.887779],
29+
help="Resolution of input in micrometer.")
30+
31+
args = parser.parse_args()
32+
33+
block_shape = (128, 128, 128)
34+
chunks = (128, 128, 128)
35+
36+
if len(args.resolution) == 1:
37+
resolution = tuple(args.resolution, args.resolution, args.resolution)
38+
else:
39+
resolution = tuple(args.resolution)
40+
41+
if args.input is not None:
42+
data = read_image_data(args.input, None)
43+
shape = data.shape
44+
# Compute centers of mass for each label (excluding background = 0)
45+
markers = ndimage.center_of_mass(np.ones_like(data), data, index=np.unique(data[data > 0]))
46+
markers = np.array(markers)
47+
48+
else:
49+
50+
s3_path = os.path.join(f"{args.cochlea}", "tables", f"{args.seg_channel}", "default.tsv")
51+
tsv_path, fs = get_s3_path(s3_path)
52+
with fs.open(tsv_path, 'r') as f:
53+
table = pd.read_csv(f, sep="\t")
54+
55+
table = table.loc[table["component_labels"].isin(args.component_labels)]
56+
markers = list(zip(table["anchor_x"] / resolution[0],
57+
table["anchor_y"] / resolution[1],
58+
table["anchor_z"] / resolution[2]))
59+
markers = np.array(markers)
60+
61+
s3_path = os.path.join(f"{args.cochlea}", "images", "ome-zarr", f"{args.seg_channel}.ome.zarr")
62+
input_key = "s0"
63+
s3_store, fs = get_s3_path(s3_path)
64+
with zarr.open(s3_store, mode="r") as f:
65+
data = f[input_key][:].astype("float32")
66+
67+
shape = data.shape
68+
69+
output_key = "extended_segmentation"
70+
output_path = os.path.join(args.output, f"{args.cochlea}-{args.seg_channel}.zarr")
71+
72+
output = open_file(output_path, mode="a")
73+
output_dataset = output.create_dataset(
74+
output_key, shape=shape, dtype=data.dtype,
75+
chunks=chunks, compression="gzip"
76+
)
77+
78+
distance_based_marker_extension(
79+
markers=markers,
80+
output=output_dataset,
81+
extension_distance=args.extension_distance,
82+
sampling=resolution,
83+
block_shape=block_shape,
84+
n_threads=16,
85+
)
86+
87+
88+
if __name__ == "__main__":
89+
main()

0 commit comments

Comments
 (0)