@@ -66,7 +66,7 @@ def __init__(
6666 localize_merge = False ,
6767 preexisting_merges = None ,
6868 save_merges = False ,
69- save_projections = False ,
69+ save_fragments = False ,
7070 valid_labels = None ,
7171 ):
7272 """
@@ -103,7 +103,7 @@ def __init__(
103103 save_merges: bool, optional
104104 Indication of whether to save fragments with a merge mistake. The
105105 default is None.
106- save_projections : bool, optional
106+ save_fragments : bool, optional
107107 Indication of whether to save fragments that project onto each
108108 ground truth skeleton. The default is False.
109109 valid_labels : set[int], optional
@@ -123,7 +123,7 @@ def __init__(
123123 self .output_dir = output_dir
124124 self .preexisting_merges = preexisting_merges
125125 self .save_merges = save_merges
126- self .save_projections = save_projections
126+ self .save_fragments = save_fragments
127127
128128 # Label handler
129129 self .label_handler = gutil .LabelHandler (
@@ -135,14 +135,8 @@ def __init__(
135135 self .load_groundtruth (gt_pointer )
136136 self .load_fragments (fragments_pointer )
137137
138- # Initialize writer
139- if self .save_merges :
140- self .init_zip_writer ()
141-
142- # Initialize fragment projections directory
143- if self .save_projections :
144- self .projections_dir = os .path .join (output_dir , "projections" )
145- util .mkdir (self .projections_dir )
138+ # Initialize writers
139+ self .init_zip_writers ()
146140
147141 # --- Load Data ---
148142 def load_groundtruth (self , swc_pointer ):
@@ -346,9 +340,9 @@ def get_node_labels(self, key, inverse_bool=False):
346340 else :
347341 return self .graphs [key ].get_labels ()
348342
349- def init_zip_writer (self ):
343+ def init_zip_writers (self ):
350344 """
351- Initializes "self.zip_writer " attribute by setting up a directory for
345+ Initializes "self.merge_writer " attribute by setting up a directory for
352346 output files and creating ZIP files for each graph in "self.graphs".
353347
354348 Parameters
@@ -360,16 +354,31 @@ def init_zip_writer(self):
360354 None
361355
362356 """
363- # Initialize output directory
364- merged_fragments_dir = os .path .join (self .output_dir , "merged_fragments" )
365- util .mkdir (merged_fragments_dir )
366-
367- # Save intial graphs
368- self .zip_writer = dict ()
369- for key in self .graphs .keys ():
370- zip_path = f"{ merged_fragments_dir } /{ key } .zip"
371- self .zip_writer [key ] = ZipFile (zip_path , "w" )
372- self .graphs [key ].to_zipped_swc (self .zip_writer [key ])
357+ # Merged fragments zip writer
358+ if self .save_merges :
359+ # Initialize directory
360+ merges_dir = os .path .join (self .output_dir , "merged_fragments" )
361+ util .mkdir (merged_fragments_dir )
362+
363+ # Initialize zip writer
364+ self .merge_writer = dict ()
365+ for key in self .graphs .keys ():
366+ zip_path = f"{ merged_fragments_dir } /{ key } .zip"
367+ self .merge_writer [key ] = ZipFile (zip_path , "w" )
368+ self .graphs [key ].to_zipped_swc (self .merge_writer [key ])
369+
370+ # Fragments zip writer
371+ if self .save_fragments :
372+ # Initialize direction
373+ fragments_dir = os .path .join (self .output_dir , "fragments" )
374+ util .mkdir (fragments_dir )
375+
376+ # Initialize zip writer
377+ self .fragment_writer = dict ()
378+ for key in self .graphs .keys ():
379+ zip_path = f"{ fragments_dir } /{ key } .zip"
380+ self .fragment_writer [key ] = ZipFile (zip_path , "w" )
381+ self .graphs [key ].to_zipped_swc (self .fragment_writer [key ])
373382
374383 # -- Main Routine --
375384 def run (self ):
@@ -539,22 +548,13 @@ def count_merges(self, key, kdtree):
539548 None
540549
541550 """
542- # Initialize zip writer
543- if self .save_projections :
544- zip_path = os .path .join (self .projections_dir , key + ".zip" )
545- zip_writer = ZipFile (zip_path , "w" )
546- #self.graphs[key].to_zipped_swc(zip_writer)
547-
548551 # Iterate over fragments that intersect with GT skeleton
549552 for label in self .get_node_labels (key ):
550553 nodes = self .graphs [key ].nodes_with_label (label )
551554 if len (nodes ) > 40 :
552555 for label in self .label_handler .get_class (label ):
553556 if label in self .fragment_ids :
554557 self .is_fragment_merge (key , label , kdtree )
555- if self .save_projections :
556- fragment_graph = self .find_graph_from_label (label )[0 ]
557- fragment_graph .to_zipped_swc (zip_writer )
558558
559559 def is_fragment_merge (self , key , label , kdtree ):
560560 """
@@ -580,6 +580,7 @@ def is_fragment_merge(self, key, label, kdtree):
580580 """
581581 # Search graphs
582582 for fragment_graph in self .find_graph_from_label (label ):
583+ # Search for merge
583584 max_dist = 0
584585 min_dist = np .inf
585586 for voxel in fragment_graph .voxels :
@@ -601,11 +602,15 @@ def is_fragment_merge(self, key, label, kdtree):
601602
602603 # Save merged fragment (if applicable)
603604 if self .save_merges :
604- fragment_graph .to_zipped_swc (self .zip_writer [key ])
605+ fragment_graph .to_zipped_swc (self .merge_writer [key ])
605606 if self .localize_merge :
606607 self .find_merge_site (key , fragment_graph , kdtree )
607608 break
608609
610+ # Save fragment (if applicable)
611+ if self .save_fragments and min_dist < 3 :
612+ fragment_graph .to_zipped_swc (self .fragment_writer [key ])
613+
609614 def adjust_metrics (self , key ):
610615 """
611616 Adjusts the metrics of the graph associated with the given key by
0 commit comments