Skip to content

Commit fd39eb7

Browse files
Update model checking script
1 parent c3b5a26 commit fd39eb7

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

development/check_uploaded_models.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import zarr
99
from flamingo_tools.test_data import _sample_registry
1010

11-
view = False
11+
view = True
1212
data_dict = {
1313
"SGN": "PV",
1414
"IHC": "VGlut3",
@@ -18,7 +18,7 @@
1818
}
1919

2020

21-
def check_segmentation_model(model_name):
21+
def check_segmentation_model(model_name, checkpoint_path=None):
2222
output_folder = f"result_{model_name}"
2323
os.makedirs(output_folder, exist_ok=True)
2424
input_path = os.path.join(output_folder, f"{model_name}.tif")
@@ -28,9 +28,10 @@ def check_segmentation_model(model_name):
2828

2929
output_path = os.path.join(output_folder, "segmentation.zarr")
3030
if not os.path.exists(output_path):
31-
subprocess.run(
32-
["flamingo_tools.run_segmentation", "-i", input_path, "-o", output_folder, "-m", model_name]
33-
)
31+
cmd = ["flamingo_tools.run_segmentation", "-i", input_path, "-o", output_folder, "-m", model_name]
32+
if checkpoint_path is not None:
33+
cmd.extend(["-c", checkpoint_path])
34+
subprocess.run(cmd)
3435

3536
if view:
3637
segmentation = zarr.open(output_path)["segmentation"][:]
@@ -68,24 +69,29 @@ def check_detection_model():
6869
def main():
6970
# SGN segmentation:
7071
# - Prediction works well on the CPU.
71-
check_segmentation_model("SGN")
72+
# - Prediction works well on the GPU.
73+
# check_segmentation_model("SGN")
7274

7375
# IHC segmentation:
74-
# - Prediction does not work well on the CPU.
75-
check_segmentation_model("IHC")
76+
# - Prediction works well on the CPU.
77+
# - Prediction works well on the GPU.
78+
# check_segmentation_model("IHC")
7679

80+
# TODO: Update model.
7781
# SGN segmentation (lowres):
7882
# - Prediction does not work well on the CPU.
79-
check_segmentation_model("SGN-lowres")
83+
# - Prediction does not work well on the GPU.
84+
check_segmentation_model("SGN-lowres", checkpoint_path="SGN-lowres.pt")
8085

8186
# IHC segmentation (lowres):
82-
# - The prediction seems to work (on the CPU), but a lot of merges.
83-
# -> Update the segmentation params?
84-
check_segmentation_model("IHC-lowres")
87+
# - Prediction works well on the CPU.
88+
# - Prediction works well on the GPU.
89+
# check_segmentation_model("IHC-lowres")
8590

8691
# Synapse detection:
8792
# - Prediction works well on the CPU.
88-
check_detection_model()
93+
# - Prediction works well on the GPU.
94+
# check_detection_model()
8995

9096

9197
if __name__ == "__main__":

0 commit comments

Comments
 (0)