Skip to content

Commit 7273c48

Browse files
Merge pull request #69 from computational-cell-analytics/sgn_detection
Implement SGN detection
2 parents 936e59f + 47f501c commit 7273c48

File tree

5 files changed

+384
-0
lines changed

5 files changed

+384
-0
lines changed

environment.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ channels:
66
dependencies:
77
- cluster_tools
88
- scikit-image
9+
- pooch
910
- pybdv
1011
- pytorch
1112
- s3fs
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
import multiprocessing
2+
import os
3+
import threading
4+
from concurrent import futures
5+
from threadpoolctl import threadpool_limits
6+
from typing import Optional, Tuple, Union
7+
8+
import numpy as np
9+
from numpy.typing import ArrayLike
10+
import pandas as pd
11+
from scipy.ndimage import distance_transform_edt
12+
from skimage.segmentation import watershed
13+
import zarr
14+
15+
from elf.io import open_file
16+
from elf.parallel.local_maxima import find_local_maxima
17+
from flamingo_tools.segmentation.unet_prediction import prediction_impl
18+
from tqdm import tqdm
19+
20+
from elf.parallel.common import get_blocking
21+
22+
23+
def distance_based_marker_extension(
24+
markers: np.ndarray,
25+
output: ArrayLike,
26+
extension_distance: float,
27+
sampling: Union[float, Tuple[float, ...]],
28+
block_shape: Tuple[int, ...],
29+
n_threads: Optional[int] = None,
30+
verbose: bool = False,
31+
roi: Optional[Tuple[slice, ...]] = None,
32+
):
33+
"""
34+
Extend SGN detection to emulate shape of SGNs for better visualization.
35+
36+
Args:
37+
markers: Array of coordinates for seeding watershed.
38+
output: Output for watershed.
39+
extension_distance: Distance in micrometer for extension.
40+
sampling: Resolution in micrometer.
41+
block_shape:
42+
n_threads:
43+
verbose:
44+
roi:
45+
"""
46+
n_threads = multiprocessing.cpu_count() if n_threads is None else n_threads
47+
blocking = get_blocking(output, block_shape, roi, n_threads)
48+
49+
lock = threading.Lock()
50+
51+
# determine the correct halo in pixels based on the sampling and the extension distance.
52+
halo = [round(extension_distance / s) + 2 for s in sampling]
53+
54+
@threadpool_limits.wrap(limits=1) # restrict the numpy threadpool to 1 to avoid oversubscription
55+
def extend_block(block_id):
56+
block = blocking.getBlockWithHalo(block_id, halo)
57+
outer_block = block.outerBlock
58+
inner_block = block.innerBlock
59+
60+
# get the indices and coordinates of the markers in the INNER block
61+
mask = (
62+
(inner_block.begin[0] <= markers[:, 0]) & (markers[:, 0] <= inner_block.end[0]) &
63+
(inner_block.begin[1] <= markers[:, 1]) & (markers[:, 1] <= inner_block.end[1]) &
64+
(inner_block.begin[2] <= markers[:, 2]) & (markers[:, 2] <= inner_block.end[2])
65+
)
66+
markers_in_block_ids = np.where(mask)[0]
67+
markers_in_block_coords = markers[markers_in_block_ids]
68+
69+
# proceed if detections fall within inner block
70+
if len(markers_in_block_coords) > 0:
71+
markers_in_block_coords = [coord - outer_block.begin for coord in markers_in_block_coords]
72+
markers_in_block_coords = [[round(c) for c in coord] for coord in markers_in_block_coords]
73+
74+
markers_in_block_coords = np.array(markers_in_block_coords, dtype=int)
75+
z, y, x = markers_in_block_coords.T
76+
77+
# Shift index by one so that zero is reserved for background id
78+
markers_in_block_ids += 1
79+
80+
# Create the seed volume.
81+
outer_block_shape = tuple(end - begin for begin, end in zip(outer_block.begin, outer_block.end))
82+
seeds = np.zeros(outer_block_shape, dtype="uint32")
83+
seeds[z, y, x] = markers_in_block_ids
84+
85+
# Compute the distance map.
86+
distance = distance_transform_edt(seeds == 0, sampling=sampling)
87+
88+
# And extend the seeds
89+
mask = distance < extension_distance
90+
segmentation = watershed(distance.max() - distance, markers=seeds, mask=mask)
91+
92+
# Write the segmentation. Note: we need to lock here because we write outside of our inner block
93+
bb = tuple(slice(begin, end) for begin, end in zip(outer_block.begin, outer_block.end))
94+
with lock:
95+
this_output = output[bb]
96+
this_output[mask] = segmentation[mask]
97+
output[bb] = this_output
98+
99+
n_blocks = blocking.numberOfBlocks
100+
with futures.ThreadPoolExecutor(n_threads) as tp:
101+
list(tqdm(
102+
tp.map(extend_block, range(n_blocks)), total=n_blocks, desc="Marker extension", disable=not verbose
103+
))
104+
105+
106+
def sgn_detection(
107+
input_path: str,
108+
input_key: str,
109+
output_folder: str,
110+
model_path: str,
111+
extension_distance: float,
112+
sampling: Union[float, Tuple[float, ...]],
113+
block_shape: Optional[Tuple[int, int, int]] = None,
114+
halo: Optional[Tuple[int, int, int]] = None,
115+
n_threads: Optional[int] = None,
116+
):
117+
"""Run prediction for SGN detection.
118+
119+
Args:
120+
input_path: Input path to image channel for SGN detection.
121+
input_key: Input key for resolution of image channel and mask channel.
122+
output_folder: Output folder for SGN segmentation.
123+
model_path: Path to model for SGN detection.
124+
block_shape: The block-shape for running the prediction.
125+
halo: The halo (= block overlap) to use for prediction.
126+
spot_radius: Radius in pixel to convert spot detection of SGNs into a volume.
127+
"""
128+
if block_shape is None:
129+
block_shape = (12, 128, 128)
130+
if halo is None:
131+
halo = (10, 64, 64)
132+
133+
# Skip existing prediction, which is saved in output_folder/predictions.zarr
134+
skip_prediction = False
135+
output_path = os.path.join(output_folder, "predictions.zarr")
136+
prediction_key = "prediction"
137+
if os.path.exists(output_path) and prediction_key in zarr.open(output_path, "r"):
138+
skip_prediction = True
139+
140+
if not skip_prediction:
141+
prediction_impl(
142+
input_path, input_key, output_folder, model_path,
143+
scale=None, block_shape=block_shape, halo=halo,
144+
apply_postprocessing=False, output_channels=1,
145+
)
146+
147+
detection_path = os.path.join(output_folder, "SGN_detection.tsv")
148+
if not os.path.exists(detection_path):
149+
input_ = zarr.open(output_path, "r")[prediction_key]
150+
detections_maxima = find_local_maxima(
151+
input_, block_shape=block_shape, min_distance=4, threshold_abs=0.5, verbose=True, n_threads=16,
152+
)
153+
154+
# Save the result in mobie compatible format.
155+
detections = np.concatenate(
156+
[np.arange(1, len(detections_maxima) + 1)[:, None], detections_maxima[:, ::-1]], axis=1
157+
)
158+
detections = pd.DataFrame(detections, columns=["spot_id", "x", "y", "z"])
159+
detections.to_csv(detection_path, index=False, sep="\t")
160+
161+
# extend detection
162+
shape = input_.shape
163+
chunks = (128, 128, 128)
164+
segmentation_path = os.path.join(output_folder, "segmentation.zarr")
165+
output = open_file(segmentation_path, mode="a")
166+
segmentation_key = "segmentation"
167+
output_dataset = output.create_dataset(
168+
segmentation_key, shape=shape, dtype=np.uint64,
169+
chunks=chunks, compression="gzip"
170+
)
171+
172+
distance_based_marker_extension(
173+
markers=detections_maxima,
174+
output=output_dataset,
175+
extension_distance=extension_distance,
176+
sampling=sampling,
177+
block_shape=(128, 128, 128),
178+
n_threads=n_threads,
179+
)
180+
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=synapse-detect
3+
#SBATCH -t 03:00:00 # estimated time, adapt to your needs
4+
#SBATCH --mail-type=FAIL # send mail when job begins and ends
5+
6+
#SBATCH -p grete:shared # the partition
7+
#SBATCH -G A100:1 # For requesting 1 A100 GPU.
8+
#SBATCH -A nim00007
9+
#SBATCH -c 4
10+
#SBATCH --mem 32G
11+
12+
source ~/.bashrc
13+
# micromamba activate micro-sam_gpu
14+
micromamba activate sam
15+
16+
# Print out some info.
17+
echo "Submitting job with sbatch from directory: ${SLURM_SUBMIT_DIR}"
18+
echo "Home directory: ${HOME}"
19+
echo "Working directory: $PWD"
20+
echo "Current node: ${SLURM_NODELIST}"
21+
22+
# Run the script
23+
#python myprogram.py $SLURM_ARRAY_TASK_ID
24+
25+
# SCRIPT_REPO=/user/schilling40/u15000/flamingo-tools
26+
SCRIPT_REPO=/user/pape41/u12086/Work/my_projects/flamingo-tools
27+
cd "$SCRIPT_REPO"/flamingo_tools/segmentation/ || exit
28+
29+
export SCRIPT_DIR=$SCRIPT_REPO/scripts
30+
31+
# name of cochlea, as it appears in MoBIE and the NHR
32+
COCHLEA=$1
33+
# model of SGN detection, e.g. v5b
34+
MODEL_VERSION=$2
35+
36+
# data on NHR
37+
MOBIE_DIR=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/
38+
export INPUT_PATH="$MOBIE_DIR"/"$COCHLEA"/images/ome-zarr/PV.ome.zarr
39+
40+
# data on MoBIE
41+
# export INPUT_PATH="$COCHLEA"/images/ome-zarr/PV.ome.zarr
42+
# use --s3 flag for script
43+
44+
export OUTPUT_FOLDER=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/"$COCHLEA"/SGN_detect-"$MODEL_VERSION"
45+
46+
if ! [[ -f $OUTPUT_FOLDER ]] ; then
47+
mkdir -p "$OUTPUT_FOLDER"
48+
fi
49+
50+
export MODEL=/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/sgn-detection-"$MODEL_VERSION".pt
51+
INPUT_KEY="s0"
52+
53+
echo "OUTPUT_FOLDER $OUTPUT_FOLDER"
54+
echo "MODEL $MODEL"
55+
56+
python ~/flamingo-tools/scripts/sgn_detection/sgn_detection.py \
57+
--input "$INPUT_PATH" \
58+
--input_key $INPUT_KEY \
59+
--output_folder "$OUTPUT_FOLDER" \
60+
--model "$MODEL"
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import argparse
2+
3+
import flamingo_tools.s3_utils as s3_utils
4+
from flamingo_tools.segmentation.sgn_detection import sgn_detection
5+
6+
7+
def main():
8+
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument("-i", "--input", required=True, help="Path to image data to be segmented.")
11+
parser.add_argument("-o", "--output_folder", required=True, help="Path to output folder.")
12+
parser.add_argument("-m", "--model", required=True,
13+
help="Path to SGN detection model.")
14+
parser.add_argument("-k", "--input_key", default=None,
15+
help="The key / internal path to image data.")
16+
17+
parser.add_argument("-d", "--extension_distance", type=float, default=12, help="Extension distance.")
18+
parser.add_argument("-r", "--resolution", type=float, nargs="+", default=[3.0, 1.887779, 1.887779],
19+
help="Resolution of input in micrometer.")
20+
21+
parser.add_argument("--s3", action="store_true", help="Use S3 bucket.")
22+
parser.add_argument("--s3_credentials", type=str, default=None,
23+
help="Input file containing S3 credentials. "
24+
"Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.")
25+
parser.add_argument("--s3_bucket_name", type=str, default=None,
26+
help="S3 bucket name. Optional if BUCKET_NAME was exported.")
27+
parser.add_argument("--s3_service_endpoint", type=str, default=None,
28+
help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.")
29+
30+
args = parser.parse_args()
31+
32+
block_shape = (12, 128, 128)
33+
halo = (10, 64, 64)
34+
35+
if len(args.resolution) == 1:
36+
resolution = tuple(args.resolution, args.resolution, args.resolution)
37+
else:
38+
resolution = tuple(args.resolution)
39+
40+
if args.s3:
41+
input_path, fs = s3_utils.get_s3_path(args.input, bucket_name=args.s3_bucket_name,
42+
service_endpoint=args.s3_service_endpoint,
43+
credential_file=args.s3_credentials)
44+
45+
else:
46+
input_path = args.input
47+
48+
sgn_detection(input_path=input_path, input_key=args.input_key, output_folder=args.output_folder,
49+
model_path=args.model, block_shape=block_shape, halo=halo,
50+
extension_distance=args.extension_distance, sampling=resolution)
51+
52+
53+
if __name__ == "__main__":
54+
main()
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import argparse
2+
import os
3+
4+
import numpy as np
5+
import pandas as pd
6+
import zarr
7+
from elf.io import open_file
8+
import scipy.ndimage as ndimage
9+
10+
from flamingo_tools.s3_utils import get_s3_path
11+
from flamingo_tools.segmentation.sgn_detection import distance_based_marker_extension
12+
from flamingo_tools.file_utils import read_image_data
13+
14+
15+
def main():
16+
parser = argparse.ArgumentParser(
17+
description="Script for the extension of an SGN detection. "
18+
"Either locally or on an S3 bucket.")
19+
20+
parser.add_argument("-c", "--cochlea", required=True, help="Cochlea in MoBIE.")
21+
parser.add_argument("-s", "--seg_channel", required=True, help="Segmentation channel.")
22+
parser.add_argument("-o", "--output", required=True, help="Output directory for segmentation.")
23+
parser.add_argument("--input", default=None, help="Input tif.")
24+
25+
parser.add_argument("--component_labels", type=int, nargs="+", default=[1],
26+
help="Component labels of SGN_detect.")
27+
parser.add_argument("-d", "--extension_distance", type=float, default=12, help="Extension distance.")
28+
parser.add_argument("-r", "--resolution", type=float, nargs="+", default=[3.0, 1.887779, 1.887779],
29+
help="Resolution of input in micrometer.")
30+
31+
args = parser.parse_args()
32+
33+
block_shape = (128, 128, 128)
34+
chunks = (128, 128, 128)
35+
36+
if len(args.resolution) == 1:
37+
resolution = tuple(args.resolution, args.resolution, args.resolution)
38+
else:
39+
resolution = tuple(args.resolution)
40+
41+
if args.input is not None:
42+
data = read_image_data(args.input, None)
43+
shape = data.shape
44+
# Compute centers of mass for each label (excluding background = 0)
45+
markers = ndimage.center_of_mass(np.ones_like(data), data, index=np.unique(data[data > 0]))
46+
markers = np.array(markers)
47+
48+
else:
49+
50+
s3_path = os.path.join(f"{args.cochlea}", "tables", f"{args.seg_channel}", "default.tsv")
51+
tsv_path, fs = get_s3_path(s3_path)
52+
with fs.open(tsv_path, 'r') as f:
53+
table = pd.read_csv(f, sep="\t")
54+
55+
table = table.loc[table["component_labels"].isin(args.component_labels)]
56+
markers = list(zip(table["anchor_x"] / resolution[0],
57+
table["anchor_y"] / resolution[1],
58+
table["anchor_z"] / resolution[2]))
59+
markers = np.array(markers)
60+
61+
s3_path = os.path.join(f"{args.cochlea}", "images", "ome-zarr", f"{args.seg_channel}.ome.zarr")
62+
input_key = "s0"
63+
s3_store, fs = get_s3_path(s3_path)
64+
with zarr.open(s3_store, mode="r") as f:
65+
data = f[input_key][:].astype("float32")
66+
67+
shape = data.shape
68+
69+
output_key = "extended_segmentation"
70+
output_path = os.path.join(args.output, f"{args.cochlea}-{args.seg_channel}.zarr")
71+
72+
output = open_file(output_path, mode="a")
73+
output_dataset = output.create_dataset(
74+
output_key, shape=shape, dtype=np.dtype("uint32"),
75+
chunks=chunks, compression="gzip"
76+
)
77+
78+
distance_based_marker_extension(
79+
markers=markers,
80+
output=output_dataset,
81+
extension_distance=args.extension_distance,
82+
sampling=resolution,
83+
block_shape=block_shape,
84+
n_threads=16,
85+
)
86+
87+
88+
if __name__ == "__main__":
89+
main()

0 commit comments

Comments
 (0)