Skip to content

Commit e4ddc57

Browse files
committed
Added workflow for marker detection; Refactoring
1 parent aefeba1 commit e4ddc57

File tree

5 files changed

+304
-120
lines changed

5 files changed

+304
-120
lines changed

flamingo_tools/segmentation/marker_detection.py

Lines changed: 0 additions & 80 deletions
This file was deleted.
Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
import os
2+
from typing import Optional, Tuple
3+
4+
import numpy as np
5+
import pandas as pd
6+
import zarr
7+
from scipy.ndimage import binary_dilation
8+
9+
from elf.parallel.local_maxima import find_local_maxima
10+
from elf.parallel.distance_transform import map_points_to_objects
11+
from flamingo_tools.file_utils import read_image_data
12+
from flamingo_tools.segmentation.unet_prediction import prediction_impl
13+
14+
15+
def map_and_filter_detections(
16+
segmentation: np.ndarray,
17+
detections: pd.DataFrame,
18+
max_distance: float,
19+
resolution: float = 0.38,
20+
n_threads: Optional[int] = None,
21+
verbose: bool = True,
22+
) -> pd.DataFrame:
23+
"""Map synapse detections to segmented IHCs and filter out detections above a distance threshold to the IHCs.
24+
25+
Args:
26+
segmentation: The IHC segmentation.
27+
detections: The synapse marker detections.
28+
max_distance: The maximal distance for a valid match of synapse markers to IHCs.
29+
resolution: The resolution / voxel size of the data in micrometer.
30+
n_threads: The number of threads for parallelizing the mapping of detections to objects.
31+
verbose: Whether to print the progress of the mapping procedure.
32+
33+
Returns:
34+
The filtered dataframe with the detections mapped to the segmentation.
35+
"""
36+
# Get the point coordinates.
37+
points = detections[["z", "y", "x"]].values.astype("int")
38+
39+
# Set the block shape (this could also be exposed as a parameter; it should not matter much though).
40+
block_shape = (64, 256, 256)
41+
42+
# Determine the halo. We set it to 2 pixels + the max-distance in pixels, to ensure all distances
43+
# that are smaller than the max distance are measured.
44+
halo = (2 + int(np.ceil(max_distance / resolution)),) * 3
45+
46+
# Map the detections to the obejcts in the (IHC) segmentation.
47+
object_ids, object_distances = map_points_to_objects(
48+
segmentation=segmentation,
49+
points=points,
50+
block_shape=block_shape,
51+
halo=halo,
52+
sampling=resolution,
53+
n_threads=n_threads,
54+
verbose=verbose,
55+
)
56+
assert len(object_ids) == len(points)
57+
assert len(object_distances) == len(points)
58+
59+
# Add matched ids and distances to the dataframe.
60+
detections["matched_ihc"] = object_ids
61+
detections["distance_to_ihc"] = object_distances
62+
63+
# Filter the dataframe by the max distance.
64+
detections = detections[detections.distance_to_ihc < max_distance]
65+
return detections
66+
67+
68+
def run_prediction(
69+
input_path: str,
70+
input_key: str,
71+
output_folder: str,
72+
model_path: str,
73+
block_shape: Optional[Tuple[int, int, int]] = None,
74+
halo: Optional[Tuple[int, int, int]] = None,
75+
):
76+
"""Run prediction for synapse detection.
77+
78+
Args:
79+
input_path: Input path to image channel for synapse detection.
80+
input_key: Input key for resolution of image channel and mask channel.
81+
output_folder: Output folder for synapse segmentation and marker detection.
82+
model_path: Path to model for synapse detection.
83+
block_shape: The block-shape for running the prediction.
84+
halo: The halo (= block overlap) to use for prediction.
85+
"""
86+
if block_shape is None:
87+
block_shape = (64, 256, 256)
88+
if halo is None:
89+
halo = (16, 64, 64)
90+
91+
# Skip existing prediction, which is saved in output_folder/predictions.zarr
92+
skip_prediction = False
93+
output_path = os.path.join(output_folder, "predictions.zarr")
94+
prediction_key = "prediction"
95+
if os.path.exists(output_path) and prediction_key in zarr.open(output_path, "r"):
96+
skip_prediction = True
97+
98+
if not skip_prediction:
99+
prediction_impl(
100+
input_path, input_key, output_folder, model_path,
101+
scale=None, block_shape=block_shape, halo=halo,
102+
apply_postprocessing=False, output_channels=1,
103+
)
104+
105+
detection_path = os.path.join(output_folder, "synapse_detection.tsv")
106+
if not os.path.exists(detection_path):
107+
input_ = zarr.open(output_path, "r")[prediction_key]
108+
detections = find_local_maxima(
109+
input_, block_shape=block_shape, min_distance=2, threshold_abs=0.5, verbose=True, n_threads=16,
110+
)
111+
# Save the result in mobie compatible format.
112+
detections = np.concatenate(
113+
[np.arange(1, len(detections) + 1)[:, None], detections[:, ::-1]], axis=1
114+
)
115+
detections = pd.DataFrame(detections, columns=["spot_id", "x", "y", "z"])
116+
detections.to_csv(detection_path, index=False, sep="\t")
117+
118+
119+
def marker_detection(
120+
input_path: str,
121+
input_key: str,
122+
mask_path: str,
123+
output_folder: str,
124+
model_path: str,
125+
mask_key: str = "s4",
126+
max_distance: float = 20,
127+
resolution: float = 0.38,
128+
):
129+
"""Streamlined workflow for marker detection, mapping, and filtering.
130+
131+
Args:
132+
input_path: Input path to image channel for synapse detection.
133+
input_key: Input key for resolution of image channel and mask channel.
134+
mask_path: Path to IHC segmentation used to mask input.
135+
output_folder: Output folder for synapse segmentation and marker detection.
136+
model_path: Path to model for synapse detection.
137+
mask_key: Key to undersampled IHC segmentation for masking input for synapse detection.
138+
max_distance: The maximal distance for a valid match of synapse markers to IHCs.
139+
resolution: The resolution / voxel size of the data in micrometer.
140+
"""
141+
142+
# 1.) Determine mask for inference based on the IHC segmentation.
143+
# Best approach: load IHC segmentation at a low scale level, binarize it,
144+
# dilate it and use this as mask. It can be mapped back to the full resolution
145+
# with `elf.wrapper.ResizedVolume`.
146+
147+
skip_masking = False
148+
149+
mask_key = "mask"
150+
output_file = os.path.join(output_folder, "mask.zarr")
151+
152+
if os.path.exists(output_file) and mask_key in zarr.open(output_file, "r"):
153+
skip_masking = True
154+
155+
if not skip_masking:
156+
mask_ = read_image_data(mask_path, mask_key)
157+
new_mask = np.zeros(mask_.shape)
158+
new_mask[mask_ != 0] = 1
159+
arr_bin = binary_dilation(mask_, structure=np.ones((9, 9, 9))).astype(int)
160+
161+
with zarr.open(output_file, mode="w") as f_out:
162+
f_out.create_dataset(mask_key, data=arr_bin, compression="gzip")
163+
164+
# 2.) Run inference and detection of maxima.
165+
# This can be taken from 'scripts/synapse_marker_detection/run_prediction.py'
166+
# (And the run prediction script should then be refactored).
167+
168+
block_shape = (64, 256, 256)
169+
halo = (16, 64, 64)
170+
171+
# Skip existing prediction, which is saved in output_folder/predictions.zarr
172+
skip_prediction = False
173+
output_path = os.path.join(output_folder, "predictions.zarr")
174+
prediction_key = "prediction"
175+
if os.path.exists(output_path) and prediction_key in zarr.open(output_path, "r"):
176+
skip_prediction = True
177+
178+
if not skip_prediction:
179+
prediction_impl(
180+
input_path, input_key, output_folder, model_path,
181+
scale=None, block_shape=block_shape, halo=halo,
182+
apply_postprocessing=False, output_channels=1,
183+
)
184+
185+
detection_path = os.path.join(output_folder, "synapse_detection.tsv")
186+
if not os.path.exists(detection_path):
187+
input_ = zarr.open(output_path, "r")[prediction_key]
188+
detections = find_local_maxima(
189+
input_, block_shape=block_shape, min_distance=2, threshold_abs=0.5, verbose=True, n_threads=16,
190+
)
191+
# Save the result in mobie compatible format.
192+
detections = np.concatenate(
193+
[np.arange(1, len(detections) + 1)[:, None], detections[:, ::-1]], axis=1
194+
)
195+
detections = pd.DataFrame(detections, columns=["spot_id", "x", "y", "z"])
196+
detections.to_csv(detection_path, index=False, sep="\t")
197+
198+
else:
199+
with open(detection_path, 'r') as f:
200+
detections = pd.read_csv(f, sep="\t")
201+
202+
# 3.) Map the detections to IHC and filter them based on a distance criterion.
203+
# Use the function 'map_and_filter_detections' from above.
204+
input_ = read_image_data(mask_path, input_key)
205+
206+
detections_filtered = map_and_filter_detections(
207+
segmentation=input_,
208+
detections=detections,
209+
max_distance=max_distance,
210+
resolution=resolution,
211+
)
212+
213+
# 4.) Add the filtered detections to MoBIE.
214+
# IMPORTANT scale the coordinates with the resolution here.
215+
detections_filtered["distance_to_ihc"] *= resolution
216+
detections_filtered["x"] *= resolution
217+
detections_filtered["y"] *= resolution
218+
detections_filtered["z"] *= resolution
219+
detection_path = os.path.join(output_folder, "synapse_detection_filtered.tsv")
220+
detections_filtered.to_csv(detection_path, index=False, sep="\t")

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,21 @@ def prediction_impl(
7777
else:
7878
model = torch.load(model_path, weights_only=False)
7979

80+
input_ = read_image_data(input_path, input_key)
81+
chunks = getattr(input_, "chunks", (64, 64, 64))
8082
mask_path = os.path.join(output_folder, "mask.zarr")
83+
8184
if os.path.exists(mask_path):
8285
image_mask = z5py.File(mask_path, "r")["mask"]
86+
# resize mask
87+
image_shape = input_.shape
88+
mask_shape = image_mask.shape
89+
if image_shape != mask_shape:
90+
image_mask = ResizedVolume(image_mask, image_shape, order=0)
91+
8392
else:
8493
image_mask = None
8594

86-
input_ = read_image_data(input_path, input_key)
87-
chunks = getattr(input_, "chunks", (64, 64, 64))
88-
8995
if scale is None or np.isclose(scale, 1):
9096
original_shape = None
9197
else:
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import argparse
2+
3+
import flamingo_tools.s3_utils as s3_utils
4+
from flamingo_tools.segmentation.synapse_detection import marker_detection
5+
6+
7+
def main():
8+
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument("-i", "--input", required=True)
11+
parser.add_argument("-o", "--output_folder", required=True, help="Path to output folder.")
12+
parser.add_argument("-s", "--mask", required=True, help="Path to IHC segmentation used for masking.")
13+
parser.add_argument("-m", "--model", required=True, help="Path to synapse detection model.")
14+
parser.add_argument("-k", "--input_key", default=None,
15+
help="Input key for image data and mask data for marker detection.")
16+
parser.add_argument("-d", "--max_distance", default=20,
17+
help="The maximal distance for a valid match of synapse markers to IHCs.")
18+
19+
parser.add_argument("--s3", action="store_true", help="Use S3 bucket.")
20+
parser.add_argument("--s3_credentials", type=str, default=None,
21+
help="Input file containing S3 credentials. "
22+
"Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.")
23+
parser.add_argument("--s3_bucket_name", type=str, default=None,
24+
help="S3 bucket name. Optional if BUCKET_NAME was exported.")
25+
parser.add_argument("--s3_service_endpoint", type=str, default=None,
26+
help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.")
27+
28+
args = parser.parse_args()
29+
30+
if args.s3:
31+
input_path, fs = s3_utils.get_s3_path(args.input, bucket_name=args.s3_bucket_name,
32+
service_endpoint=args.s3_service_endpoint,
33+
credential_file=args.s3_credentials)
34+
35+
mask_path, fs = s3_utils.get_s3_path(args.mask, bucket_name=args.s3_bucket_name,
36+
service_endpoint=args.s3_service_endpoint,
37+
credential_file=args.s3_credentials)
38+
else:
39+
input_path = args.input
40+
mask_path = args.mask
41+
42+
marker_detection(input_path=input_path, input_key=args.input_key, mask_path=mask_path,
43+
output_folder=args.output_folder, model_path=args.model)
44+
45+
46+
if __name__ == "__main__":
47+
main()

0 commit comments

Comments
 (0)