88import zarr
99from flamingo_tools .test_data import _sample_registry
1010
11- view = False
11+ view = True
1212data_dict = {
1313 "SGN" : "PV" ,
1414 "IHC" : "VGlut3" ,
1818}
1919
2020
21- def check_segmentation_model (model_name ):
21+ def check_segmentation_model (model_name , checkpoint_path = None ):
2222 output_folder = f"result_{ model_name } "
2323 os .makedirs (output_folder , exist_ok = True )
2424 input_path = os .path .join (output_folder , f"{ model_name } .tif" )
@@ -28,9 +28,10 @@ def check_segmentation_model(model_name):
2828
2929 output_path = os .path .join (output_folder , "segmentation.zarr" )
3030 if not os .path .exists (output_path ):
31- subprocess .run (
32- ["flamingo_tools.run_segmentation" , "-i" , input_path , "-o" , output_folder , "-m" , model_name ]
33- )
31+ cmd = ["flamingo_tools.run_segmentation" , "-i" , input_path , "-o" , output_folder , "-m" , model_name ]
32+ if checkpoint_path is not None :
33+ cmd .extend (["-c" , checkpoint_path ])
34+ subprocess .run (cmd )
3435
3536 if view :
3637 segmentation = zarr .open (output_path )["segmentation" ][:]
@@ -68,24 +69,29 @@ def check_detection_model():
6869def main ():
6970 # SGN segmentation:
7071 # - Prediction works well on the CPU.
71- check_segmentation_model ("SGN" )
72+ # - Prediction works well on the GPU.
73+ # check_segmentation_model("SGN")
7274
7375 # IHC segmentation:
74- # - Prediction does not work well on the CPU.
75- check_segmentation_model ("IHC" )
76+ # - Prediction works well on the CPU.
77+ # - Prediction works well on the GPU.
78+ # check_segmentation_model("IHC")
7679
80+ # TODO: Update model.
7781 # SGN segmentation (lowres):
7882 # - Prediction does not work well on the CPU.
79- check_segmentation_model ("SGN-lowres" )
83+ # - Prediction does not work well on the GPU.
84+ check_segmentation_model ("SGN-lowres" , checkpoint_path = "SGN-lowres.pt" )
8085
8186 # IHC segmentation (lowres):
82- # - The prediction seems to work ( on the CPU), but a lot of merges .
83- # -> Update the segmentation params?
84- check_segmentation_model ("IHC-lowres" )
87+ # - Prediction works well on the CPU.
88+ # - Prediction works well on the GPU.
89+ # check_segmentation_model("IHC-lowres")
8590
8691 # Synapse detection:
8792 # - Prediction works well on the CPU.
88- check_detection_model ()
93+ # - Prediction works well on the GPU.
94+ # check_detection_model()
8995
9096
9197if __name__ == "__main__" :
0 commit comments