@@ -63,8 +63,10 @@ def __init__(
6363 anisotropy = (1.0 , 1.0 , 1.0 ),
6464 connections_path = None ,
6565 fragments_pointer = None ,
66+ localize_merge = False ,
6667 preexisting_merges = None ,
6768 save_merges = False ,
69+ save_projections = False ,
6870 valid_labels = None ,
6971 ):
7072 """
@@ -92,12 +94,18 @@ def __init__(
9294 "swc_util.Reader" for documentation. Notes: (1) "anisotropy" is
9395 applied to these SWC files and (2) these SWC files are required
9496 for counting merges. The default is None.
97+ localize_merge : bool, optional
98+ Indication of whether to search for the approximate location of a
99+ merge. The default is False.
95100 preexisting_merges : List[int], optional
96101 List of segment IDs that are known to contain a merge mistake. The
97102 default is None.
98103 save_merges: bool, optional
99104 Indication of whether to save fragments with a merge mistake. The
100105 default is None.
106+ save_projections : bool, optional
107+ Indication of whether to save fragments that project onto each
108+ ground truth skeleton. The default is False.
101109 valid_labels : set[int], optional
102110 Segment IDs that can be assigned to nodes. This argument accounts
103111 for segments that were been removed due to some type of filtering.
@@ -111,9 +119,11 @@ def __init__(
111119 # Instance attributes
112120 self .anisotropy = anisotropy
113121 self .connections_path = connections_path
122+ self .localize_merge = localize_merge
114123 self .output_dir = output_dir
115124 self .preexisting_merges = preexisting_merges
116125 self .save_merges = save_merges
126+ self .save_projections = save_projections
117127
118128 # Label handler
119129 self .label_handler = gutil .LabelHandler (
@@ -129,6 +139,11 @@ def __init__(
129139 if self .save_merges :
130140 self .init_zip_writer ()
131141
142+ # Initialize fragment projections directory
143+ if self .save_projections :
144+ self .projections_dir = os .path .join (output_dir , "projections" )
145+ util .mkdir (self .projections_dir )
146+
132147 # --- Load Data ---
133148 def load_groundtruth (self , swc_pointer ):
134149 """
@@ -346,13 +361,13 @@ def init_zip_writer(self):
346361
347362 """
348363 # Initialize output directory
349- projections_dir = os .path .join (self .output_dir , "projections " )
350- util .mkdir (projections_dir )
364+ merged_fragments_dir = os .path .join (self .output_dir , "merged_fragments " )
365+ util .mkdir (merged_fragments_dir )
351366
352367 # Save intial graphs
353368 self .zip_writer = dict ()
354369 for key in self .graphs .keys ():
355- zip_path = f"{ projections_dir } /{ key } .zip"
370+ zip_path = f"{ merged_fragments_dir } /{ key } .zip"
356371 self .zip_writer [key ] = ZipFile (zip_path , "w" )
357372 self .graphs [key ].to_zipped_swc (self .zip_writer [key ])
358373
@@ -524,12 +539,21 @@ def count_merges(self, key, kdtree):
524539 None
525540
526541 """
542+ # Initialize zip writer
543+ if self .save_projections :
544+ zip_path = os .path .join (self .projections_dir , key + ".zip" )
545+ zip_writer = ZipFile (zip_path , "w" )
546+
547+ # Iterate over fragments that intersect with GT skeleton
527548 for label in self .get_node_labels (key ):
528549 nodes = self .graphs [key ].nodes_with_label (label )
529550 if len (nodes ) > 40 :
530551 for label in self .label_handler .get_class (label ):
531552 if label in self .fragment_ids :
532553 self .is_fragment_merge (key , label , kdtree )
554+ if self .save_projections :
555+ fragment_graph = self .find_graph_from_label (label )[0 ]
556+ fragment_graph .to_zipped_swc (zip_writer )
533557
534558 def is_fragment_merge (self , key , label , kdtree ):
535559 """
@@ -553,32 +577,33 @@ def is_fragment_merge(self, key, label, kdtree):
553577 None
554578
555579 """
556- fragment_graph = self .find_graph_from_label (label )
557-
558- max_dist = 0
559- min_dist = np .inf
560-
561- for voxel in fragment_graph .voxels :
562- # Find closest point in ground truth
563- gt_voxel = util .kdtree_query (kdtree , voxel )
564-
565- # Compute projection distance
566- dist = self .physical_dist (gt_voxel , voxel )
567- min_dist = min (dist , min_dist )
568- max_dist = max (dist , max_dist )
569-
570- # Check if distances imply merge mistake
571- if max_dist > 100 and min_dist < 3 :
572- # Log merge mistake
573- equiv_label = self .label_handler .get (label )
574- xyz = img_util .to_physical (voxel , self .anisotropy )
575- self .merge_cnt [key ] += 1
576- self .merged_labels .add ((key , equiv_label , tuple (xyz )))
577-
578- # Save merged fragment (if applicable)
579- if self .save_merges :
580- fragment_graph .to_zipped_swc (self .zip_writer [key ])
581- break
580+ # Search graphs
581+ for fragment_graph in self .find_graph_from_label (label ):
582+ max_dist = 0
583+ min_dist = np .inf
584+ for voxel in fragment_graph .voxels :
585+ # Find closest point in ground truth
586+ gt_voxel = util .kdtree_query (kdtree , voxel )
587+
588+ # Compute projection distance
589+ dist = self .physical_dist (gt_voxel , voxel )
590+ min_dist = min (dist , min_dist )
591+ max_dist = max (dist , max_dist )
592+
593+ # Check if distances imply merge mistake
594+ if max_dist > 100 and min_dist < 3 :
595+ # Log merge mistake
596+ equiv_label = self .label_handler .get (label )
597+ xyz = img_util .to_physical (voxel , self .anisotropy )
598+ self .merge_cnt [key ] += 1
599+ self .merged_labels .add ((key , equiv_label , tuple (xyz )))
600+
601+ # Save merged fragment (if applicable)
602+ if self .save_merges :
603+ fragment_graph .to_zipped_swc (self .zip_writer [key ])
604+ if self .localize_merge :
605+ self .find_merge_site (key , fragment_graph , kdtree )
606+ break
582607
583608 def adjust_metrics (self , key ):
584609 """
@@ -674,6 +699,29 @@ def process_merge(self, key, label, xyz, update_merged_labels=True):
674699 if update_merged_labels :
675700 self .merged_labels .add ((key , label , - 1 ))
676701
702+ def find_merge_site (self , key , fragment_graph , kdtree ):
703+ visited = set ()
704+ hit = False
705+ for i , voxel in enumerate (fragment_graph .voxels ):
706+ # Find closest point in ground truth
707+ visited .add (i )
708+ gt_voxel = util .kdtree_query (kdtree , voxel )
709+
710+ # Compute projection distance
711+ if self .physical_dist (gt_voxel , voxel ) > 100 :
712+ for _ , j in nx .dfs_edges (fragment_graph , source = i ):
713+ visited .add (j )
714+ voxel_j = fragment_graph .voxels [j ]
715+ gt_voxel = util .kdtree_query (kdtree , voxel_j )
716+ if self .physical_dist (gt_voxel , voxel_j ) < 2 :
717+ hit = True
718+ print ("Approximate Site:" , img_util .to_physical (voxel_j , self .anisotropy ))
719+ break
720+
721+ # Check whether to continue
722+ if hit :
723+ break
724+
677725 def quantify_merges (self ):
678726 """
679727 Computes the percentage of merged edges for each graph.
@@ -966,10 +1014,11 @@ def list_metrics(self):
9661014
9671015 # -- Helpers --
9681016 def find_graph_from_label (self , label ):
1017+ graphs = list ()
9691018 for key in self .fragment_graphs :
9701019 if label == util .get_segment_id (key ):
971- return self .fragment_graphs [key ]
972- return None
1020+ graphs . append ( self .fragment_graphs [key ])
1021+ return graphs
9731022
9741023 def physical_dist (self , voxel_1 , voxel_2 ):
9751024 """
0 commit comments