Skip to content

Commit 431c748

Browse files
Update segmentation CLI
1 parent 4d4d552 commit 431c748

File tree

5 files changed

+101
-45
lines changed

5 files changed

+101
-45
lines changed

flamingo_tools/segmentation/cli.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def run_segmentation():
2020
"""
2121
parser = argparse.ArgumentParser(description="")
2222
parser.add_argument("-i", "--input_path", required=True, help="The path to the input data.")
23-
parser.add_argument("-k", "--input_key", required=True, help="The key to the input data.")
23+
parser.add_argument("-k", "--input_key", help="The key to the input data.")
2424
parser.add_argument("-o", "--output_folder", required=True)
2525
parser.add_argument("-m", "--model_type", required=True)
2626
parser.add_argument("-c", "--checkpoint_path")
@@ -43,14 +43,20 @@ def run_detection():
4343
"""private
4444
"""
4545
parser = argparse.ArgumentParser()
46+
parser.add_argument("-i", "--input_path", required=True, help="The path to the input data.")
47+
parser.add_argument("-k", "--input_key", help="The key to the input data.")
48+
parser.add_argument("-o", "--output_folder", required=True)
4649
parser.add_argument("-m", "--model_type", default="Synapses")
50+
parser.add_argument("--mask_path")
51+
parser.add_argument("--mask_key")
52+
parser.add_argument("-c", "--checkpoint_path")
4753
args = parser.parse_args()
4854
detection_models = ["Synapses"]
4955
if args.model_type not in detection_models:
5056
raise ValueError
5157
model_path = _get_model_path(args.model_type, args.checkpoint_path)
52-
# TODO
5358
marker_detection(
5459
input_path=args.input_path, input_key=args.input_key,
5560
output_folder=args.output_folder, model_path=model_path,
61+
mask_path=args.mask_path, mask_input_key=args.mask_key,
5662
)

flamingo_tools/segmentation/synapse_detection.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,10 @@ def run_prediction(
119119
def marker_detection(
120120
input_path: str,
121121
input_key: str,
122-
mask_path: str,
122+
mask_path: Optional[str],
123123
output_folder: str,
124124
model_path: str,
125-
mask_input_key: str = "s4",
125+
mask_input_key: Optional[str] = "s4",
126126
max_distance: float = 20,
127127
resolution: float = 0.38,
128128
):
@@ -143,13 +143,12 @@ def marker_detection(
143143
# Best approach: load IHC segmentation at a low scale level, binarize it,
144144
# dilate it and use this as mask. It can be mapped back to the full resolution
145145
# with `elf.wrapper.ResizedVolume`.
146-
147146
skip_masking = False
148147

149148
mask_preprocess_key = "mask"
150149
output_file = os.path.join(output_folder, "mask.zarr")
151150

152-
if os.path.exists(output_file) and mask_preprocess_key in zarr.open(output_file, "r"):
151+
if mask_path is None or (os.path.exists(output_file) and mask_preprocess_key in zarr.open(output_file, "r")):
153152
skip_masking = True
154153

155154
if not skip_masking:
@@ -162,11 +161,6 @@ def marker_detection(
162161
f_out.create_dataset(mask_preprocess_key, data=arr_bin, compression="gzip")
163162

164163
# 2.) Run inference and detection of maxima.
165-
# This can be taken from 'scripts/synapse_marker_detection/run_prediction.py'
166-
# (And the run prediction script should then be refactored).
167-
168-
block_shape = (64, 256, 256)
169-
halo = (16, 64, 64)
170164

171165
# Skip existing prediction, which is saved in output_folder/predictions.zarr
172166
skip_prediction = False
@@ -183,11 +177,12 @@ def marker_detection(
183177
if not skip_prediction:
184178
prediction_impl(
185179
input_path, input_key, output_folder, model_path,
186-
scale=None, block_shape=block_shape, halo=halo,
187-
apply_postprocessing=False, output_channels=1,
180+
scale=None, apply_postprocessing=False, output_channels=1,
181+
block_shape=None, halo=None,
188182
)
189183

190184
if not os.path.exists(detection_path):
185+
block_shape = (64, 256, 256)
191186
input_ = zarr.open(output_path, "r")[prediction_key]
192187
detections = find_local_maxima(
193188
input_, block_shape=block_shape, min_distance=2, threshold_abs=0.5, verbose=True, n_threads=16,
@@ -200,25 +195,25 @@ def marker_detection(
200195
detections.to_csv(detection_path, index=False, sep="\t")
201196

202197
else:
203-
with open(detection_path, 'r') as f:
198+
with open(detection_path, "r") as f:
204199
detections = pd.read_csv(f, sep="\t")
205200

206201
# 3.) Map the detections to IHC and filter them based on a distance criterion.
207202
# Use the function 'map_and_filter_detections' from above.
208-
input_ = read_image_data(mask_path, input_key)
209-
210-
detections_filtered = map_and_filter_detections(
211-
segmentation=input_,
212-
detections=detections,
213-
max_distance=max_distance,
214-
resolution=resolution,
215-
)
203+
if mask_path is not None:
204+
input_ = read_image_data(mask_path, input_key)
205+
detections_filtered = map_and_filter_detections(
206+
segmentation=input_,
207+
detections=detections,
208+
max_distance=max_distance,
209+
resolution=resolution,
210+
)
216211

217-
# 4.) Add the filtered detections to MoBIE.
218-
# IMPORTANT scale the coordinates with the resolution here.
219-
detections_filtered["distance_to_ihc"] *= resolution
220-
detections_filtered["x"] *= resolution
221-
detections_filtered["y"] *= resolution
222-
detections_filtered["z"] *= resolution
223-
detection_path = os.path.join(output_folder, "synapse_detection_filtered.tsv")
224-
detections_filtered.to_csv(detection_path, index=False, sep="\t")
212+
# 4.) Add the filtered detections to MoBIE.
213+
# IMPORTANT scale the coordinates with the resolution here.
214+
detections_filtered["distance_to_ihc"] *= resolution
215+
detections_filtered["x"] *= resolution
216+
detections_filtered["y"] *= resolution
217+
detections_filtered["z"] *= resolution
218+
detection_path = os.path.join(output_folder, "synapse_detection_filtered.tsv")
219+
detections_filtered.to_csv(detection_path, index=False, sep="\t")

flamingo_tools/segmentation/unet_prediction.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,21 @@ def ndim(self):
5555
return self._volume.ndim - 1
5656

5757

58+
def _get_device_and_tiling(block_shape, halo, input_):
59+
have_cuda = torch.cuda.is_available()
60+
if block_shape is None:
61+
block_shape = (128, 128, 128) if have_cuda else getattr(input_, "chunks", (64, 64, 64))
62+
if halo is None:
63+
halo = (16, 32, 32)
64+
if have_cuda:
65+
print("Predict with GPU")
66+
gpu_ids = [0]
67+
else:
68+
print("Predict with CPU")
69+
gpu_ids = ["cpu"]
70+
return gpu_ids, block_shape, halo
71+
72+
5873
def prediction_impl(
5974
input_path,
6075
input_key,
@@ -109,19 +124,6 @@ def prediction_impl(
109124
input_ = ResizedVolume(input_, shape=new_shape, order=3)
110125
image_mask = ResizedVolume(image_mask, new_shape, order=0)
111126

112-
have_cuda = torch.cuda.is_available()
113-
114-
if block_shape is None:
115-
block_shape = (128, 128, 128) if have_cuda else input_.chunks
116-
if halo is None:
117-
halo = (16, 32, 32)
118-
if have_cuda:
119-
print("Predict with GPU")
120-
gpu_ids = [0]
121-
else:
122-
print("Predict with CPU")
123-
gpu_ids = ["cpu"]
124-
125127
if mean is None or std is None:
126128
# Compute the global mean and standard deviation.
127129
n_threads = min(16, mp.cpu_count())
@@ -156,6 +158,7 @@ def postprocess(x):
156158

157159
shape = input_.shape
158160
ndim = len(shape)
161+
gpu_ids, block_shape, halo = _get_device_and_tiling(block_shape, halo, input_)
159162

160163
blocking = nt.blocking([0] * ndim, shape, block_shape)
161164
n_blocks = blocking.numberOfBlocks

test/test_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ def setUp(self):
1919
def tearDown(self):
2020
rmtree(self.folder)
2121

22-
def test_convert_flamingo(self):
22+
def test_convert_data(self):
2323
out_path = os.path.join(self.folder, "converted_data.n5")
24-
cmd = ["flamingo_tools.convert_flamingo", "-i", self.folder, "-o", out_path, "--metadata_pattern", ""]
24+
cmd = ["flamingo_tools.convert_data", "-i", self.folder, "-o", out_path, "--metadata_pattern", ""]
2525
run(cmd)
2626

2727
self.assertTrue(os.path.exists(out_path))

test/test_segmentation/test_cli.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import os
2+
import subprocess
3+
import tempfile
4+
import unittest
5+
6+
import imageio.v3 as imageio
7+
import numpy as np
8+
import z5py
9+
10+
11+
class TestSegmentationCLI(unittest.TestCase):
12+
shape = (64, 128, 128)
13+
14+
def _create_data(self, tmp_dir):
15+
data = np.random.randint(0, 255, size=self.shape)
16+
path = os.path.join(tmp_dir, "data.tif")
17+
imageio.imwrite(path, data)
18+
return path
19+
20+
def test_run_segmentation(self):
21+
with tempfile.TemporaryDirectory() as tmp_dir:
22+
data_path = self._create_data(tmp_dir)
23+
output_folder = os.path.join(tmp_dir, "output")
24+
25+
subprocess.run([
26+
"flamingo_tools.run_segmentation",
27+
"-i", data_path, "-o", output_folder, "-m", "SGN", "--min_size", "0"
28+
])
29+
30+
expected_path = os.path.join(output_folder, "segmentation.zarr")
31+
expected_key = "segmentation"
32+
33+
self.assertTrue(os.path.exists(expected_path))
34+
with z5py.File(expected_path, "r") as f:
35+
self.assertTrue(expected_key in f)
36+
self.assertEqual(f[expected_key].shape, self.shape)
37+
38+
def test_run_detection(self):
39+
with tempfile.TemporaryDirectory() as tmp_dir:
40+
data_path = self._create_data(tmp_dir)
41+
output_folder = os.path.join(tmp_dir, "output")
42+
43+
subprocess.run([
44+
"flamingo_tools.run_detection", "-i", data_path, "-o", output_folder
45+
])
46+
47+
expected_path = os.path.join(output_folder, "synapse_detection.tsv")
48+
self.assertTrue(os.path.exists(expected_path))
49+
50+
51+
if __name__ == "__main__":
52+
unittest.main()

0 commit comments

Comments
 (0)