@@ -66,7 +66,6 @@ def __init__(
6666 anisotropy = (1.0 , 1.0 , 1.0 ),
6767 connections_path = None ,
6868 fragments_pointer = None ,
69- localize_merges = False ,
7069 preexisting_merges = None ,
7170 save_merges = False ,
7271 save_fragments = False ,
@@ -97,9 +96,6 @@ def __init__(
9796 "swc_util.Reader" for documentation. Notes: (1) "anisotropy" is
9897 applied to these SWC files and (2) these SWC files are required
9998 for counting merges. The default is None.
100- localize_merges : bool, optional
101- Indication of whether to search for the approximate location of a
102- merge. The default is False.
10399 preexisting_merges : List[int], optional
104100 List of segment IDs that are known to contain a merge mistake. The
105101 default is None.
@@ -122,7 +118,7 @@ def __init__(
122118 # Instance attributes
123119 self .anisotropy = anisotropy
124120 self .connections_path = connections_path
125- self .localize_merges = localize_merges
121+ self .merge_sites = list ()
126122 self .output_dir = output_dir
127123 self .preexisting_merges = preexisting_merges
128124 self .save_merges = save_merges
@@ -373,10 +369,9 @@ def init_writers(self):
373369 self .graphs [key ].to_zipped_swc (self .fragment_writer [key ])
374370
375371 # Merged fragments writer
376- if self .save_merges or self . localize_merges :
372+ if self .save_merges :
377373 zip_path = os .path .join (self .output_dir , "merged_fragments.zip" )
378374 self .merge_writer = ZipFile (zip_path , "a" )
379- self .merge_sites = list ()
380375
381376 # -- Main Routine --
382377 def run (self ):
@@ -410,12 +405,10 @@ def run(self):
410405 path = f"{ self .output_dir } /{ prefix } results.xls"
411406 util .save_results (path , full_results )
412407
413- # Save merge sites (if applicable)
414- if self .localize_merges :
415- df = pd .DataFrame (self .merge_sites )
416- df .to_csv (
417- os .path .join (self .output_dir , "merge_sites.csv" ), index = False
418- )
408+ # Save merge sites
409+ df = pd .DataFrame (self .merge_sites )
410+ path = os .path .join (self .output_dir , "merge_sites.csv" )
411+ df .to_csv (path , index = False )
419412
420413 # Report results overview
421414 path = os .path .join (self .output_dir , f"{ prefix } results-overview.txt" )
@@ -554,7 +547,7 @@ def count_merges(self, key, kdtree):
554547
555548 """
556549 # Iterate over fragments that intersect with GT skeleton
557- for label in self .get_node_labels (key ):
550+ for label in tqdm ( self .get_node_labels (key ), desc = "Merge Search" ):
558551 nodes = self .graphs [key ].nodes_with_label (label )
559552 if len (nodes ) > 40 :
560553 for label in self .label_handler .get_class (label ):
@@ -583,46 +576,64 @@ def is_fragment_merge(self, key, label, kdtree):
583576 None
584577
585578 """
586- # Search graphs
587579 for fragment_graph in self .find_graph_from_label (label ):
588- # Search for merge
589- max_dist = 0
590- min_dist = np .inf
591- for voxel in fragment_graph .voxels :
580+ if fragment_graph .run_length < 10 ** 6 :
581+ # Search for merge
582+ visited = set ()
583+ for leaf in gutil .get_leafs (fragment_graph ):
584+ # Check if leaf is far from ground truth
585+ voxel = fragment_graph .voxels [leaf ]
586+ gt_voxel = util .kdtree_query (kdtree , voxel )
587+ if self .physical_dist (gt_voxel , voxel ) > 50 :
588+ has_merge , visited = self .find_merge_site (
589+ key , kdtree , fragment_graph , leaf , visited
590+ )
591+ if has_merge :
592+ break
593+
594+ # Save fragment (if applicable)
595+ if self .save_fragments :
596+ for node in fragment_graph .nodes :
597+ voxel = fragment_graph .voxels [node ]
598+ gt_voxel = util .kdtree_query (kdtree , voxel )
599+ if self .physical_dist (gt_voxel , voxel ) < 3 :
600+ write_graph (fragment_graph , self .fragment_writer [key ])
601+ break
602+
603+ def find_merge_site (self , key , kdtree , fragment_graph , source , visited ):
604+ for _ , node in nx .dfs_edges (fragment_graph , source = source ):
605+ if node not in visited :
592606 # Find closest point in ground truth
607+ visited .add (node )
608+ voxel = fragment_graph .voxels [node ]
593609 gt_voxel = util .kdtree_query (kdtree , voxel )
594-
595- # Compute projection distance
596- dist = self .physical_dist (gt_voxel , voxel )
597- min_dist = min (dist , min_dist )
598- max_dist = max (dist , max_dist )
599-
600- # Check if distances imply merge mistake
601- if max_dist > 100 and min_dist < 3 :
610+ if self .physical_dist (gt_voxel , voxel ) < 2 :
602611 # Log merge mistake
603- equiv_label = self . label_handler . get ( label )
612+ segment_id = util . get_segment_id ( fragment_graph . filename )
604613 xyz = img_util .to_physical (voxel , self .anisotropy )
605614 self .merge_cnt [key ] += 1
606- self .merged_labels .add ((key , equiv_label , tuple (xyz )))
615+ self .merged_labels .add ((key , segment_id , xyz ))
616+ self .merge_sites .append (
617+ {
618+ "Segment_ID" : segment_id ,
619+ "Voxel" : voxel ,
620+ "XYZ" : xyz ,
621+ }
622+ )
607623
608624 # Save merged fragment (if applicable)
609625 if self .save_merges :
610- fragment_graph .to_zipped_swc (self .merge_writer )
611- if f"{ key } .swc" not in self .merge_writer .namelist ():
612- self .gt_graphs [key ].to_zipped_swc (
613- self .merge_writer
614- )
626+ # Save graphs
627+ write_graph (self .gt_graphs [key ], self .merge_writer )
628+ write_graph (fragment_graph , self .merge_writer )
615629
616- # Find approximate merge site
617- if self .localize_merges :
618- self .find_merge_site (key , fragment_graph , kdtree )
619-
620- break
621-
622- # Save fragment (if applicable)
623- if self .save_fragments and min_dist < 3 :
624- if fragment_graph .filename not in self .merge_writer .namelist ():
625- fragment_graph .to_zipped_swc (self .fragment_writer [key ])
630+ # Save merge site
631+ merge_cnt = np .sum (list (self .merge_cnt .values ()))
632+ swc_util .to_zipped_point (
633+ self .merge_writer , f"merge-{ merge_cnt } .swc" , xyz
634+ )
635+ return True , visited
636+ return False , visited
626637
627638 def adjust_metrics (self , key ):
628639 """
@@ -718,45 +729,6 @@ def process_merge(self, key, label, xyz, update_merged_labels=True):
718729 if update_merged_labels :
719730 self .merged_labels .add ((key , label , - 1 ))
720731
721- def find_merge_site (self , key , fragment_graph , kdtree ):
722- visited = set ()
723- hit = False
724- for i , voxel in enumerate (fragment_graph .voxels ):
725- # Find closest point in ground truth
726- visited .add (i )
727- gt_voxel = util .kdtree_query (kdtree , voxel )
728-
729- # Compute projection distance
730- if self .physical_dist (gt_voxel , voxel ) > 100 :
731- for _ , j in nx .dfs_edges (fragment_graph , source = i ):
732- visited .add (j )
733- voxel_j = fragment_graph .voxels [j ]
734- gt_voxel = util .kdtree_query (kdtree , voxel_j )
735- if self .physical_dist (gt_voxel , voxel_j ) < 2 :
736- # Save merge swc
737- hit = True
738- merge_cnt = np .sum (list (self .merge_cnt .values ()))
739- filename = f"merge-{ merge_cnt } .swc"
740- xyz = img_util .to_physical (voxel_j , self .anisotropy )
741- swc_util .to_zipped_point (
742- self .merge_writer , filename , xyz
743- )
744-
745- # Save merge in list
746- segment_id = util .get_segment_id (fragment_graph .filename )
747- self .merge_sites .append (
748- {
749- "Segment_ID" : segment_id ,
750- "Voxel" : voxel_j ,
751- "XYZ" : xyz ,
752- }
753- )
754- break
755-
756- # Check whether to continue
757- if hit :
758- break
759-
760732 def quantify_merges (self ):
761733 """
762734 Computes the percentage of merged edges for each graph.
@@ -775,30 +747,6 @@ def quantify_merges(self):
775747 n_edges = self .graphs [key ].graph ["n_edges" ]
776748 self .merged_percent [key ] = self .merged_edges_cnt [key ] / n_edges
777749
778- def save_merged_labels (self ):
779- """
780- Saves merged labels and their corresponding coordinates to a text
781- file.
782-
783- Parameters
784- ----------
785- None
786-
787- Returns
788- -------
789- None
790-
791- """
792- # Save detected merges
793- prefix = "corrected_" if self .connections_path else ""
794- filename = f"merged_{ prefix } segment_ids.txt"
795- with open (os .path .join (self .output_dir , filename ), "w" ) as f :
796- f .write (f" Label - Physical Coordinate\n " )
797- for _ , label , xyz in self .merged_labels :
798- if self .label_handler .use_mapping ():
799- label = self .get_merged_label (label )
800- f .write (f" { label } - { xyz } \n " )
801-
802750 def get_merged_label (self , label ):
803751 """
804752 Retrieves the label present in the corrected fragments that
@@ -1127,3 +1075,8 @@ def generate_result(keys, stats):
11271075
11281076 """
11291077 return [stats [key ] for key in keys ]
1078+
1079+
1080+ def write_graph (graph , writer ):
1081+ if graph .filename not in writer .namelist ():
1082+ graph .to_zipped_swc (writer )
0 commit comments