1+ import sys
2+ import os
3+
4+ # Add membrain-seg to Python path
5+ MEMBRAIN_SEG_PATH = "/user/muth9/u12095/membrain-seg/src"
6+ if MEMBRAIN_SEG_PATH not in sys .path :
7+ sys .path .insert (0 , MEMBRAIN_SEG_PATH )
8+
9+ import argparse
10+ import h5py
11+ import pandas as pd
12+ from tqdm import tqdm
13+ import numpy as np
14+
15+ from membrain_seg .segmentation .skeletonize import skeletonization
16+ from membrain_seg .benchmark .metrics import masked_surface_dice
17+
18+
19+ def load_segmentation (file_path , key ):
20+ """Load a dataset from an HDF5 file."""
21+ with h5py .File (file_path , "r" ) as f :
22+ data = f [key ][:]
23+ return data
24+
25+
26+ def evaluate_surface_dice (pred , gt , raw , check ):
27+ """Skeletonize predictions and GT, compute surface dice."""
28+ gt_skeleton = skeletonization (gt == 1 , batch_size = 100000 )
29+ pred_skeleton = skeletonization (pred , batch_size = 100000 )
30+ mask = gt != 2
31+
32+ if check :
33+ import napari
34+ v = napari .Viewer ()
35+ v .add_image (raw )
36+ v .add_labels (gt , name = f"gt" )
37+ v .add_labels (gt_skeleton .astype (np .uint16 ), name = f"gt_skeleton" )
38+ v .add_labels (pred , name = f"pred" )
39+ v .add_labels (pred_skeleton .astype (np .uint16 ), name = f"pred_skeleton" )
40+ napari .run ()
41+
42+ surf_dice , confusion_dict = masked_surface_dice (
43+ pred_skeleton , gt_skeleton , pred , gt , mask
44+ )
45+ return surf_dice , confusion_dict
46+
47+
48+ def process_file (pred_path , gt_path , seg_key , gt_key , check ):
49+ """Process a single prediction/GT file pair."""
50+ try :
51+ pred = load_segmentation (pred_path , seg_key )
52+ gt = load_segmentation (gt_path , gt_key )
53+ raw = load_segmentation (gt_path , "raw" )
54+ surf_dice , confusion = evaluate_surface_dice (pred , gt , raw , check )
55+
56+ result = {
57+ "tomo_name" : os .path .basename (pred_path ),
58+ "surface_dice" : surf_dice ,
59+ ** confusion ,
60+ }
61+ return result
62+
63+ except Exception as e :
64+ print (f"Error processing { pred_path } : { e } " )
65+ return None
66+
67+
68+ def collect_results (input_folder , gt_folder , version , check = False ):
69+ """Loop through prediction files and compute metrics."""
70+ results = []
71+ seg_key = f"predictions/az/seg_v{ version } "
72+ gt_key = "/labels/az_merged"
73+
74+ for fname in tqdm (os .listdir (input_folder ), desc = "Processing segmentations" ):
75+ if not fname .endswith (".h5" ):
76+ continue
77+
78+ pred_path = os .path .join (input_folder , fname )
79+ gt_path = os .path .join (gt_folder , fname )
80+
81+ if not os .path .exists (gt_path ):
82+ print (f"Warning: Ground truth file not found for { fname } " )
83+ continue
84+
85+ result = process_file (pred_path , gt_path , seg_key , gt_key , check )
86+ if result :
87+ results .append (result )
88+
89+ return results
90+
91+
92+ def save_results (results , output_file ):
93+ """Save results as an Excel file."""
94+ df = pd .DataFrame (results )
95+ df .to_excel (output_file , index = False )
96+ print (f"Results saved to { output_file } " )
97+
98+
99+ def main ():
100+ parser = argparse .ArgumentParser (description = "Compute surface dice for AZ segmentations." )
101+ parser .add_argument ("--input_folder" , "-i" , required = True , help = "Folder with predicted segmentations (.h5)" )
102+ parser .add_argument ("--gt_folder" , "-gt" , required = True , help = "Folder with ground truth segmentations (.h5)" )
103+ parser .add_argument ("--version" , "-v" , required = True , help = "Version string used in prediction key" )
104+ parser .add_argument ("--check" , action = "store_true" , help = "Version string used in prediction key" )
105+
106+ args = parser .parse_args ()
107+
108+ output_file = f"/user/muth9/u12095/synapse-net/scripts/cooper/revision/evaluation_results/v{ args .version } _surface_dice.xlsx"
109+ results = collect_results (args .input_folder , args .gt_folder , args .version , args .check )
110+ save_results (results , output_file )
111+
112+
113+ if __name__ == "__main__" :
114+ main ()
0 commit comments