1212 ProcessPoolExecutor ,
1313 ThreadPoolExecutor ,
1414)
15- from copy import deepcopy
1615from scipy .spatial import distance , KDTree
17- from time import time
1816from tqdm import tqdm
1917from zipfile import ZipFile
2018
@@ -118,7 +116,7 @@ def __init__(
118116
119117 # Load data
120118 self .label_mask = pred_labels
121- self .load_groundtruth (gt_pointer )
119+ self .load_groundtruth (gt_pointer , valid_labels )
122120 self .load_fragments (fragments_pointer )
123121
124122 # Initialize writer
@@ -127,7 +125,7 @@ def __init__(
127125 self .init_zip_writer ()
128126
129127 # --- Load Data ---
130- def load_groundtruth (self , swc_pointer ):
128+ def load_groundtruth (self , swc_pointer , valid_labels ):
131129 """
132130 Initializes "self.graphs" by iterating over "paths" which corresponds
133131 to neurons in the ground truth.
@@ -149,16 +147,13 @@ def load_groundtruth(self, swc_pointer):
149147 anisotropy = self .anisotropy ,
150148 label_mask = self .label_mask ,
151149 use_anisotropy = False ,
150+ valid_labels = valid_labels ,
152151 )
153152 self .graphs = graph_builder .run (swc_pointer )
154153
155154 # Label nodes
156- self .key_to_label_to_nodes = dict () # {id: {label: nodes}}
157155 for key in tqdm (self .graphs , desc = "Labeling Graphs" ):
158156 self .label_graphs (key )
159- self .key_to_label_to_nodes [key ] = gutil .init_label_to_nodes (
160- self .graphs [key ]
161- )
162157
163158 def load_fragments (self , swc_pointer ):
164159 print ("\n (2) Load Fragments" )
@@ -220,10 +215,12 @@ def label_graphs(self, key, batch_size=128):
220215 threads .append (executor .submit (self .get_patch_labels , key , batch ))
221216
222217 # Process results
218+ n_nodes = self .graphs [key ].number_of_nodes ()
219+ self .graphs [key ].graph ["label" ] = np .zeros ((n_nodes ), dtype = int )
223220 for thread in as_completed (threads ):
224221 node_to_label = thread .result ()
225222 for i , label in node_to_label .items ():
226- self .graphs [key ].nodes [ i ]. update ({ "label" : label })
223+ self .graphs [key ].graph [ "label" ][ i ] = label
227224
228225 def get_patch_labels (self , key , nodes ):
229226 # Get bounding box
@@ -287,11 +284,11 @@ def get_node_labels(self, key, inverse_bool=False):
287284 """
288285 if inverse_bool :
289286 output = set ()
290- for l in self .key_to_label_to_nodes [key ].keys ():
287+ for l in self .graphs [key ].get_labels ():
291288 output = output .union (self .label_handler .inverse_mapping [l ])
292289 return output
293290 else :
294- return set ( self .key_to_label_to_nodes [key ].keys () )
291+ return self .graphs [key ].get_labels ( )
295292
296293 def init_zip_writer (self ):
297294 """
@@ -372,9 +369,9 @@ def adjust_metrics(self, key):
372369 """
373370 for label in self .preexisting_merges :
374371 label = self .label_map [label ] if self .label_map else label
375- if label in self .key_to_label_to_nodes [key ].keys ():
372+ if label in self .graphs [key ].get_labels ():
376373 # Extract subgraph
377- nodes = deepcopy ( self .key_to_label_to_nodes [key ][ label ] )
374+ nodes = self .graphs [key ]. nodes_with_label ( label )
378375 subgraph = self .graphs [key ].subgraph (nodes )
379376
380377 # Adjust metrics
@@ -385,7 +382,6 @@ def adjust_metrics(self, key):
385382
386383 # Update graph
387384 self .graphs [key ].remove_nodes_from (nodes )
388- del self .key_to_label_to_nodes [key ][label ]
389385
390386 # -- Split Detection --
391387 def detect_splits (self ):
@@ -402,7 +398,6 @@ def detect_splits(self):
402398 None
403399
404400 """
405- t0 = time ()
406401 pbar = tqdm (total = len (self .graphs ), desc = "Split Detection" )
407402 with ProcessPoolExecutor () as executor :
408403 # Assign processes
@@ -420,17 +415,10 @@ def detect_splits(self):
420415 self .split_percent = dict ()
421416 for process in as_completed (processes ):
422417 key , graph , split_percent = process .result ()
423- self .graphs [key ] = gutil .delete_nodes (graph , 0 )
424- self .key_to_label_to_nodes [key ] = gutil .init_label_to_nodes (
425- self .graphs [key ]
426- )
418+ self .graphs [key ] = gutil .remove_nodes (graph , 0 )
427419 self .split_percent [key ] = split_percent
428420 pbar .update (1 )
429421
430- # Report runtime
431- t , unit = util .time_writer (time () - t0 )
432- print (f"Runtime: { round (t , 2 )} { unit } \n " )
433-
434422 def quantify_splits (self ):
435423 """
436424 Counts the number of splits, number of omit edges, and percent of omit
@@ -449,9 +437,11 @@ def quantify_splits(self):
449437 self .omit_cnts = dict ()
450438 self .omit_percent = dict ()
451439 for key in self .graphs :
440+ # Get counts
452441 n_pred_edges = self .graphs [key ].number_of_edges ()
453442 n_target_edges = self .graphs [key ].graph ["n_edges" ]
454443
444+ # Compute stats
455445 self .split_cnt [key ] = gutil .count_splits (self .graphs [key ])
456446 self .omit_cnts [key ] = n_target_edges - n_pred_edges
457447 self .omit_percent [key ] = 1 - n_pred_edges / n_target_edges
@@ -517,7 +507,8 @@ def count_merges(self, key, kdtree):
517507
518508 """
519509 for label in self .get_node_labels (key ):
520- if len (self .key_to_label_to_nodes [key ][label ]) > MIN_CNT :
510+ nodes = self .graphs [key ].nodes_with_label (label )
511+ if len (nodes ) > MIN_CNT :
521512 for label in self .label_handler .get_class (label ):
522513 if label in self .fragment_graphs :
523514 self .is_fragment_merge (key , label , kdtree )
@@ -581,8 +572,8 @@ def find_label_intersections(self):
581572 keys = frozenset ((key_1 , key_2 ))
582573 if key_1 != key_2 and keys not in visited :
583574 visited .add (keys )
584- labels_1 = self .get_node_labels ( key_1 )
585- labels_2 = self .get_node_labels ( key_2 )
575+ labels_1 = set ( self .graphs [ key_1 ]. get_labels () )
576+ labels_2 = set ( self .graphs [ key_2 ]. get_labels () )
586577 for label in labels_1 .intersection (labels_2 ):
587578 label_intersections .add ((keys , label ))
588579 return label_intersections
@@ -605,15 +596,14 @@ def process_merge(self, key, label, xyz, update_merged_labels=True):
605596 None
606597
607598 """
608- if label in self .key_to_label_to_nodes [key ]:
599+ if label in self .graphs [key ]. get_labels () :
609600 # Compute metrics
610- nodes = list ( self .key_to_label_to_nodes [key ][ label ] )
601+ nodes = self .graphs [key ]. nodes_with_label ( label )
611602 subgraph = self .graphs [key ].subgraph (nodes )
612603 self .merged_edges_cnt [key ] += subgraph .number_of_edges ()
613604
614605 # Update self
615606 self .graphs [key ].remove_nodes_from (nodes )
616- del self .key_to_label_to_nodes [key ][label ]
617607 if update_merged_labels :
618608 self .merged_labels .add ((key , label , - 1 ))
619609
0 commit comments