Skip to content

Commit 21f0138

Browse files
Update model check script
1 parent 9b2f579 commit 21f0138

File tree

1 file changed

+27
-22
lines changed

1 file changed

+27
-22
lines changed

development/check_uploaded_models.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import zarr
99
from flamingo_tools.test_data import _sample_registry
1010

11+
view = False
1112
data_dict = {
1213
"SGN": "PV",
1314
"IHC": "VGlut3",
@@ -25,17 +26,19 @@ def check_segmentation_model(model_name):
2526
data_path = _sample_registry().fetch(data_dict[model_name])
2627
copyfile(data_path, input_path)
2728

28-
subprocess.run(
29-
["flamingo_tools.run_segmentation", "-i", input_path, "-o", output_folder, "-m", model_name]
30-
)
3129
output_path = os.path.join(output_folder, "segmentation.zarr")
32-
segmentation = zarr.open(output_path)["segmentation"][:]
30+
if not os.path.exists(output_path):
31+
subprocess.run(
32+
["flamingo_tools.run_segmentation", "-i", input_path, "-o", output_folder, "-m", model_name]
33+
)
3334

34-
image = imageio.imread(input_path)
35-
v = napari.Viewer()
36-
v.add_image(image)
37-
v.add_labels(segmentation, name=f"{model_name}-segmentation")
38-
napari.run()
35+
if view:
36+
segmentation = zarr.open(output_path)["segmentation"][:]
37+
image = imageio.imread(input_path)
38+
v = napari.Viewer()
39+
v.add_image(image)
40+
v.add_labels(segmentation, name=f"{model_name}-segmentation")
41+
napari.run()
3942

4043

4144
def check_detection_model():
@@ -47,36 +50,38 @@ def check_detection_model():
4750
data_path = _sample_registry().fetch(data_dict[model_name])
4851
copyfile(data_path, input_path)
4952

50-
subprocess.run(
51-
["flamingo_tools.run_detection", "-i", input_path, "-o", output_folder, "-m", model_name]
52-
)
5353
output_path = os.path.join(output_folder, "synapse_detection.tsv")
54-
prediction = pd.read_csv(output_path, sep="\t")[["z", "y", "x"]]
54+
if not os.path.exists(output_path):
55+
subprocess.run(
56+
["flamingo_tools.run_detection", "-i", input_path, "-o", output_folder, "-m", model_name]
57+
)
5558

56-
image = imageio.imread(input_path)
57-
v = napari.Viewer()
58-
v.add_image(image)
59-
v.add_points(prediction)
60-
napari.run()
59+
if view:
60+
prediction = pd.read_csv(output_path, sep="\t")[["z", "y", "x"]]
61+
image = imageio.imread(input_path)
62+
v = napari.Viewer()
63+
v.add_image(image)
64+
v.add_points(prediction)
65+
napari.run()
6166

6267

6368
def main():
6469
# SGN segmentation:
6570
# - Prediction works well on the CPU.
66-
# check_segmentation_model("SGN")
71+
check_segmentation_model("SGN")
6772

6873
# IHC segmentation:
6974
# - Prediction does not work well on the CPU.
70-
# check_segmentation_model("IHC")
75+
check_segmentation_model("IHC")
7176

7277
# SGN segmentation (lowres):
7378
# - Prediction does not work well on the CPU.
74-
# check_segmentation_model("SGN-lowres")
79+
check_segmentation_model("SGN-lowres")
7580

7681
# IHC segmentation (lowres):
7782
# - The prediction seems to work (on the CPU), but a lot of merges.
7883
# -> Update the segmentation params?
79-
# check_segmentation_model("IHC-lowres")
84+
check_segmentation_model("IHC-lowres")
8085

8186
# Synapse detection:
8287
# - Prediction works well on the CPU.

0 commit comments

Comments
 (0)