77
88"""
99
10-
10+ from collections import deque
1111from concurrent .futures import (
1212 as_completed ,
1313 ProcessPoolExecutor ,
3131 util
3232)
3333
34- MERGE_DIST_THRESHOLD = 100
34+ MERGE_DIST_THRESHOLD = 200
3535MIN_CNT = 40
3636
3737
@@ -112,15 +112,19 @@ def __init__(
112112 self .output_dir = output_dir
113113 self .preexisting_merges = preexisting_merges
114114
115- # Load Data
116- print ("\n (1) Load Data " )
115+ # Load ground truth
116+ print ("\n (1) Load Ground Truth " )
117117 assert type (valid_labels ) is set if valid_labels else True
118- self .label_mask = pred_labels
119118 self .valid_labels = valid_labels
120119 self .init_label_map (connections_path )
121120 self .init_graphs (gt_pointer )
121+
122+ print ("\n (2) Load Prediction" )
123+ self .label_mask = pred_labels
122124 if fragments_pointer :
123125 self .load_fragments (fragments_pointer )
126+ else :
127+ self .fragment_graphs = None
124128
125129 # Initialize writer
126130 self .save_projections = save_projections
@@ -160,7 +164,7 @@ def init_graphs(self, paths):
160164
161165 Parameters
162166 ----------
163- paths : list [str]
167+ paths : List [str]
164168 List of paths to swc files which correspond to neurons in the
165169 ground truth.
166170
@@ -170,18 +174,35 @@ def init_graphs(self, paths):
170174
171175 """
172176 # Build graphs
173- self . graphs = swc_util .Reader ().load (paths )
174- self .fragment_graphs = None
177+ swc_dicts = swc_util .Reader ().load (paths )
178+ self .graphs = self . build_graphs ( swc_dicts )
175179
176180 # Label nodes
177181 self .key_to_label_to_nodes = dict () # {id: {label: nodes}}
178182 for key in tqdm (self .graphs , desc = "Labeling Graphs" ):
179- self .set_node_labels (key )
183+ self .label_graphs (key )
180184 self .key_to_label_to_nodes [key ] = gutil .init_label_to_nodes (
181185 self .graphs [key ]
182186 )
183187
184- def set_node_labels (self , key , batch_size = 128 ):
188+ def build_graphs (self , swc_dicts ):
189+ graphs = dict ()
190+ with ProcessPoolExecutor () as executor :
191+ # Assign processes
192+ processes = list ()
193+ for swc_dict in swc_dicts :
194+ processes .append (
195+ executor .submit (gutil .to_graph , swc_dict )
196+ )
197+
198+ # Store results
199+ pbar = tqdm (total = len (processes ), desc = "Build Graphs" )
200+ for process in as_completed (processes ):
201+ graphs .update (process .result ())
202+ pbar .update (1 )
203+ return graphs
204+
205+ def label_graphs (self , key , batch_size = 128 ):
185206 """
186207 Iterates over nodes in "graph" and stores the corresponding label from
187208 predicted segmentation mask (i.e. "self.label_mask") as a node-level
@@ -238,7 +259,7 @@ def get_patch_labels(self, key, nodes):
238259 # Get bounding box
239260 bbox = {"min" : [np .inf , np .inf , np .inf ], "max" : [0 , 0 , 0 ]}
240261 for i in nodes :
241- voxel = deepcopy (self .graphs [key ].nodes [ i ][ "voxel" ])
262+ voxel = deepcopy (self .graphs [key ].graph [ "voxel" ][ i ])
242263 for idx in range (3 ):
243264 if voxel [idx ] < bbox ["min" ][idx ]:
244265 bbox ["min" ][idx ] = voxel [idx ]
@@ -359,20 +380,20 @@ def load_fragments(self, fragments_pointer):
359380 Dictionary that maps an swc id to the fragment graph.
360381
361382 """
362- # Read fragments
363- reader = swc_util .Reader (anisotropy = self .anisotropy , min_size = 40 )
364- fragment_graphs = reader .load (fragments_pointer )
365- self . fragment_ids = set ( fragment_graphs . keys ())
366-
367- # Filter fragments
368- self . fragment_graphs = dict ()
369- for label in self . get_all_node_labels () :
370- if label in fragment_graphs :
371- self . fragment_graphs [ label ] = fragment_graphs [ label ]
372- else :
373- self . fragment_graphs [ label ] = nx . Graph (
374- filename = f" { label } .swc" , run_length = 0 , n_edges = 1
375- )
383+ # Read SWC files
384+ reader = swc_util .Reader (anisotropy = self .anisotropy )
385+ swc_dicts = deque ( reader .load (fragments_pointer ) )
386+
387+ # Filter SWC files
388+ filtered_swc_dicts = list ()
389+ labels = self . get_all_node_labels ()
390+ while len ( swc_dicts ) > 0 :
391+ swc_dict = swc_dicts . popleft ()
392+ swc_id = int ( swc_dict [ "swc_id" ])
393+ if swc_id in labels :
394+ swc_dict [ "swc_id" ] = swc_id
395+ filtered_swc_dicts . append ( swc_dict )
396+ self . fragment_graphs = self . build_graphs ( filtered_swc_dicts )
376397 print ("# Fragments:" , len (self .fragment_graphs ))
377398
378399 def init_zip_writer (self ):
@@ -416,7 +437,7 @@ def run(self):
416437 ...
417438
418439 """
419- print ("\n (2 ) Evaluation" )
440+ print ("\n (3 ) Evaluation" )
420441
421442 # Split evaluation
422443 self .detect_splits ()
@@ -564,16 +585,14 @@ def detect_merges(self):
564585 pbar = tqdm (total = len (self .graphs ), desc = "Count Merges:" )
565586 for key , graph in self .graphs .items ():
566587 if graph .number_of_nodes () > 0 :
567- kdtree = KDTree (gutil . to_array ( graph ) )
588+ kdtree = KDTree (graph . graph [ "voxel" ] )
568589 self .count_merges (key , kdtree )
569590 pbar .update (1 )
570591
571592 # Process merges
572- pbar = tqdm (total = len (self .graphs ), desc = "Compute Percent Merged:" )
573593 for (key_1 , key_2 ), label in self .find_label_intersections ():
574594 self .process_merge (key_1 , label , - 1 )
575595 self .process_merge (key_2 , label , - 1 )
576- pbar .update (1 )
577596
578597 for key , label , xyz in self .merged_labels :
579598 self .process_merge (key , label , xyz , update_merged_labels = False )
@@ -610,8 +629,8 @@ def count_merges(self, key, kdtree):
610629
611630 # Check if fragment is a merge mistake
612631 for label in labels :
613- rl = self .fragment_graphs [ label ]. graph [ "run_length" ]
614- self .is_fragment_merge (key , label , kdtree )
632+ if label in self .fragment_graphs :
633+ self .is_fragment_merge (key , label , kdtree )
615634
616635 def is_fragment_merge (self , key , label , kdtree ):
617636 """
@@ -634,7 +653,7 @@ def is_fragment_merge(self, key, label, kdtree):
634653 None
635654
636655 """
637- for voxel in gutil . to_array ( self .fragment_graphs [label ])[:: 2 ]:
656+ for voxel in self .fragment_graphs [label ]. graph [ "voxel" ]:
638657 if kdtree .query (voxel )[0 ] > MERGE_DIST_THRESHOLD :
639658 # Check whether to get inverse of label
640659 if self .inverse_label_map :
@@ -643,10 +662,10 @@ def is_fragment_merge(self, key, label, kdtree):
643662 equivalent_label = label
644663
645664 # Record merge mistake
646- xyz = img_util .to_physical (voxel )
665+ xyz = img_util .to_physical (voxel , self . anisotropy )
647666 self .merge_cnt [key ] += 1
648667 self .merged_labels .add ((key , equivalent_label , tuple (xyz )))
649- if self .save_projections :
668+ if self .save_projections and label in self . fragment_graphs :
650669 swc_util .to_zipped_swc (
651670 self .zip_writer [key ], self .fragment_graphs [label ]
652671 )
@@ -768,13 +787,13 @@ def get_merged_label(self, label):
768787 Returns:
769788 -------
770789 str or list
771- The first matching label found in "self.fragment_ids " or the
772- original associated labels from "inverse_label_map" if no matches
773- are found.
790+ The first matching label found in "self.fragment_graphs.keys() " or
791+ the original associated labels from "inverse_label_map" if no
792+ matches are found.
774793
775794 """
776795 for l in self .inverse_label_map [label ]:
777- if l in self .fragment_ids :
796+ if l in self .fragment_graphs . keys () :
778797 return l
779798 return self .inverse_label_map [label ]
780799
@@ -988,8 +1007,8 @@ def list_metrics(self):
9881007
9891008 # -- util --
9901009 def dist (self , key , i , j ):
991- xyz_i = self .graphs [key ].nodes [ i ][ "voxel" ]
992- xyz_j = self .graphs [key ].nodes [ j ][ "voxel" ]
1010+ xyz_i = self .graphs [key ].graph [ "voxel" ][ i ]
1011+ xyz_j = self .graphs [key ].graph [ "voxel" ][ j ]
9931012 return distance .euclidean (xyz_i , xyz_j )
9941013
9951014 def init_counter (self ):
@@ -1010,7 +1029,7 @@ def init_counter(self):
10101029 return {key : 0 for key in self .graphs }
10111030
10121031 def to_local_voxels (self , key , i , offset ):
1013- voxel = np .array (self .graphs [key ].nodes [ i ][ "voxel" ])
1032+ voxel = np .array (self .graphs [key ].graph [ "voxel" ][ i ])
10141033 offset = np .array (offset )
10151034 return tuple (voxel - offset )
10161035
0 commit comments