Skip to content

Commit 9b2f579

Browse files
Add script for checking models
1 parent f2ea3e0 commit 9b2f579

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import os
2+
import subprocess
3+
from shutil import copyfile
4+
5+
import imageio.v3 as imageio
6+
import napari
7+
import pandas as pd
8+
import zarr
9+
from flamingo_tools.test_data import _sample_registry
10+
11+
data_dict = {
12+
"SGN": "PV",
13+
"IHC": "VGlut3",
14+
"SGN-lowres": "PV-lowres",
15+
"IHC-lowres": "MYO-lowres",
16+
"Synapses": "CTBP2",
17+
}
18+
19+
20+
def check_segmentation_model(model_name):
21+
output_folder = f"result_{model_name}"
22+
os.makedirs(output_folder, exist_ok=True)
23+
input_path = os.path.join(output_folder, f"{model_name}.tif")
24+
if not os.path.exists(input_path):
25+
data_path = _sample_registry().fetch(data_dict[model_name])
26+
copyfile(data_path, input_path)
27+
28+
subprocess.run(
29+
["flamingo_tools.run_segmentation", "-i", input_path, "-o", output_folder, "-m", model_name]
30+
)
31+
output_path = os.path.join(output_folder, "segmentation.zarr")
32+
segmentation = zarr.open(output_path)["segmentation"][:]
33+
34+
image = imageio.imread(input_path)
35+
v = napari.Viewer()
36+
v.add_image(image)
37+
v.add_labels(segmentation, name=f"{model_name}-segmentation")
38+
napari.run()
39+
40+
41+
def check_detection_model():
42+
model_name = "Synapses"
43+
output_folder = f"result_{model_name}"
44+
os.makedirs(output_folder, exist_ok=True)
45+
input_path = os.path.join(output_folder, f"{model_name}.tif")
46+
if not os.path.exists(input_path):
47+
data_path = _sample_registry().fetch(data_dict[model_name])
48+
copyfile(data_path, input_path)
49+
50+
subprocess.run(
51+
["flamingo_tools.run_detection", "-i", input_path, "-o", output_folder, "-m", model_name]
52+
)
53+
output_path = os.path.join(output_folder, "synapse_detection.tsv")
54+
prediction = pd.read_csv(output_path, sep="\t")[["z", "y", "x"]]
55+
56+
image = imageio.imread(input_path)
57+
v = napari.Viewer()
58+
v.add_image(image)
59+
v.add_points(prediction)
60+
napari.run()
61+
62+
63+
def main():
64+
# SGN segmentation:
65+
# - Prediction works well on the CPU.
66+
# check_segmentation_model("SGN")
67+
68+
# IHC segmentation:
69+
# - Prediction does not work well on the CPU.
70+
# check_segmentation_model("IHC")
71+
72+
# SGN segmentation (lowres):
73+
# - Prediction does not work well on the CPU.
74+
# check_segmentation_model("SGN-lowres")
75+
76+
# IHC segmentation (lowres):
77+
# - The prediction seems to work (on the CPU), but a lot of merges.
78+
# -> Update the segmentation params?
79+
# check_segmentation_model("IHC-lowres")
80+
81+
# Synapse detection:
82+
# - Prediction works well on the CPU.
83+
check_detection_model()
84+
85+
86+
if __name__ == "__main__":
87+
main()

0 commit comments

Comments
 (0)