1111import pandas as pd
1212from tqdm import tqdm
1313import numpy as np
14+ from scipy .ndimage import label
15+ from skimage .measure import regionprops
1416
1517from membrain_seg .segmentation .skeletonize import skeletonization
1618from membrain_seg .benchmark .metrics import masked_surface_dice
1719
1820
1921def load_segmentation (file_path , key ):
20- """Load a dataset from an HDF5 file."""
2122 with h5py .File (file_path , "r" ) as f :
2223 data = f [key ][:]
2324 return data
2425
2526
2627def evaluate_surface_dice (pred , gt , raw , check ):
27- """Skeletonize predictions and GT, compute surface dice."""
2828 gt_skeleton = skeletonization (gt == 1 , batch_size = 100000 )
2929 pred_skeleton = skeletonization (pred , batch_size = 100000 )
3030 mask = gt != 2
@@ -33,10 +33,10 @@ def evaluate_surface_dice(pred, gt, raw, check):
3333 import napari
3434 v = napari .Viewer ()
3535 v .add_image (raw )
36- v .add_labels (gt , name = f "gt" )
37- v .add_labels (gt_skeleton .astype (np .uint16 ), name = f "gt_skeleton" )
38- v .add_labels (pred , name = f "pred" )
39- v .add_labels (pred_skeleton .astype (np .uint16 ), name = f "pred_skeleton" )
36+ v .add_labels (gt , name = "gt" )
37+ v .add_labels (gt_skeleton .astype (np .uint16 ), name = "gt_skeleton" )
38+ v .add_labels (pred , name = "pred" )
39+ v .add_labels (pred_skeleton .astype (np .uint16 ), name = "pred_skeleton" )
4040 napari .run ()
4141
4242 surf_dice , confusion_dict = masked_surface_dice (
@@ -45,28 +45,80 @@ def evaluate_surface_dice(pred, gt, raw, check):
4545 return surf_dice , confusion_dict
4646
4747
48- def process_file (pred_path , gt_path , seg_key , gt_key , check ):
49- """Process a single prediction/GT file pair."""
48+ def process_file (pred_path , gt_path , seg_key , gt_key , check ,
49+ min_bb_shape = (32 , 384 , 384 ), min_thinning_size = 2500 ,
50+ global_eval = False ):
5051 try :
5152 pred = load_segmentation (pred_path , seg_key )
5253 gt = load_segmentation (gt_path , gt_key )
5354 raw = load_segmentation (gt_path , "raw" )
54- surf_dice , confusion = evaluate_surface_dice (pred , gt , raw , check )
5555
56- result = {
57- "tomo_name" : os .path .basename (pred_path ),
58- "surface_dice" : surf_dice ,
59- ** confusion ,
60- }
61- return result
56+ if global_eval :
57+ gt_bin = (gt == 1 ).astype (np .uint8 )
58+ pred_bin = pred .astype (np .uint8 )
59+
60+ dice , confusion = evaluate_surface_dice (pred_bin , gt_bin , raw , check )
61+ return [{
62+ "tomo_name" : os .path .basename (pred_path ),
63+ "gt_component_id" : - 1 , # -1 indicates global eval
64+ "surface_dice" : dice ,
65+ ** confusion
66+ }]
67+
68+ labeled_gt , _ = label (gt == 1 )
69+ props = regionprops (labeled_gt )
70+ results = []
71+
72+ for prop in props :
73+ if prop .area < min_thinning_size :
74+ continue
75+
76+ comp_id = prop .label
77+ bbox_start = prop .bbox [:3 ]
78+ bbox_end = prop .bbox [3 :]
79+ bbox = tuple (slice (start , stop ) for start , stop in zip (bbox_start , bbox_end ))
80+
81+ pad_width = [
82+ max (min_shape - (sl .stop - sl .start ), 0 ) // 2
83+ for sl , min_shape in zip (bbox , min_bb_shape )
84+ ]
85+
86+ expanded_bbox = tuple (
87+ slice (
88+ max (sl .start - pw , 0 ),
89+ min (sl .stop + pw , dim )
90+ )
91+ for sl , pw , dim in zip (bbox , pad_width , gt .shape )
92+ )
93+
94+ gt_crop = (labeled_gt [expanded_bbox ] == comp_id ).astype (np .uint8 )
95+ pred_crop = pred [expanded_bbox ].astype (np .uint8 )
96+ raw_crop = raw [expanded_bbox ]
97+
98+ try :
99+ dice , confusion = evaluate_surface_dice (pred_crop , gt_crop , raw_crop , check )
100+ except Exception as e :
101+ print (f"Error computing Dice for GT component { comp_id } in { pred_path } : { e } " )
102+ continue
103+
104+ result = {
105+ "tomo_name" : os .path .basename (pred_path ),
106+ "gt_component_id" : comp_id ,
107+ "surface_dice" : dice ,
108+ ** confusion
109+ }
110+ results .append (result )
111+
112+ return results
62113
63114 except Exception as e :
64115 print (f"Error processing { pred_path } : { e } " )
65- return None
116+ return []
66117
67118
68- def collect_results (input_folder , gt_folder , version , check = False ):
69- """Loop through prediction files and compute metrics."""
119+ def collect_results (input_folder , gt_folder , version , check = False ,
120+ min_bb_shape = (32 , 384 , 384 ), min_thinning_size = 2500 ,
121+ global_eval = False ):
70122 results = []
71123 seg_key = f"predictions/az/seg_v{ version } "
72124 gt_key = "/labels/az_merged"
@@ -83,29 +135,32 @@ def collect_results(input_folder, gt_folder, version, check=False):
83135 print (f"Warning: Ground truth file not found for { fname } " )
84136 continue
85137
86- result = process_file (pred_path , gt_path , seg_key , gt_key , check )
87- if result :
88- result ["input_folder" ] = input_folder_name
89- results .append (result )
138+ file_results = process_file (
139+ pred_path , gt_path , seg_key , gt_key , check ,
140+ min_bb_shape = min_bb_shape ,
141+ min_thinning_size = min_thinning_size ,
142+ global_eval = global_eval
143+ )
144+
145+ for res in file_results :
146+ res ["input_folder" ] = input_folder_name
147+ results .append (res )
90148
91149 return results
92150
93151
94152def save_results (results , output_file ):
95- """Append results to an Excel file, updating rows with matching tomo_name and input_folder."""
96153 new_df = pd .DataFrame (results )
97154
98155 if os .path .exists (output_file ):
99156 existing_df = pd .read_excel (output_file )
100157
101- # Drop rows where tomo_name and input_folder match any in new_df
102158 combined_df = existing_df [
103- ~ existing_df .set_index (["tomo_name" , "input_folder" ]).index .isin (
104- new_df .set_index (["tomo_name" , "input_folder" ]).index
159+ ~ existing_df .set_index (["tomo_name" , "input_folder" , "gt_component_id" ]).index .isin (
160+ new_df .set_index (["tomo_name" , "input_folder" , "gt_component_id" ]).index
105161 )
106162 ]
107163
108- # Append new data and reset index
109164 final_df = pd .concat ([combined_df , new_df ], ignore_index = True )
110165 else :
111166 final_df = new_df
@@ -114,20 +169,34 @@ def save_results(results, output_file):
114169 print (f"Results saved to { output_file } " )
115170
116171
117-
118172def main ():
119- parser = argparse .ArgumentParser (description = "Compute surface dice for AZ segmentations." )
173+ parser = argparse .ArgumentParser (description = "Compute surface dice per GT component or globally for AZ segmentations." )
120174 parser .add_argument ("--input_folder" , "-i" , required = True , help = "Folder with predicted segmentations (.h5)" )
121175 parser .add_argument ("--gt_folder" , "-gt" , required = True , help = "Folder with ground truth segmentations (.h5)" )
122176 parser .add_argument ("--version" , "-v" , required = True , help = "Version string used in prediction key" )
123- parser .add_argument ("--check" , action = "store_true" , help = "Version string used in prediction key" )
177+ parser .add_argument ("--check" , action = "store_true" , help = "Visualize intermediate outputs in Napari" )
178+ parser .add_argument ("--global_eval" , action = "store_true" , help = "If set, compute global surface dice instead of per-component" )
124179
125180 args = parser .parse_args ()
126181
127- output_file = f"/user/muth9/u12095/synapse-net/scripts/cooper/revision/evaluation_results/v{ args .version } _surface_dice.xlsx"
128- results = collect_results (args .input_folder , args .gt_folder , args .version , args .check )
182+ min_bb_shape = (32 , 384 , 384 )
183+ min_thinning_size = 2500
184+
185+ suffix = "global" if args .global_eval else "per_gt_component"
186+ output_file = f"/user/muth9/u12095/synapse-net/scripts/cooper/revision/evaluation_results/v{ args .version } _surface_dice_{ suffix } .xlsx"
187+
188+ results = collect_results (
189+ args .input_folder ,
190+ args .gt_folder ,
191+ args .version ,
192+ args .check ,
193+ min_bb_shape = min_bb_shape ,
194+ min_thinning_size = min_thinning_size ,
195+ global_eval = args .global_eval
196+ )
197+
129198 save_results (results , output_file )
130199
131200
132201if __name__ == "__main__" :
133- main ()
202+ main ()
0 commit comments