88import zarr
99from flamingo_tools .test_data import _sample_registry
1010
11+ view = False
1112data_dict = {
1213 "SGN" : "PV" ,
1314 "IHC" : "VGlut3" ,
@@ -25,17 +26,19 @@ def check_segmentation_model(model_name):
2526 data_path = _sample_registry ().fetch (data_dict [model_name ])
2627 copyfile (data_path , input_path )
2728
28- subprocess .run (
29- ["flamingo_tools.run_segmentation" , "-i" , input_path , "-o" , output_folder , "-m" , model_name ]
30- )
3129 output_path = os .path .join (output_folder , "segmentation.zarr" )
32- segmentation = zarr .open (output_path )["segmentation" ][:]
30+ if not os .path .exists (output_path ):
31+ subprocess .run (
32+ ["flamingo_tools.run_segmentation" , "-i" , input_path , "-o" , output_folder , "-m" , model_name ]
33+ )
3334
34- image = imageio .imread (input_path )
35- v = napari .Viewer ()
36- v .add_image (image )
37- v .add_labels (segmentation , name = f"{ model_name } -segmentation" )
38- napari .run ()
35+ if view :
36+ segmentation = zarr .open (output_path )["segmentation" ][:]
37+ image = imageio .imread (input_path )
38+ v = napari .Viewer ()
39+ v .add_image (image )
40+ v .add_labels (segmentation , name = f"{ model_name } -segmentation" )
41+ napari .run ()
3942
4043
4144def check_detection_model ():
@@ -47,36 +50,38 @@ def check_detection_model():
4750 data_path = _sample_registry ().fetch (data_dict [model_name ])
4851 copyfile (data_path , input_path )
4952
50- subprocess .run (
51- ["flamingo_tools.run_detection" , "-i" , input_path , "-o" , output_folder , "-m" , model_name ]
52- )
5353 output_path = os .path .join (output_folder , "synapse_detection.tsv" )
54- prediction = pd .read_csv (output_path , sep = "\t " )[["z" , "y" , "x" ]]
54+ if not os .path .exists (output_path ):
55+ subprocess .run (
56+ ["flamingo_tools.run_detection" , "-i" , input_path , "-o" , output_folder , "-m" , model_name ]
57+ )
5558
56- image = imageio .imread (input_path )
57- v = napari .Viewer ()
58- v .add_image (image )
59- v .add_points (prediction )
60- napari .run ()
59+ if view :
60+ prediction = pd .read_csv (output_path , sep = "\t " )[["z" , "y" , "x" ]]
61+ image = imageio .imread (input_path )
62+ v = napari .Viewer ()
63+ v .add_image (image )
64+ v .add_points (prediction )
65+ napari .run ()
6166
6267
6368def main ():
6469 # SGN segmentation:
6570 # - Prediction works well on the CPU.
66- # check_segmentation_model("SGN")
71+ check_segmentation_model ("SGN" )
6772
6873 # IHC segmentation:
6974 # - Prediction does not work well on the CPU.
70- # check_segmentation_model("IHC")
75+ check_segmentation_model ("IHC" )
7176
7277 # SGN segmentation (lowres):
7378 # - Prediction does not work well on the CPU.
74- # check_segmentation_model("SGN-lowres")
79+ check_segmentation_model ("SGN-lowres" )
7580
7681 # IHC segmentation (lowres):
7782 # - The prediction seems to work (on the CPU), but a lot of merges.
7883 # -> Update the segmentation params?
79- # check_segmentation_model("IHC-lowres")
84+ check_segmentation_model ("IHC-lowres" )
8085
8186 # Synapse detection:
8287 # - Prediction works well on the CPU.
0 commit comments