Skip to content

Commit 3230fa2

Browse files
Updates to support large-scale prediction with the synapse detection model
1 parent cd0cc0e commit 3230fa2

File tree

2 files changed

+85
-9
lines changed

2 files changed

+85
-9
lines changed

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import nifty.tools as nt
1010
import vigra
11+
import tifffile
1112
import torch
1213
import z5py
1314

@@ -37,7 +38,10 @@ def ndim(self):
3738
return self._volume.ndim - 1
3839

3940

40-
def prediction_impl(input_path, input_key, output_folder, model_path, scale, block_shape, halo):
41+
def prediction_impl(
42+
input_path, input_key, output_folder, model_path, scale, block_shape, halo,
43+
output_channels=3, apply_postprocessing=True,
44+
):
4145
with warnings.catch_warnings():
4246
warnings.simplefilter("ignore")
4347
if os.path.isdir(model_path):
@@ -46,10 +50,16 @@ def prediction_impl(input_path, input_key, output_folder, model_path, scale, blo
4650
model = torch.load(model_path)
4751

4852
mask_path = os.path.join(output_folder, "mask.zarr")
49-
image_mask = z5py.File(mask_path, "r")["mask"]
53+
if os.path.exists(mask_path):
54+
image_mask = z5py.File(mask_path, "r")["mask"]
55+
else:
56+
image_mask = None
5057

5158
if input_key is None:
52-
input_ = imageio.imread(input_path)
59+
try:
60+
input_ = tifffile.memmap(input_path)
61+
except Exception:
62+
input_ = imageio.imread(input_path)
5363
else:
5464
input_ = open_file(input_path, "r")[input_key]
5565

@@ -93,17 +103,27 @@ def preprocess(raw):
93103
raw /= std
94104
return raw
95105

96-
# Smooth the distance prediction channel.
97-
def postprocess(x):
98-
x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0)
99-
return x
106+
if apply_postprocessing:
107+
# Smooth the distance prediction channel.
108+
def postprocess(x):
109+
x[1] = vigra.filters.gaussianSmoothing(x[1], sigma=2.0)
110+
return x
111+
else:
112+
postprocess = None if output_channels > 1 else lambda x: x.squeeze()
113+
114+
if output_channels > 1:
115+
output_shape = (output_channels,) + input_.shape
116+
output_chunks = (1,) + block_shape
117+
else:
118+
output_shape = input_.shape
119+
output_chunks = block_shape
100120

101121
output_path = os.path.join(output_folder, "predictions.zarr")
102122
with open_file(output_path, "a") as f:
103123
output = f.require_dataset(
104124
"prediction",
105-
shape=(3,) + input_.shape,
106-
chunks=(1,) + block_shape,
125+
shape=output_shape,
126+
chunks=output_chunks,
107127
compression="gzip",
108128
dtype="float32",
109129
)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import argparse
2+
import os
3+
import sys
4+
5+
import pandas as pd
6+
import numpy as np
7+
import zarr
8+
9+
from elf.parallel.local_maxima import find_local_maxima
10+
11+
sys.path.append("../..")
12+
13+
14+
def main():
15+
from flamingo_tools.segmentation.unet_prediction import prediction_impl
16+
17+
parser = argparse.ArgumentParser()
18+
parser.add_argument("-i", "--input", required=True)
19+
parser.add_argument("-o", "--output_folder", required=True)
20+
parser.add_argument("-m", "--model", required=True)
21+
parser.add_argument("-k", "--input_key", default=None)
22+
args = parser.parse_args()
23+
24+
block_shape = (64, 256, 256)
25+
halo = (16, 64, 64)
26+
27+
# Skip existing prediction, which is saved in output_folder/predictions.zarr
28+
skip_prediction = False
29+
output_path = os.path.join(args.output_folder, "predictions.zarr")
30+
prediction_key = "prediction"
31+
if os.path.exists(output_path) and prediction_key in zarr.open(output_path, "r"):
32+
skip_prediction = True
33+
34+
if not skip_prediction:
35+
prediction_impl(
36+
args.input, args.input_key, args.output_folder, args.model,
37+
scale=None, block_shape=block_shape, halo=halo,
38+
apply_postprocessing=False, output_channels=1,
39+
)
40+
41+
detection_path = os.path.join(args.output_folder, "synapse_detection.tsv")
42+
if not os.path.exists(detection_path):
43+
input_ = zarr.open(output_path, "r")[prediction_key]
44+
detections = find_local_maxima(
45+
input_, block_shape=block_shape, min_distance=2, threshold_abs=0.5, verbose=True, n_threads=16,
46+
)
47+
# Save the result in mobie compatible format.
48+
detections = np.concatenate(
49+
[np.arange(1, len(detections) + 1)[:, None], detections[:, ::-1]], axis=1
50+
)
51+
detections = pd.DataFrame(detections, columns=["spot_id", "x", "y", "z"])
52+
detections.to_csv(detection_path, index=False, sep="\t")
53+
54+
55+
if __name__ == "__main__":
56+
main()

0 commit comments

Comments
 (0)