1111import  pandas  as  pd 
1212from  tqdm  import  tqdm 
1313import  numpy  as  np 
14+ from  scipy .ndimage  import  label 
15+ from  skimage .measure  import  regionprops 
1416
1517from  membrain_seg .segmentation .skeletonize  import  skeletonization 
1618from  membrain_seg .benchmark .metrics  import  masked_surface_dice 
1719
1820
1921def  load_segmentation (file_path , key ):
20-     """Load a dataset from an HDF5 file.""" 
2122    with  h5py .File (file_path , "r" ) as  f :
2223        data  =  f [key ][:]
2324    return  data 
2425
2526
2627def  evaluate_surface_dice (pred , gt , raw , check ):
27-     """Skeletonize predictions and GT, compute surface dice.""" 
2828    gt_skeleton  =  skeletonization (gt  ==  1 , batch_size = 100000 )
2929    pred_skeleton  =  skeletonization (pred , batch_size = 100000 )
3030    mask  =  gt  !=  2 
@@ -33,10 +33,10 @@ def evaluate_surface_dice(pred, gt, raw, check):
3333        import  napari 
3434        v  =  napari .Viewer ()
3535        v .add_image (raw )
36-         v .add_labels (gt , name =   f "gt" )
37-         v .add_labels (gt_skeleton .astype (np .uint16 ), name =   f "gt_skeleton" )
38-         v .add_labels (pred , name =   f "pred" )
39-         v .add_labels (pred_skeleton .astype (np .uint16 ), name =   f "pred_skeleton" )
36+         v .add_labels (gt , name = "gt" )
37+         v .add_labels (gt_skeleton .astype (np .uint16 ), name = "gt_skeleton" )
38+         v .add_labels (pred , name = "pred" )
39+         v .add_labels (pred_skeleton .astype (np .uint16 ), name = "pred_skeleton" )
4040        napari .run ()
4141
4242    surf_dice , confusion_dict  =  masked_surface_dice (
@@ -45,28 +45,80 @@ def evaluate_surface_dice(pred, gt, raw, check):
4545    return  surf_dice , confusion_dict 
4646
4747
48- def  process_file (pred_path , gt_path , seg_key , gt_key , check ):
49-     """Process a single prediction/GT file pair.""" 
48+ def  process_file (pred_path , gt_path , seg_key , gt_key , check ,
49+                  min_bb_shape = (32 , 384 , 384 ), min_thinning_size = 2500 ,
50+                  global_eval = False ):
5051    try :
5152        pred  =  load_segmentation (pred_path , seg_key )
5253        gt  =  load_segmentation (gt_path , gt_key )
5354        raw  =  load_segmentation (gt_path , "raw" )
54-         surf_dice , confusion  =  evaluate_surface_dice (pred , gt , raw , check )
5555
56-         result  =  {
57-             "tomo_name" : os .path .basename (pred_path ),
58-             "surface_dice" : surf_dice ,
59-             ** confusion ,
60-         }
61-         return  result 
56+         if  global_eval :
57+             gt_bin  =  (gt  ==  1 ).astype (np .uint8 )
58+             pred_bin  =  pred .astype (np .uint8 )
59+ 
60+             dice , confusion  =  evaluate_surface_dice (pred_bin , gt_bin , raw , check )
61+             return  [{
62+                 "tomo_name" : os .path .basename (pred_path ),
63+                 "gt_component_id" : - 1 ,  # -1 indicates global eval 
64+                 "surface_dice" : dice ,
65+                 ** confusion 
66+             }]
67+ 
68+         labeled_gt , _  =  label (gt  ==  1 )
69+         props  =  regionprops (labeled_gt )
70+         results  =  []
71+ 
72+         for  prop  in  props :
73+             if  prop .area  <  min_thinning_size :
74+                 continue 
75+ 
76+             comp_id  =  prop .label 
77+             bbox_start  =  prop .bbox [:3 ]
78+             bbox_end  =  prop .bbox [3 :]
79+             bbox  =  tuple (slice (start , stop ) for  start , stop  in  zip (bbox_start , bbox_end ))
80+ 
81+             pad_width  =  [
82+                 max (min_shape  -  (sl .stop  -  sl .start ), 0 ) //  2 
83+                 for  sl , min_shape  in  zip (bbox , min_bb_shape )
84+             ]
85+ 
86+             expanded_bbox  =  tuple (
87+                 slice (
88+                     max (sl .start  -  pw , 0 ),
89+                     min (sl .stop  +  pw , dim )
90+                 )
91+                 for  sl , pw , dim  in  zip (bbox , pad_width , gt .shape )
92+             )
93+ 
94+             gt_crop  =  (labeled_gt [expanded_bbox ] ==  comp_id ).astype (np .uint8 )
95+             pred_crop  =  pred [expanded_bbox ].astype (np .uint8 )
96+             raw_crop  =  raw [expanded_bbox ]
97+ 
98+             try :
99+                 dice , confusion  =  evaluate_surface_dice (pred_crop , gt_crop , raw_crop , check )
100+             except  Exception  as  e :
101+                 print (f"Error computing Dice for GT component { comp_id }   in { pred_path }  : { e }  " )
102+                 continue 
103+ 
104+             result  =  {
105+                 "tomo_name" : os .path .basename (pred_path ),
106+                 "gt_component_id" : comp_id ,
107+                 "surface_dice" : dice ,
108+                 ** confusion 
109+             }
110+             results .append (result )
111+ 
112+         return  results 
62113
63114    except  Exception  as  e :
64115        print (f"Error processing { pred_path }  : { e }  " )
65-         return  None 
116+         return  [] 
66117
67118
68- def  collect_results (input_folder , gt_folder , version , check = False ):
69-     """Loop through prediction files and compute metrics.""" 
119+ def  collect_results (input_folder , gt_folder , version , check = False ,
120+                     min_bb_shape = (32 , 384 , 384 ), min_thinning_size = 2500 ,
121+                     global_eval = False ):
70122    results  =  []
71123    seg_key  =  f"predictions/az/seg_v{ version }  " 
72124    gt_key  =  "/labels/az_merged" 
@@ -83,29 +135,32 @@ def collect_results(input_folder, gt_folder, version, check=False):
83135            print (f"Warning: Ground truth file not found for { fname }  " )
84136            continue 
85137
86-         result  =  process_file (pred_path , gt_path , seg_key , gt_key , check )
87-         if  result :
88-             result ["input_folder" ] =  input_folder_name 
89-             results .append (result )
138+         file_results  =  process_file (
139+             pred_path , gt_path , seg_key , gt_key , check ,
140+             min_bb_shape = min_bb_shape ,
141+             min_thinning_size = min_thinning_size ,
142+             global_eval = global_eval 
143+         )
144+ 
145+         for  res  in  file_results :
146+             res ["input_folder" ] =  input_folder_name 
147+             results .append (res )
90148
91149    return  results 
92150
93151
94152def  save_results (results , output_file ):
95-     """Append results to an Excel file, updating rows with matching tomo_name and input_folder.""" 
96153    new_df  =  pd .DataFrame (results )
97154
98155    if  os .path .exists (output_file ):
99156        existing_df  =  pd .read_excel (output_file )
100157
101-         # Drop rows where tomo_name and input_folder match any in new_df 
102158        combined_df  =  existing_df [
103-             ~ existing_df .set_index (["tomo_name" , "input_folder" ]).index .isin (
104-                 new_df .set_index (["tomo_name" , "input_folder" ]).index 
159+             ~ existing_df .set_index (["tomo_name" , "input_folder" ,  "gt_component_id" ]).index .isin (
160+                 new_df .set_index (["tomo_name" , "input_folder" ,  "gt_component_id" ]).index 
105161            )
106162        ]
107163
108-         # Append new data and reset index 
109164        final_df  =  pd .concat ([combined_df , new_df ], ignore_index = True )
110165    else :
111166        final_df  =  new_df 
@@ -114,20 +169,34 @@ def save_results(results, output_file):
114169    print (f"Results saved to { output_file }  " )
115170
116171
117- 
118172def  main ():
119-     parser  =  argparse .ArgumentParser (description = "Compute surface dice for AZ segmentations." )
173+     parser  =  argparse .ArgumentParser (description = "Compute surface dice per GT component or globally  for AZ segmentations." )
120174    parser .add_argument ("--input_folder" , "-i" , required = True , help = "Folder with predicted segmentations (.h5)" )
121175    parser .add_argument ("--gt_folder" , "-gt" , required = True , help = "Folder with ground truth segmentations (.h5)" )
122176    parser .add_argument ("--version" , "-v" , required = True , help = "Version string used in prediction key" )
123-     parser .add_argument ("--check" , action = "store_true" , help = "Version string used in prediction key" )
177+     parser .add_argument ("--check" , action = "store_true" , help = "Visualize intermediate outputs in Napari" )
178+     parser .add_argument ("--global_eval" , action = "store_true" , help = "If set, compute global surface dice instead of per-component" )
124179
125180    args  =  parser .parse_args ()
126181
127-     output_file  =  f"/user/muth9/u12095/synapse-net/scripts/cooper/revision/evaluation_results/v{ args .version }  _surface_dice.xlsx" 
128-     results  =  collect_results (args .input_folder , args .gt_folder , args .version , args .check )
182+     min_bb_shape  =  (32 , 384 , 384 )
183+     min_thinning_size  =  2500 
184+ 
185+     suffix  =  "global"  if  args .global_eval  else  "per_gt_component" 
186+     output_file  =  f"/user/muth9/u12095/synapse-net/scripts/cooper/revision/evaluation_results/v{ args .version }  _surface_dice_{ suffix }  .xlsx" 
187+ 
188+     results  =  collect_results (
189+         args .input_folder ,
190+         args .gt_folder ,
191+         args .version ,
192+         args .check ,
193+         min_bb_shape = min_bb_shape ,
194+         min_thinning_size = min_thinning_size ,
195+         global_eval = args .global_eval 
196+     )
197+ 
129198    save_results (results , output_file )
130199
131200
132201if  __name__  ==  "__main__" :
133-     main ()
202+     main ()
0 commit comments