@@ -136,7 +136,7 @@ def __init__(
136136 self .load_fragments (fragments_pointer )
137137
138138 # Initialize metrics
139- util .mkdir (output_dir , delete = True )
139+ util .mkdir (output_dir )
140140 self .init_writers ()
141141 self .merge_sites = list ()
142142
@@ -602,6 +602,10 @@ def find_merge_site(self, key, kdtree, fragment_graph, source, visited):
602602 voxel = fragment_graph .voxels [node ]
603603 gt_voxel = util .kdtree_query (kdtree , voxel )
604604 if self .physical_dist (gt_voxel , voxel ) < 3 :
605+ # Local search
606+ node = self .branch_search (fragment_graph , kdtree , node )
607+ voxel = fragment_graph .voxels [node ]
608+
605609 # Log merge mistake
606610 segment_id = util .get_segment_id (fragment_graph .filename )
607611 xyz = img_util .to_physical (voxel , self .anisotropy )
@@ -815,6 +819,48 @@ def compute_weighted_avg(self, column_name):
815819 return (self .metrics [column_name ] * wgt ).sum () / wgt .sum ()
816820
817821 # -- Helpers --
822+ def branch_search (self , graph , kdtree , root , radius = 70 ):
823+ """
824+ Searches for a branching node within distance "radius" from the given
825+ root node.
826+
827+ Parameters
828+ ----------
829+ graph : networkx.Graph
830+ Graph to be searched.
831+ kdtree : ...
832+ KDTree containing voxel coordinates from a ground truth tracing.
833+ root : int
834+ Root of search.
835+ radius : float, optional
836+ Distance to search from root. The default is 70.
837+
838+ Returns
839+ -------
840+ int
841+ Root node or closest branching node within distance "radius".
842+
843+ """
844+ queue = list ([(root , 0 )])
845+ visited = set ({root })
846+ while queue :
847+ # Visit node
848+ i , d_i = queue .pop ()
849+ voxel_i = graph .voxels [i ]
850+ if graph .degree [i ] > 2 :
851+ gt_voxel = util .kdtree_query (kdtree , voxel_i )
852+ if self .physical_dist (gt_voxel , voxel_i ) < 16 :
853+ return i
854+
855+ # Update queue
856+ for j in graph .neighbors (i ):
857+ voxel_j = graph .voxels [j ]
858+ d_j = d_i + self .physical_dist (voxel_i , voxel_j )
859+ if j not in visited and d_j < radius :
860+ queue .append ((j , d_j ))
861+ visited .add (j )
862+ return root
863+
818864 def find_graph_from_label (self , label ):
819865 graphs = list ()
820866 for key in self .fragment_graphs :
0 commit comments