Skip to content

Commit 90087c4

Browse files
committed
Implement SGN detection
1 parent ff08ecc commit 90087c4

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import multiprocessing as mp
2+
from concurrent import futures
3+
import os
4+
from typing import Optional, Tuple
5+
6+
import numpy as np
7+
import pandas as pd
8+
import zarr
9+
10+
from elf.io import open_file
11+
from elf.parallel.local_maxima import find_local_maxima
12+
from flamingo_tools.segmentation.unet_prediction import prediction_impl
13+
from tqdm import tqdm
14+
15+
16+
def sgn_detection(
17+
input_path: str,
18+
input_key: str,
19+
output_folder: str,
20+
model_path: str,
21+
block_shape: Optional[Tuple[int, int, int]] = None,
22+
halo: Optional[Tuple[int, int, int]] = None,
23+
spot_radius: int = 4,
24+
):
25+
"""Run prediction for sgn detection.
26+
27+
Args:
28+
input_path: Input path to image channel for SGN detection.
29+
input_key: Input key for resolution of image channel and mask channel.
30+
output_folder: Output folder for SGN segmentation.
31+
model_path: Path to model for SGN detection.
32+
block_shape: The block-shape for running the prediction.
33+
halo: The halo (= block overlap) to use for prediction.
34+
spot_radius: Radius in pixel to convert spot detection of SGNs into a volume.
35+
"""
36+
if block_shape is None:
37+
block_shape = (24, 256, 256)
38+
if halo is None:
39+
halo = (12, 64, 64)
40+
41+
# Skip existing prediction, which is saved in output_folder/predictions.zarr
42+
skip_prediction = False
43+
output_path = os.path.join(output_folder, "predictions.zarr")
44+
prediction_key = "prediction"
45+
if os.path.exists(output_path) and prediction_key in zarr.open(output_path, "r"):
46+
skip_prediction = True
47+
48+
if not skip_prediction:
49+
prediction_impl(
50+
input_path, input_key, output_folder, model_path,
51+
scale=None, block_shape=block_shape, halo=halo,
52+
apply_postprocessing=False, output_channels=1,
53+
)
54+
55+
detection_path = os.path.join(output_folder, "SGN_detection.tsv")
56+
detection_path = os.path.join(output_folder, "SGN_detection.tsv")
57+
if not os.path.exists(detection_path):
58+
input_ = zarr.open(output_path, "r")[prediction_key]
59+
detections = find_local_maxima(
60+
input_, block_shape=block_shape, min_distance=4, threshold_abs=0.5, verbose=True, n_threads=16,
61+
)
62+
63+
print(detections.shape)
64+
65+
shape = input_.shape
66+
chunks = (128, 128, 128)
67+
segmentation_path = os.path.join(output_folder, "segmentation.zarr")
68+
output = open_file(segmentation_path, mode="a")
69+
segmentation_key = "segmentation"
70+
output_dataset = output.create_dataset(
71+
segmentation_key, shape=shape, dtype=input_.dtype,
72+
chunks=chunks, compression="gzip"
73+
)
74+
75+
def add_halo_segm(detection_index):
76+
"""Create a segmentation volume around all detected spots.
77+
"""
78+
coord = detections[detection_index]
79+
block_begin = [round(c) - spot_radius for c in coord]
80+
block_end = [round(c) + spot_radius for c in coord]
81+
volume_index = tuple(slice(beg, end) for beg, end in zip(block_begin, block_end))
82+
output_dataset[volume_index] = detection_index + 1
83+
84+
# Limit the number of cores for parallelization.
85+
n_threads = min(16, mp.cpu_count())
86+
with futures.ThreadPoolExecutor(n_threads) as filter_pool:
87+
list(tqdm(filter_pool.map(add_halo_segm, range(len(detections))), total=len(detections)))
88+
89+
# Save the result in mobie compatible format.
90+
detections = np.concatenate(
91+
[np.arange(1, len(detections) + 1)[:, None], detections[:, ::-1]], axis=1
92+
)
93+
detections = pd.DataFrame(detections, columns=["spot_id", "x", "y", "z"])
94+
detections.to_csv(detection_path, index=False, sep="\t")
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import argparse
2+
3+
import flamingo_tools.s3_utils as s3_utils
4+
from flamingo_tools.segmentation.sgn_detection import sgn_detection
5+
6+
7+
def main():
8+
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument("-i", "--input", required=True, help="Path to image data to be segmented.")
11+
parser.add_argument("-o", "--output_folder", required=True, help="Path to output folder.")
12+
parser.add_argument("-m", "--model", required=True,
13+
help="Path to SGN detection model.")
14+
parser.add_argument("-k", "--input_key", default=None,
15+
help="The key / internal path to image data.")
16+
17+
parser.add_argument("--s3", action="store_true", help="Use S3 bucket.")
18+
parser.add_argument("--s3_credentials", type=str, default=None,
19+
help="Input file containing S3 credentials. "
20+
"Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.")
21+
parser.add_argument("--s3_bucket_name", type=str, default=None,
22+
help="S3 bucket name. Optional if BUCKET_NAME was exported.")
23+
parser.add_argument("--s3_service_endpoint", type=str, default=None,
24+
help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.")
25+
26+
args = parser.parse_args()
27+
28+
block_shape = (24, 256, 256)
29+
halo = (12, 64, 64)
30+
31+
if args.s3:
32+
input_path, fs = s3_utils.get_s3_path(args.input, bucket_name=args.s3_bucket_name,
33+
service_endpoint=args.s3_service_endpoint,
34+
credential_file=args.s3_credentials)
35+
36+
else:
37+
input_path = args.input
38+
39+
sgn_detection(input_path=input_path, input_key=args.input_key, output_folder=args.output_folder,
40+
model_path=args.model, block_shape=block_shape, halo=halo)
41+
42+
43+
if __name__ == "__main__":
44+
main()

0 commit comments

Comments
 (0)