88
99"""
1010
11- from segmentation_skeleton_metrics .skeleton_metric import SkeletonMetric
11+ from zipfile import ZipFile
12+
13+ import numpy as np
14+ import os
15+ import pandas as pd
16+
17+ from segmentation_skeleton_metrics .skeleton_metrics import (
18+ MergeCountMetric ,
19+ MergeRateMetric ,
20+ MergedEdgePercentMetric ,
21+ OmitEdgePercentMetric ,
22+ SplitEdgePercentMetric ,
23+ SplitCountMetric ,
24+ SplitRateMetric ,
25+ EdgeAccuracyMetric ,
26+ ERLMetric ,
27+ NormalizedERLMetric
28+ )
1229from segmentation_skeleton_metrics .data_handling .graph_loading import (
1330 DataLoader , LabelHandler
1431)
32+ from segmentation_skeleton_metrics .utils import util
1533
1634
1735def evaluate (
@@ -21,6 +39,7 @@ def evaluate(
2139 anisotropy = (1.0 , 1.0 , 1.0 ),
2240 connections_path = None ,
2341 fragments_pointer = None ,
42+ results_filename = "results" ,
2443 save_merges = False ,
2544 save_fragments = False ,
2645 use_anisotropy = False ,
@@ -56,27 +75,117 @@ def evaluate(
5675 anisotropy = anisotropy ,
5776 use_anisotropy = use_anisotropy ,
5877 verbose = verbose
59- )
78+ )
6079 gt_graphs = dataloader .load_groundtruth (gt_pointer , label_mask )
6180 fragment_graphs = dataloader .load_fragments (fragments_pointer , gt_graphs )
6281
63- # Evaluator
64- skeleton_metric = SkeletonMetric (
65- gt_graphs ,
66- fragment_graphs ,
67- label_handler ,
68- anisotropy = anisotropy ,
69- output_dir = output_dir ,
70- save_merges = save_merges ,
71- save_fragments = save_fragments ,
72- )
73- skeleton_metric .run ()
82+ # Run evaluation
83+ evaluator = Evaluator (output_dir , results_filename , verbose )
84+ evaluator .run (gt_graphs , fragment_graphs )
7485
75- # Compute metrics
86+ # Optional saves
87+ if save_merges :
88+ evaluator .save_merge_results ()
7689
77- # Report results
90+ if save_fragments and fragment_graphs :
91+ evaluator .save_fragments (gt_graphs , fragment_graphs )
7892
7993
8094# --- Evaluator ---
8195class Evaluator :
82- pass
96+
97+ def __init__ (self , output_dir , results_filename , verbose = True ):
98+ # Instance attributes
99+ self .output_dir = output_dir
100+ self .results_filename = results_filename
101+ self .verbose = verbose
102+
103+ # Set core metrics
104+ self .metrics = {
105+ "# Splits" : SplitCountMetric (verbose = verbose ),
106+ "# Merges" : MergeCountMetric (verbose = verbose ),
107+ "% Split Edges" : SplitEdgePercentMetric (verbose = verbose ),
108+ "% Omit Edges" : OmitEdgePercentMetric (verbose = verbose ),
109+ "% Merged Edges" : MergedEdgePercentMetric (verbose = verbose ),
110+ "ERL" : ERLMetric (verbose = verbose )
111+ }
112+
113+ # Set derived metrics
114+ self .derived_metrics = {
115+ "Normalized ERL" : NormalizedERLMetric (verbose = verbose ),
116+ "Edge Accuracy" : EdgeAccuracyMetric (verbose = verbose ),
117+ "Split Rate" : SplitRateMetric (verbose = verbose ),
118+ "Merge Rate" : MergeRateMetric (verbose = verbose ),
119+ }
120+
121+ # --- Core Routines ---
122+ def run (self , gt_graphs , fragment_graphs = None ):
123+ # Printout step
124+ if self .verbose :
125+ print ("\n (3) Compute Metrics" )
126+
127+ # Compute core metrics
128+ results = self .init_results (gt_graphs )
129+ for name , metric in self .metrics .items ():
130+ if name == "# Merges" and fragment_graphs :
131+ results [name ] = metric .compute (gt_graphs , fragment_graphs )
132+ elif name != "# Merges" :
133+ results .update (metric .compute (gt_graphs ))
134+
135+ # Compute derived metrics
136+ for name , metric in self .derived_metrics .items ():
137+ if name == "Merge Rate" and fragment_graphs :
138+ results [name ] = metric .compute (gt_graphs , results )
139+ elif name != "Merge Rate" :
140+ results [name ] = metric .compute (gt_graphs , results )
141+
142+ # Save report
143+ path = f"{ self .output_dir } /{ self .results_filename } .csv"
144+ results .to_csv (path , index = False )
145+ self .report_summary (results )
146+
147+ def init_results (self , gt_graphs ):
148+ cols = list (self .metrics .keys ()) + list (self .derived_metrics .keys ())
149+ index = list (gt_graphs .keys ())
150+ index .sort ()
151+ return pd .DataFrame (np .nan , index = index , columns = cols )
152+
153+ def report_summary (self , results ):
154+ pass
155+
156+ # --- Writers ---
157+ def save_fragments (self ):
158+ pass
159+
160+ def save_merge_results (self , gt_graphs , fragment_graphs , output_dir ):
161+ # Initialize a writer
162+ zip_path = os .path .join (output_dir , "merged_fragments.zip" )
163+ util .rm_file (zip_path )
164+ zip_writer = ZipFile (zip_path , "a" )
165+
166+ # Save SWC files
167+ self .save_merge_sites (zip_writer )
168+ self .save_skeletons_with_merge (gt_graphs , fragment_graphs , zip_writer )
169+ zip_writer .close ()
170+
171+ # Save CSV file
172+ path = os .path .join (output_dir , "merge_sites.csv" )
173+ self .merge_sites .to_csv (path , index = True )
174+
175+ def save_merge_sites (self , zip_writer ):
176+ merge_sites = self .metrics ["# Merges" ].merge_sites
177+ for i in range (len (merge_sites )):
178+ filename = merge_sites .index [i ]
179+ xyz = merge_sites ["World" ].iloc [i ]
180+ util .to_zipped_point (zip_writer , filename , xyz )
181+
182+ def save_skeletons_with_merge (
183+ self , gt_graphs , fragment_graphs , zip_writer
184+ ):
185+ # Save ground truth skeletons
186+ for key in self .merge_sites ["GroundTruth_ID" ].unique ():
187+ gt_graphs [key ].to_zipped_swc (zip_writer )
188+
189+ # Save fragments
190+ for key in self .fragments_with_merge :
191+ fragment_graphs [key ].to_zipped_swc (zip_writer )
0 commit comments