1+ import os
2+ from glob import glob
3+
14import h5py
5+ import imageio .v3 as imageio
26import napari
37import zarr
48
59from torch_em .util import load_model
610from torch_em .util .prediction import predict_with_halo
711from train_synapse_detection import get_paths
12+ from tqdm import tqdm
13+
14+ OUTPUT_ROOT = "./predictions"
815
916
1017def run_prediction (val_image ):
@@ -15,22 +22,56 @@ def run_prediction(val_image):
1522 return pred .squeeze ()
1623
1724
18- def main ():
19- val_paths , _ = get_paths ("val" )
20- val_image = zarr .open (val_paths [0 ])["raw" ][:]
21-
22- # pred = run_prediction(val_image)
23- # with h5py.File("pred.h5", "a") as f:
24- # f.create_dataset("pred", data=pred, compression="gzip")
25+ def require_prediction (image_data , output_path ):
26+ key = "prediction"
27+ if os .path .exists (output_path ):
28+ with h5py .File (output_path , "r" ) as f :
29+ pred = f [key ][:]
30+ else :
31+ pred = run_prediction (image_data )
32+ with h5py .File (output_path , "w" ) as f :
33+ f .create_dataset (key , data = pred , compression = "gzip" )
34+ return pred
2535
26- with h5py .File ("pred.h5" , "r" ) as f :
27- pred = f ["pred" ][:]
2836
37+ def visualize_results (image_data , pred ):
2938 v = napari .Viewer ()
30- v .add_image (val_image )
39+ v .add_image (image_data )
3140 v .add_image (pred )
3241 napari .run ()
3342
3443
44+ def check_val_image ():
45+ val_paths , _ = get_paths ("val" )
46+ val_path = val_paths [0 ]
47+ val_image = zarr .open (val_path )["raw" ][:]
48+
49+ os .makedirs (os .path .join (OUTPUT_ROOT , "val" ), exist_ok = True )
50+ output_path = os .path .join (OUTPUT_ROOT , "val" , os .path .basename (val_path ).replace (".zarr" , ".h5" ))
51+ pred = require_prediction (val_image , output_path )
52+
53+ visualize_results (val_image , pred )
54+
55+
56+ def check_new_images ():
57+ input_root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_crops"
58+ inputs = glob (os .path .join (input_root , "*.tif" ))
59+ output_folder = os .path .join (OUTPUT_ROOT , "new_crops" )
60+ os .makedirs (output_folder , exist_ok = True )
61+ for path in tqdm (inputs ):
62+ name = os .path .basename (path )
63+ if name == "M_AMD_58L_avgblendfused_RibB.tif" :
64+ continue
65+ image_data = imageio .imread (path )
66+ output_path = os .path .join (output_folder , name .replace (".tif" , ".h5" ))
67+ require_prediction (image_data , output_path )
68+
69+
70+ # TODO update to support post-processing and showing annotations for the val data
71+ def main ():
72+ # check_val_image()
73+ check_new_images ()
74+
75+
3576if __name__ == "__main__" :
3677 main ()
0 commit comments