77
88"""
99
10- from collections import deque
1110from concurrent .futures import (
1211 as_completed ,
1312 ProcessPoolExecutor ,
@@ -112,11 +111,13 @@ def __init__(
112111 self .output_dir = output_dir
113112 self .preexisting_merges = preexisting_merges
114113
114+ # Label handler
115+ self .label_handler = gutil .LabelHandler (
116+ connections_path = connections_path , valid_labels = valid_labels
117+ )
118+
115119 # Load data
116- assert isinstance (valid_labels , set ) if valid_labels else True
117120 self .label_mask = pred_labels
118- self .valid_labels = valid_labels
119- self .init_label_map (connections_path )
120121 self .load_groundtruth (gt_pointer )
121122 self .load_fragments (fragments_pointer )
122123
@@ -126,31 +127,6 @@ def __init__(
126127 self .init_zip_writer ()
127128
128129 # --- Load Data ---
129- def init_label_map (self , path ):
130- """
131- Initializes a dictionary that maps a label to its equivalent label in
132- the case where "connections_path" is provided.
133-
134- Parameters
135- ----------
136- path : str
137- Path to a txt file containing pairs of segment ids of segments
138- that were merged into a single segment.
139-
140- Returns
141- -------
142- None
143-
144- """
145- if path :
146- assert self .valid_labels is not None , "Must provide valid labels!"
147- self .label_map , self .inverse_label_map = util .init_label_map (
148- path , self .valid_labels
149- )
150- else :
151- self .label_map = None
152- self .inverse_label_map = None
153-
154130 def load_groundtruth (self , swc_pointer ):
155131 """
156132 Initializes "self.graphs" by iterating over "paths" which corresponds
@@ -265,50 +241,10 @@ def get_patch_labels(self, key, nodes):
265241 node_to_label = dict ()
266242 for i in nodes :
267243 voxel = self .to_local_voxels (key , i , bbox ["min" ])
268- label = self .adjust_label (label_patch [voxel ])
244+ label = self .label_handler . get (label_patch [voxel ])
269245 node_to_label [i ] = label
270246 return node_to_label
271247
272- def adjust_label (self , label ):
273- """
274- Gets label of voxel in "self.label_mask".
275-
276- Parameters
277- ----------
278- i : int
279- Node ID.
280- voxel : numpy.ndarray
281- Image coordinate of voxel to be read.
282-
283- Returns
284- -------
285- int
286- Label of voxel.
287-
288- """
289- if self .label_map :
290- label = self .get_equivalent_label (label )
291- elif self .valid_labels :
292- label = 0 if label not in self .valid_labels else label
293- return label
294-
295- def get_equivalent_label (self , label ):
296- """
297- Gets the equivalence class label corresponding to "label".
298-
299- Parameters
300- ----------
301- label : int
302- Label to be checked.
303-
304- Returns
305- -------
306- label
307- Equivalence class label.
308-
309- """
310- return self .label_map [label ] if label in self .label_map else 0
311-
312248 def get_all_node_labels (self ):
313249 """
314250 Gets the a set of all unique labels from all graphs in "self.graphs".
@@ -324,7 +260,7 @@ def get_all_node_labels(self):
324260
325261 """
326262 all_labels = set ()
327- inverse_bool = True if self .inverse_label_map else False
263+ inverse_bool = self .label_handler . use_mapping ()
328264 for key in self .graphs :
329265 labels = self .get_node_labels (key , inverse_bool = inverse_bool )
330266 all_labels = all_labels .union (labels )
@@ -352,7 +288,7 @@ def get_node_labels(self, key, inverse_bool=False):
352288 if inverse_bool :
353289 output = set ()
354290 for l in self .key_to_label_to_nodes [key ].keys ():
355- output = output .union (self .inverse_label_map [l ])
291+ output = output .union (self .label_handler . inverse_mapping [l ])
356292 return output
357293 else :
358294 return set (self .key_to_label_to_nodes [key ].keys ())
@@ -404,7 +340,7 @@ def run(self):
404340 self .detect_splits ()
405341 self .quantify_splits ()
406342
407- # Check whether to delete prexisting merges
343+ # Check for prexisting merges
408344 if self .preexisting_merges :
409345 for key in self .graphs :
410346 self .adjust_metrics (key )
@@ -467,7 +403,7 @@ def detect_splits(self):
467403
468404 """
469405 t0 = time ()
470- pbar = tqdm (total = len (self .graphs ), desc = "Split Detection: " )
406+ pbar = tqdm (total = len (self .graphs ), desc = "Split Detection" )
471407 with ProcessPoolExecutor () as executor :
472408 # Assign processes
473409 processes = list ()
@@ -543,7 +479,7 @@ def detect_merges(self):
543479
544480 # Count total merges
545481 if self .fragment_graphs :
546- pbar = tqdm (total = len (self .graphs ), desc = "Count Merges: " )
482+ pbar = tqdm (total = len (self .graphs ), desc = "Merge Detection " )
547483 for key , graph in self .graphs .items ():
548484 if graph .number_of_nodes () > 0 :
549485 kdtree = KDTree (graph .graph ["voxel" ])
@@ -582,14 +518,7 @@ def count_merges(self, key, kdtree):
582518 """
583519 for label in self .get_node_labels (key ):
584520 if len (self .key_to_label_to_nodes [key ][label ]) > MIN_CNT :
585- # Check whether to compute label inverse
586- if self .inverse_label_map :
587- labels = deepcopy (self .inverse_label_map [label ])
588- else :
589- labels = [label ]
590-
591- # Check if fragment is a merge mistake
592- for label in labels :
521+ for label in self .label_handler .get_class (label ):
593522 if label in self .fragment_graphs :
594523 self .is_fragment_merge (key , label , kdtree )
595524
@@ -616,16 +545,13 @@ def is_fragment_merge(self, key, label, kdtree):
616545 """
617546 for voxel in self .fragment_graphs [label ].graph ["voxel" ]:
618547 if kdtree .query (voxel )[0 ] > MERGE_DIST_THRESHOLD :
619- # Check whether to get inverse of label
620- if self .inverse_label_map :
621- equivalent_label = self .label_map [label ]
622- else :
623- equivalent_label = label
624-
625- # Record merge mistake
548+ # Log merge mistake
549+ equiv_label = self .label_handler .get (label )
626550 xyz = img_util .to_physical (voxel , self .anisotropy )
627551 self .merge_cnt [key ] += 1
628- self .merged_labels .add ((key , equivalent_label , tuple (xyz )))
552+ self .merged_labels .add ((key , equiv_label , tuple (xyz )))
553+
554+ # Save merged fragment (if applicable)
629555 if self .save_projections and label in self .fragment_graphs :
630556 swc_util .to_zipped_swc (
631557 self .zip_writer [key ], self .fragment_graphs [label ]
@@ -729,7 +655,7 @@ def save_merged_labels(self):
729655 with open (os .path .join (self .output_dir , filename ), "w" ) as f :
730656 f .write (f" Label - xyz\n " )
731657 for _ , label , xyz in self .merged_labels :
732- if self .connections_path :
658+ if self .label_handler . use_mapping () :
733659 label = self .get_merged_label (label )
734660 f .write (f" { label } - { xyz } \n " )
735661
@@ -749,11 +675,11 @@ def get_merged_label(self, label):
749675 -------
750676 str or list
751677 The first matching label found in "self.fragment_graphs.keys()" or
752- the original associated labels from "inverse_label_map" if no
678+ the original associated labels from "inverse_label_map" if no
753679 matches are found.
754680
755681 """
756- for l in self .inverse_label_map [ label ] :
682+ for l in self .label_handler . get_class ( label ) :
757683 if l in self .fragment_graphs .keys ():
758684 return l
759685 return self .inverse_label_map [label ]
@@ -852,8 +778,8 @@ def generate_avg_results(self):
852778
853779 def avg_result (self , stats ):
854780 """
855- Averages the values computed across "self.graphs" for
856- a given metric stored in "stats".
781+ Averages the values computed across "self.graphs" for a given metric
782+ stored in "stats".
857783
858784 Parameters
859785 ----------
0 commit comments