11import os
22from glob import glob
3+ from pathlib import Path
34
45import h5py
56import imageio .v3 as imageio
67import napari
8+ import numpy as np
79import pandas as pd
810import zarr
911
1416from train_synapse_detection import get_paths
1517from tqdm import tqdm
1618
19+ # INPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_crops"
20+ INPUT_ROOT = "./data/test_crops"
1721OUTPUT_ROOT = "./predictions"
22+ DETECTION_OUT_ROOT = "./detections"
1823
1924
2025def run_prediction (val_image ):
@@ -40,19 +45,22 @@ def require_prediction(image_data, output_path):
4045def run_postprocessing (pred ):
4146 # print("Running local max ...")
4247 # coords = blob_dog(pred)
43- coords = peak_local_max (pred , min_distance = 2 , threshold_abs = 0.2 )
48+ coords = peak_local_max (pred , min_distance = 2 , threshold_abs = 0.5 )
4449 # print("... done")
4550 return coords
4651
4752
48- def visualize_results (image_data , pred , coords = None , val_coords = None ):
53+ def visualize_results (image_data , pred , coords = None , val_coords = None , title = None ):
4954 v = napari .Viewer ()
5055 v .add_image (image_data )
56+ pred = pred .clip (0 , pred .max ())
5157 v .add_image (pred )
52- if coords is None :
53- v .add_points (coords , name = "predicted_synapses" )
54- if val_coords is None :
58+ if coords is not None :
59+ v .add_points (coords , name = "predicted_synapses" , face_color = "yellow" )
60+ if val_coords is not None :
5561 v .add_points (val_coords , face_color = "green" , name = "synapse_annotations" )
62+ if title is not None :
63+ v .title = title
5664 napari .run ()
5765
5866
@@ -68,9 +76,8 @@ def check_val_image():
6876 visualize_results (val_image , pred )
6977
7078
71- def check_new_images ():
72- input_root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_crops"
73- inputs = glob (os .path .join (input_root , "*.tif" ))
79+ def check_new_images (view = False , save_detection = False ):
80+ inputs = glob (os .path .join (INPUT_ROOT , "*.tif" ))
7481 output_folder = os .path .join (OUTPUT_ROOT , "new_crops" )
7582 os .makedirs (output_folder , exist_ok = True )
7683 for path in tqdm (inputs ):
@@ -80,13 +87,27 @@ def check_new_images():
8087 continue
8188 image_data = imageio .imread (path )
8289 output_path = os .path .join (output_folder , name .replace (".tif" , ".h5" ))
83- require_prediction (image_data , output_path )
90+ # if not os.path.exists(output_path):
91+ # continue
92+ pred = require_prediction (image_data , output_path )
93+ if view or save_detection :
94+ coords = run_postprocessing (pred )
95+ if view :
96+ print ("Number of synapses:" , len (coords ))
97+ visualize_results (image_data , pred , coords = coords , title = name )
98+ if save_detection :
99+ os .makedirs (DETECTION_OUT_ROOT , exist_ok = True )
100+ coords = np .concatenate ([np .arange (0 , len (coords ))[:, None ], coords ], axis = 1 )
101+ coords = pd .DataFrame (coords , columns = ["index" , "axis-0" , "axis-1" , "axis-2" ])
102+ fname = Path (path ).stem
103+ detection_save_path = os .path .join (DETECTION_OUT_ROOT , f"{ fname } .csv" )
104+ coords .to_csv (detection_save_path , index = False )
84105
85106
86107# TODO update to support post-processing and showing annotations for the val data
87108def main ():
88109 # check_val_image()
89- check_new_images ()
110+ check_new_images (view = False , save_detection = True )
90111
91112
92113if __name__ == "__main__" :
0 commit comments