@@ -136,7 +136,7 @@ def __init__(
136136 self .load_fragments (fragments_pointer )
137137
138138 # Initialize metrics
139- util .mkdir (output_dir , delete = True )
139+ util .mkdir (output_dir )
140140 self .init_writers ()
141141 self .merge_sites = list ()
142142
@@ -174,6 +174,7 @@ def load_groundtruth(self, swc_pointer):
174174 print ("\n (1) Load Ground Truth" )
175175 graph_builder = gutil .GraphBuilder (
176176 anisotropy = self .anisotropy ,
177+ is_groundtruth = True ,
177178 label_mask = self .label_mask ,
178179 use_anisotropy = False ,
179180 )
@@ -203,6 +204,7 @@ def load_fragments(self, swc_pointer):
203204 if swc_pointer :
204205 graph_builder = gutil .GraphBuilder (
205206 anisotropy = self .anisotropy ,
207+ is_groundtruth = False ,
206208 selected_ids = self .get_all_node_labels (),
207209 use_anisotropy = self .use_anisotropy ,
208210 )
@@ -464,12 +466,13 @@ def detect_splits(self):
464466 n_missing = n_before - n_after
465467 p_omit = 100 * (n_missing + n_split_edges ) / n_before
466468 p_split = 100 * n_split_edges / n_before
469+ gt_rl = graph .run_length
467470
468471 self .graphs [key ] = graph
469- self .metrics .at [key , "% Omit" ] = p_omit
472+ self .metrics .at [key , "% Omit" ] = round ( p_omit , 2 )
470473 self .metrics .at [key , "# Splits" ] = gutil .count_splits (graph )
471- self .metrics .loc [key , "% Split" ] = p_split
472- self .metrics .loc [key , "GT Run Length" ] = graph . run_length
474+ self .metrics .loc [key , "% Split" ] = round ( p_split , 2 )
475+ self .metrics .loc [key , "GT Run Length" ] = round ( gt_rl , 2 )
473476 pbar .update (1 )
474477
475478 # -- Merge Detection --
@@ -571,8 +574,8 @@ def is_fragment_merge(self, key, label, kdtree):
571574 for leaf in gutil .get_leafs (fragment_graph ):
572575 voxel = fragment_graph .voxels [leaf ]
573576 gt_voxel = util .kdtree_query (kdtree , voxel )
574- if self .physical_dist (gt_voxel , voxel ) > 50 :
575- visited = self .find_merge_site (
577+ if self .physical_dist (gt_voxel , voxel ) > 60 :
578+ self .find_merge_site (
576579 key , kdtree , fragment_graph , leaf , visited
577580 )
578581
@@ -599,28 +602,60 @@ def find_merge_site(self, key, kdtree, fragment_graph, source, visited):
599602 voxel = fragment_graph .voxels [node ]
600603 gt_voxel = util .kdtree_query (kdtree , voxel )
601604 if self .physical_dist (gt_voxel , voxel ) < 3 :
602- # Log merge mistake
603- segment_id = util .get_segment_id (fragment_graph .filename )
604- xyz = img_util .to_physical (voxel , self .anisotropy )
605- self .merged_labels .add ((key , segment_id , xyz ))
606- self .merge_sites .append (
607- {
608- "Segment_ID" : segment_id ,
609- "GroundTruth_ID" : key ,
610- "Voxel" : tuple ([int (t ) for t in voxel ]),
611- "World" : tuple ([float (t ) for t in xyz ]),
612- }
613- )
605+ # Local search
606+ node = self .branch_search (fragment_graph , kdtree , node )
607+ voxel = fragment_graph .voxels [node ]
614608
615- # Save merged fragment (if applicable)
616- if self .save_merges :
617- gutil .write_graph (fragment_graph , self .merge_writer )
618- gutil .write_graph (
619- self .gt_graphs [key ], self .merge_writer
620- )
621- return visited
622- return visited
609+ # Log merge mistake
610+ if self .is_valid_merge (fragment_graph , kdtree , node ):
611+ filename = fragment_graph .filename
612+ segment_id = util .get_segment_id (filename )
613+ xyz = img_util .to_physical (voxel , self .anisotropy )
614+ self .merged_labels .add ((key , segment_id , xyz ))
615+ self .merge_sites .append (
616+ {
617+ "Segment_ID" : segment_id ,
618+ "GroundTruth_ID" : key ,
619+ "Voxel" : tuple ([int (t ) for t in voxel ]),
620+ "World" : tuple ([float (t ) for t in xyz ]),
621+ }
622+ )
623623
624+ # Save merged fragment (if applicable)
625+ if self .save_merges :
626+ gutil .write_graph (
627+ fragment_graph , self .merge_writer
628+ )
629+ gutil .write_graph (
630+ self .gt_graphs [key ], self .merge_writer
631+ )
632+ return
633+
634+ def is_valid_merge (self , graph , kdtree , root ):
635+ n_hits = 0
636+ queue = list ([(root , 0 )])
637+ visited = set ({root })
638+ while queue :
639+ # Visit node
640+ i , d_i = queue .pop ()
641+ voxel_i = graph .voxels [i ]
642+ gt_voxel = util .kdtree_query (kdtree , voxel_i )
643+ if self .physical_dist (gt_voxel , voxel_i ) < 5 :
644+ n_hits += 1
645+
646+ # Check whether to break
647+ if n_hits > 16 :
648+ break
649+
650+ # Update queue
651+ for j in graph .neighbors (i ):
652+ voxel_j = graph .voxels [j ]
653+ d_j = d_i + self .physical_dist (voxel_i , voxel_j )
654+ if j not in visited and d_j < 30 :
655+ queue .append ((j , d_j ))
656+ visited .add (j )
657+ return True if n_hits > 16 else False
658+
624659 def process_merge_sites (self ):
625660 if self .merge_sites :
626661 # Remove duplicates
@@ -632,10 +667,13 @@ def process_merge_sites(self):
632667
633668 # Save merge sites
634669 if self .save_merges :
670+ row_names = list ()
635671 for i in range (len (self .merge_sites )):
636672 filename = f"merge-{ i + 1 } .swc"
637673 xyz = self .merge_sites .iloc [i ]["World" ]
638674 swc_util .to_zipped_point (self .merge_writer , filename , xyz )
675+ row_names .append (filename )
676+ self .merge_sites .index = row_names
639677 self .merge_writer .close ()
640678
641679 # Update counter
@@ -645,7 +683,7 @@ def process_merge_sites(self):
645683
646684 # Save results
647685 path = os .path .join (self .output_dir , "merge_sites.csv" )
648- self .merge_sites .to_csv (path , index = False )
686+ self .merge_sites .to_csv (path , index = True )
649687
650688
651689 def adjust_metrics (self , key ):
@@ -757,7 +795,7 @@ def quantify_merges(self):
757795 """
758796 for key in self .graphs :
759797 p = self .n_merged_edges [key ] / self .graphs [key ].graph ["n_edges" ]
760- self .metrics .loc [key , "% Merged" ] = 100 * p
798+ self .metrics .loc [key , "% Merged" ] = round ( 100 * p , 2 )
761799
762800 # -- Compute Metrics --
763801 def compute_edge_accuracy (self ):
@@ -776,7 +814,8 @@ def compute_edge_accuracy(self):
776814 for key in self .graphs :
777815 p_omit = self .metrics .loc [key , "% Omit" ]
778816 p_merged = self .metrics .loc [key , "% Merged" ]
779- self .metrics .loc [key , "Edge Accuracy" ] = 100 - p_omit - p_merged
817+ edge_accuracy = round (100 - p_omit - p_merged , 2 )
818+ self .metrics .loc [key , "Edge Accuracy" ] = edge_accuracy
780819
781820 def compute_erl (self ):
782821 """
@@ -799,14 +838,57 @@ def compute_erl(self):
799838 wgt = run_lengths / max (np .sum (run_lengths ), 1 )
800839
801840 erl = np .sum (wgt * run_lengths )
802- self .metrics .loc [key , "ERL" ] = erl
803- self .metrics .loc [key , "Normalized ERL" ] = erl / max (run_length , 1 )
841+ n_erl = round (erl / max (run_length , 1 ), 4 )
842+ self .metrics .loc [key , "ERL" ] = round (erl , 2 )
843+ self .metrics .loc [key , "Normalized ERL" ] = n_erl
804844
805845 def compute_weighted_avg (self , column_name ):
806846 wgt = self .metrics ["GT Run Length" ]
807847 return (self .metrics [column_name ] * wgt ).sum () / wgt .sum ()
808848
809849 # -- Helpers --
850+ def branch_search (self , graph , kdtree , root , radius = 70 ):
851+ """
852+ Searches for a branching node within distance "radius" from the given
853+ root node.
854+
855+ Parameters
856+ ----------
857+ graph : networkx.Graph
858+ Graph to be searched.
859+ kdtree : ...
860+ KDTree containing voxel coordinates from a ground truth tracing.
861+ root : int
862+ Root of search.
863+ radius : float, optional
864+ Distance to search from root. The default is 70.
865+
866+ Returns
867+ -------
868+ int
869+ Root node or closest branching node within distance "radius".
870+
871+ """
872+ queue = list ([(root , 0 )])
873+ visited = set ({root })
874+ while queue :
875+ # Visit node
876+ i , d_i = queue .pop ()
877+ voxel_i = graph .voxels [i ]
878+ if graph .degree [i ] > 2 :
879+ gt_voxel = util .kdtree_query (kdtree , voxel_i )
880+ if self .physical_dist (gt_voxel , voxel_i ) < 16 :
881+ return i
882+
883+ # Update queue
884+ for j in graph .neighbors (i ):
885+ voxel_j = graph .voxels [j ]
886+ d_j = d_i + self .physical_dist (voxel_i , voxel_j )
887+ if j not in visited and d_j < radius :
888+ queue .append ((j , d_j ))
889+ visited .add (j )
890+ return root
891+
810892 def find_graph_from_label (self , label ):
811893 graphs = list ()
812894 for key in self .fragment_graphs :
0 commit comments