@@ -66,7 +66,7 @@ def __init__(
6666 localize_merge = False ,
6767 preexisting_merges = None ,
6868 save_merges = False ,
69- save_projections = False ,
69+ save_fragments = False ,
7070 valid_labels = None ,
7171 ):
7272 """
@@ -103,7 +103,7 @@ def __init__(
103103 save_merges: bool, optional
104104 Indication of whether to save fragments with a merge mistake. The
105105 default is None.
106- save_projections : bool, optional
106+ save_fragments : bool, optional
107107 Indication of whether to save fragments that project onto each
108108 ground truth skeleton. The default is False.
109109 valid_labels : set[int], optional
@@ -123,7 +123,7 @@ def __init__(
123123 self .output_dir = output_dir
124124 self .preexisting_merges = preexisting_merges
125125 self .save_merges = save_merges
126- self .save_projections = save_projections
126+ self .save_fragments = save_fragments
127127
128128 # Label handler
129129 self .label_handler = gutil .LabelHandler (
@@ -135,9 +135,8 @@ def __init__(
135135 self .load_groundtruth (gt_pointer )
136136 self .load_fragments (fragments_pointer )
137137
138- # Initialize writer
139- if self .save_merges :
140- self .init_zip_writer ()
138+ # Initialize writers
139+ self .init_zip_writers ()
141140
142141 # Initialize fragment projections directory
143142 if self .save_projections :
@@ -346,9 +345,9 @@ def get_node_labels(self, key, inverse_bool=False):
346345 else :
347346 return self .graphs [key ].get_labels ()
348347
349- def init_zip_writer (self ):
348+ def init_zip_writers (self ):
350349 """
351- Initializes "self.zip_writer " attribute by setting up a directory for
350+ Initializes "self.merge_writer " attribute by setting up a directory for
352351 output files and creating ZIP files for each graph in "self.graphs".
353352
354353 Parameters
@@ -360,16 +359,31 @@ def init_zip_writer(self):
360359 None
361360
362361 """
363- # Initialize output directory
364- merged_fragments_dir = os .path .join (self .output_dir , "merged_fragments" )
365- util .mkdir (merged_fragments_dir )
366-
367- # Save intial graphs
368- self .zip_writer = dict ()
369- for key in self .graphs .keys ():
370- zip_path = f"{ merged_fragments_dir } /{ key } .zip"
371- self .zip_writer [key ] = ZipFile (zip_path , "w" )
372- self .graphs [key ].to_zipped_swc (self .zip_writer [key ])
362+ # Merged fragments zip writer
363+ if self .save_merges :
364+ # Initialize directory
365+ merges_dir = os .path .join (self .output_dir , "merged_fragments" )
366+ util .mkdir (merged_fragments_dir )
367+
368+ # Initialize zip writer
369+ self .merge_writer = dict ()
370+ for key in self .graphs .keys ():
371+ zip_path = f"{ merged_fragments_dir } /{ key } .zip"
372+ self .merge_writer [key ] = ZipFile (zip_path , "w" )
373+ self .graphs [key ].to_zipped_swc (self .merge_writer [key ])
374+
375+ # Fragments zip writer
376+ if self .save_fragments :
377+ # Initialize direction
378+ fragments_dir = os .path .join (self .output_dir , "fragments" )
379+ util .mkdir (fragments_dir )
380+
381+ # Initialize zip writer
382+ self .fragment_writer = dict ()
383+ for key in self .graphs .keys ():
384+ zip_path = f"{ fragments_dir } /{ key } .zip"
385+ self .fragment_writer [key ] = ZipFile (zip_path , "w" )
386+ self .graphs [key ].to_zipped_swc (self .fragment_writer [key ])
373387
374388 # -- Main Routine --
375389 def run (self ):
@@ -539,12 +553,6 @@ def count_merges(self, key, kdtree):
539553 None
540554
541555 """
542- # Initialize zip writer
543- if self .save_projections :
544- zip_path = os .path .join (self .projections_dir , key + ".zip" )
545- zip_writer = ZipFile (zip_path , "w" )
546- #self.graphs[key].to_zipped_swc(zip_writer)
547-
548556 # Iterate over fragments that intersect with GT skeleton
549557 for label in self .get_node_labels (key ):
550558 nodes = self .graphs [key ].nodes_with_label (label )
@@ -580,6 +588,7 @@ def is_fragment_merge(self, key, label, kdtree):
580588 """
581589 # Search graphs
582590 for fragment_graph in self .find_graph_from_label (label ):
591+ # Search for merge
583592 max_dist = 0
584593 min_dist = np .inf
585594 for voxel in fragment_graph .voxels :
@@ -601,11 +610,15 @@ def is_fragment_merge(self, key, label, kdtree):
601610
602611 # Save merged fragment (if applicable)
603612 if self .save_merges :
604- fragment_graph .to_zipped_swc (self .zip_writer [key ])
613+ fragment_graph .to_zipped_swc (self .merge_writer [key ])
605614 if self .localize_merge :
606615 self .find_merge_site (key , fragment_graph , kdtree )
607616 break
608617
618+ # Save fragment (if applicable)
619+ if self .save_fragments and min_dist < 3 :
620+ fragment_graph .to_zipped_swc (self .fragment_writer [key ])
621+
609622 def adjust_metrics (self , key ):
610623 """
611624 Adjusts the metrics of the graph associated with the given key by
0 commit comments