@@ -129,6 +129,8 @@ def __init__(
129129 row_names = list (self .graphs .keys ())
130130 row_names .sort ()
131131 col_names = [
132+ "SWC Name" ,
133+ "SWC Run Length" ,
132134 "# Splits" ,
133135 "# Merges" ,
134136 "Split Rate" ,
@@ -139,9 +141,9 @@ def __init__(
139141 "Edge Accuracy" ,
140142 "ERL" ,
141143 "Normalized ERL" ,
142- "GT Run Length" ,
143144 ]
144145 self .metrics = pd .DataFrame (index = row_names , columns = col_names )
146+ self .metrics ["SWC Name" ] = self .metrics .index
145147
146148 # --- Load Data ---
147149 def load_groundtruth (self , swc_pointer , label_mask ):
@@ -293,13 +295,13 @@ def run(self):
293295 path = f"{ self .output_dir } /{ prefix } results.csv"
294296 if self .fragment_graphs is None :
295297 self .metrics = self .metrics .drop ("# Merges" , axis = 1 )
296- self .metrics .to_csv (path , index = True )
298+ self .metrics .to_csv (path , index = False )
297299
298300 # Report results
299301 path = os .path .join (self .output_dir , f"{ prefix } results-overview.txt" )
300302 util .update_txt (path , "Average Results..." )
301303 for column_name in self .metrics .columns :
302- if column_name != "GT Run Length" :
304+ if column_name != "SWC Run Length" and column_name != "SWC Name " :
303305 avg = self .compute_weighted_avg (column_name )
304306 util .update_txt (path , f" { column_name } : { avg :.4f} " )
305307
@@ -345,7 +347,7 @@ def detect_splits(self):
345347 self .metrics .at [key , "Split Rate" ] = rl / max (n_splits , 1 )
346348 self .metrics .loc [key , "% Split Edges" ] = round (p_split , 2 )
347349 self .metrics .at [key , "% Omit Edges" ] = round (p_omit , 2 )
348- self .metrics .loc [key , "GT Run Length" ] = round (gt_rl , 2 )
350+ self .metrics .loc [key , "SWC Run Length" ] = round (gt_rl , 2 )
349351
350352 if self .verbose :
351353 pbar .update (1 )
@@ -532,9 +534,10 @@ def process_merge_sites(self):
532534 idx_mask = self .merge_sites ["GroundTruth_ID" ] == key
533535 n_merges = int (idx_mask .sum ())
534536 rl = np .sum (self .graphs [key ].run_lengths ())
537+ merge_rate = rl / n_merges if n_merges > 0 else np .nan
535538
536539 self .metrics .loc [key , "# Merges" ] = n_merges
537- self .metrics .loc [key , "Merge Rate" ] = rl / max ( n_merges , 1 )
540+ self .metrics .loc [key , "Merge Rate" ] = merge_rate
538541
539542 # Save results
540543 path = os .path .join (self .output_dir , "merge_sites.csv" )
@@ -626,8 +629,20 @@ def compute_erl(self):
626629 self .metrics .loc [key , "Normalized ERL" ] = n_erl
627630
628631 def compute_weighted_avg (self , column_name ):
629- wgt = self .metrics ["GT Run Length" ]
630- return (self .metrics [column_name ] * wgt ).sum () / wgt .sum ()
632+ # Extract values
633+ values = self .metrics [column_name ]
634+ weights = self .metrics ["SWC Run Length" ]
635+
636+ # Ignore NaNs
637+ mask = values .notna () & weights .notna ()
638+ values = values [mask ]
639+ weights = weights [mask ]
640+
641+ # Compute weighted mean
642+ if weights .sum () == 0 :
643+ return float ("nan" )
644+ else :
645+ return (values * weights ).sum () / weights .sum ()
631646
632647 # -- Helpers --
633648 def branch_search (self , graph , kdtree , root , radius = 100 ):
0 commit comments