1+ import argparse
2+ from glob import glob
3+ import os
4+
5+ # import h5py
6+ from elf .io import open_file
7+ # from tifffile import imread
8+ import pandas as pd
9+
10+ from elf .evaluation import matching , symmetric_best_dice_score
11+
12+
13+ def evaluate (labels , vesicles ):
14+ assert labels .shape == vesicles .shape
15+ stats = matching (vesicles , labels )
16+ sbd = symmetric_best_dice_score (vesicles , labels )
17+ return [stats ["f1" ], stats ["precision" ], stats ["recall" ], sbd ]
18+
19+
20+ def summarize_eval (results ):
21+ summary = (
22+ results [["dataset" , "f1-score" , "precision" , "recall" , "SBD score" ]]
23+ .groupby ("dataset" )
24+ .mean ()
25+ .reset_index ("dataset" )
26+ )
27+ total = results [["f1-score" , "precision" , "recall" , "SBD score" ]].mean ().values .tolist ()
28+ summary .iloc [- 1 ] = ["all" ] + total
29+ table = summary .to_markdown (index = False )
30+ print (table )
31+
32+
33+ def evaluate_file (labels_path , seg_path , model_name , segment_key , anno_key , mask_key , output_folder ):
34+ print (f"Evaluate labels \n { labels_path } and segmentations \n { seg_path } " )
35+ labels , seg = None , None
36+ labels = open_file (labels_path )
37+ seg = open_file (seg_path )
38+ if segment_key is not None :
39+ seg = seg [segment_key ][:]
40+ if anno_key is not None :
41+ labels = labels [anno_key ][:]
42+ if labels is None or seg is None :
43+ print ("Could not find label file for" , seg_path )
44+ print ("Skipping..." )
45+
46+ # evaluate the match of ground truth and vesicles
47+ scores = evaluate (labels , seg )
48+
49+ # store results
50+ result_folder = output_folder
51+ os .makedirs (result_folder , exist_ok = True )
52+ result_path = os .path .join (result_folder , f"evaluation_{ model_name } .csv" )
53+ print ("Evaluation results are saved to:" , result_path )
54+
55+ if os .path .exists (result_path ):
56+ results = pd .read_csv (result_path )
57+ else :
58+ results = None
59+ ds_name = os .path .basename (os .path .dirname (labels_path ))
60+ tomo = os .path .basename (labels_path )
61+ res = pd .DataFrame (
62+ [[ds_name , tomo ] + scores ], columns = ["dataset" , "tomogram" , "f1-score" , "precision" , "recall" , "SBD score" ]
63+ )
64+ if results is None :
65+ results = res
66+ else :
67+ results = pd .concat ([results , res ])
68+ results .to_csv (result_path , index = False )
69+
70+ # print results
71+ summarize_eval (results )
72+
73+
74+ def evaluate_folder (labels_path , segmentation_path , model_name , segment_key ,
75+ anno_key , mask_key , output_folder , ext = ".tif" ):
76+ print (f"Evaluating folder { segmentation_path } " )
77+ print (f"Using labels stored in { labels_path } " )
78+
79+ label_paths = get_file_paths (labels_path , ext = ext )
80+ seg_paths = get_file_paths (segmentation_path , ext = ext )
81+ if label_paths is None or seg_paths is None :
82+ print ("Could not find label file or segmentation file" )
83+ return
84+
85+ for seg_path in seg_paths :
86+ label_path = find_label_file (seg_path , label_paths )
87+ if label_path is not None :
88+ evaluate_file (label_path , seg_path , model_name , segment_key , anno_key , mask_key , output_folder )
89+ else :
90+ print ("Could not find label file for" , seg_path )
91+ print ("Skipping..." )
92+
93+
94+ def get_file_paths (path , ext = ".h5" , reverse = False ):
95+ if ext in path :
96+ return [path ]
97+ else :
98+ paths = sorted (glob (os .path .join (path , "**" , f"*{ ext } " ), recursive = True ), reverse = reverse )
99+ return paths
100+
101+
102+ def find_label_file (given_path : str , label_paths : list ) -> str :
103+ """
104+ Find the corresponding label file for a given raw file.
105+ Args:
106+ given_path (str): The path we want to find label file to.
107+ label_paths (list): A list of label file paths.
108+ Returns:
109+ str: The path to the matching label file, or None if no match is found.
110+ """
111+ raw_base = os .path .splitext (os .path .basename (given_path ))[0 ] # Remove extension
112+
113+ for label_path in label_paths :
114+ label_base = os .path .splitext (os .path .basename (label_path ))[0 ] # Remove extension
115+ if raw_base in label_base : # Ensure raw name is contained in label name
116+ return label_path
117+
118+ return None # No match found
119+
120+
121+ def main ():
122+ parser = argparse .ArgumentParser ()
123+ parser .add_argument ("-sp" , "--segmentation_path" , required = True ,
124+ default = "/scratch-grete/projects/nim00007/data/mitochondria/cooper/cristae_test_segmentations/" )
125+ parser .add_argument ("-gp" , "--groundtruth_path" , required = True ,
126+ default = "/scratch-grete/projects/nim00007/data/mitochondria/cooper/cristae_test_segmentations/" )
127+ parser .add_argument ("-n" , "--model_name" , required = True )
128+ parser .add_argument ("-sk" , "--segmentation_key" , default = None , default = "labels/new_cristae_seg" )
129+ parser .add_argument ("-gk" , "--groundtruth_key" , default = None , default = "labels/cristae" )
130+ parser .add_argument ("-m" , "--mask_key" , default = None )
131+ parser .add_argument (
132+ "-o" , "--output_folder" , required = True ,
133+ default = "/scratch-grete/projects/nim00007/data/mitochondria/cooper/cristae_test_segmentations/eval"
134+ )
135+ args = parser .parse_args ()
136+
137+ if os .path .isdir (args .segmentation_path ):
138+ evaluate_folder (args .groundtruth_path , args .segmentation_path , args .model_name , args .segmentation_key ,
139+ args .groundtruth_key ,
140+ args .mask_key , args .output_folder )
141+ else :
142+ evaluate_file (args .groundtruth_path , args .segmentation_path , args .model_name , args .segmentation_key ,
143+ args .groundtruth_key ,
144+ args .mask_key , args .output_folder )
145+
146+
147+ if __name__ == "__main__" :
148+ main ()
0 commit comments