88"""
99
1010
11- from concurrent .futures import ThreadPoolExecutor , as_completed
11+ from concurrent .futures import (
12+ as_completed ,
13+ ProcessPoolExecutor ,
14+ ThreadPoolExecutor ,
15+ )
1216from copy import deepcopy
13- from scipy .spatial import KDTree
17+ from scipy .spatial import distance , KDTree
1418from time import time
1519from tqdm import tqdm
1620from zipfile import ZipFile
1721
1822import networkx as nx
1923import numpy as np
2024import os
21- import tensorstore as ts
2225
23- from segmentation_skeleton_metrics import split_detection
26+ from segmentation_skeleton_metrics import graph_segmentation_alignment as gsa
2427from segmentation_skeleton_metrics .utils import (
2528 graph_util as gutil ,
29+ img_util ,
2630 swc_util ,
2731 util
2832)
@@ -108,8 +112,8 @@ def __init__(
108112 self .output_dir = output_dir
109113 self .preexisting_merges = preexisting_merges
110114
111- # Load Labels, Graphs, Fragments
112- print ("\n (1) Initializations " )
115+ # Load Data
116+ print ("\n (1) Load Data " )
113117 assert type (valid_labels ) is set if valid_labels else True
114118 self .label_mask = pred_labels
115119 self .valid_labels = valid_labels
@@ -177,7 +181,7 @@ def init_graphs(self, paths):
177181 self .graphs [key ]
178182 )
179183
180- def set_node_labels (self , key ):
184+ def set_node_labels (self , key , batch_size = 128 ):
181185 """
182186 Iterates over nodes in "graph" and stores the corresponding label from
183187 predicted segmentation mask (i.e. "self.label_mask") as a node-level
@@ -195,17 +199,62 @@ def set_node_labels(self, key):
195199 """
196200 with ThreadPoolExecutor () as executor :
197201 # Assign threads
198- threads = []
199- for i in self .graphs [key ].nodes :
200- voxel = tuple (self .graphs [key ].nodes [i ]["voxel" ])
201- threads .append (executor .submit (self .get_label , i , voxel ))
202+ batch = set ()
203+ threads = list ()
204+ visited = set ()
205+ for i , j in nx .dfs_edges (self .graphs [key ]):
206+ # Check for new batch
207+ if len (batch ) == 0 :
208+ root = i
209+ batch .add (i )
210+ visited .add (i )
211+
212+ # Check whether to submit batch
213+ is_node_far = self .dist (key , root , j ) > 128
214+ is_batch_full = len (batch ) >= batch_size
215+ if is_node_far or is_batch_full :
216+ threads .append (
217+ executor .submit (self .get_patch_labels , key , batch )
218+ )
219+ batch = set ()
202220
203- # Store label
204- for thread in as_completed (threads ):
205- i , label = thread .result ()
206- self .graphs [key ].nodes [i ].update ({"label" : label })
221+ # Visit j
222+ if j not in visited :
223+ batch .add (j )
224+ visited .add (j )
225+ if len (batch ) == 1 :
226+ root = j
227+
228+ # Submit last thread
229+ threads .append (executor .submit (self .get_patch_labels , key , batch ))
207230
208- def get_label (self , i , voxel ):
231+ # Process results
232+ for thread in as_completed (threads ):
233+ node_to_label = thread .result ()
234+ for i , label in node_to_label .items ():
235+ self .graphs [key ].nodes [i ].update ({"label" : label })
236+
237+ def get_patch_labels (self , key , nodes ):
238+ # Get bounding box
239+ bbox = {"min" : [np .inf , np .inf , np .inf ], "max" : [0 , 0 , 0 ]}
240+ for i in nodes :
241+ voxel = deepcopy (self .graphs [key ].nodes [i ]["voxel" ])
242+ for idx in range (3 ):
243+ if voxel [idx ] < bbox ["min" ][idx ]:
244+ bbox ["min" ][idx ] = voxel [idx ]
245+ if voxel [idx ] >= bbox ["max" ][idx ]:
246+ bbox ["max" ][idx ] = voxel [idx ] + 1
247+
248+ # Read labels
249+ label_patch = self .label_mask .read_with_bbox (bbox )
250+ node_to_label = dict ()
251+ for i in nodes :
252+ voxel = self .to_local_voxels (key , i , bbox ["min" ])
253+ label = self .adjust_label (label_patch [voxel ])
254+ node_to_label [i ] = label
255+ return node_to_label
256+
257+ def adjust_label (self , label ):
209258 """
210259 Gets label of voxel in "self.label_mask".
211260
@@ -222,18 +271,11 @@ def get_label(self, i, voxel):
222271 Label of voxel.
223272
224273 """
225- # Read label
226- if isinstance (self .label_mask , ts .TensorStore ):
227- label = int (self .label_mask [voxel ].read ().result ())
228- else :
229- label = self .label_mask [voxel ]
230-
231- # Check whether to update label
232274 if self .label_map :
233275 label = self .get_equivalent_label (label )
234276 elif self .valid_labels :
235277 label = 0 if label not in self .valid_labels else label
236- return i , label
278+ return label
237279
238280 def get_equivalent_label (self , label ):
239281 """
@@ -443,15 +485,29 @@ def detect_splits(self):
443485
444486 """
445487 t0 = time ()
446- for key , graph in tqdm (self .graphs .items (), desc = "Split Detection:" ):
447- # Detection
448- graph = split_detection .run (graph , self .graphs [key ])
488+ pbar = tqdm (total = len (self .graphs ), desc = "Split Detection:" )
489+ with ProcessPoolExecutor () as executor :
490+ # Assign processes
491+ processes = list ()
492+ for key , graph in self .graphs .items ():
493+ processes .append (
494+ executor .submit (
495+ gsa .correct_graph_misalignments ,
496+ key ,
497+ graph ,
498+ )
499+ )
449500
450- # Update graph by removing omits (i.e. nodes labeled 0)
451- self .graphs [key ] = gutil .delete_nodes (graph , 0 )
452- self .key_to_label_to_nodes [key ] = gutil .init_label_to_nodes (
453- self .graphs [key ]
454- )
501+ # Store results
502+ self .split_percent = dict ()
503+ for process in as_completed (processes ):
504+ key , graph , split_percent = process .result ()
505+ self .graphs [key ] = gutil .delete_nodes (graph , 0 )
506+ self .key_to_label_to_nodes [key ] = gutil .init_label_to_nodes (
507+ self .graphs [key ]
508+ )
509+ self .split_percent [key ] = split_percent
510+ pbar .update (1 )
455511
456512 # Report runtime
457513 t , unit = util .time_writer (time () - t0 )
@@ -505,15 +561,19 @@ def detect_merges(self):
505561
506562 # Count total merges
507563 if self .fragment_graphs :
564+ pbar = tqdm (total = len (self .graphs ), desc = "Count Merges:" )
508565 for key , graph in self .graphs .items ():
509566 if graph .number_of_nodes () > 0 :
510567 kdtree = KDTree (gutil .to_array (graph ))
511568 self .count_merges (key , kdtree )
569+ pbar .update (1 )
512570
513571 # Process merges
572+ pbar = tqdm (total = len (self .graphs ), desc = "Compute Percent Merged:" )
514573 for (key_1 , key_2 ), label in self .find_label_intersections ():
515574 self .process_merge (key_1 , label , - 1 )
516575 self .process_merge (key_2 , label , - 1 )
576+ pbar .update (1 )
517577
518578 for key , label , xyz in self .merged_labels :
519579 self .process_merge (key , label , xyz , update_merged_labels = False )
@@ -583,7 +643,7 @@ def is_fragment_merge(self, key, label, kdtree):
583643 equivalent_label = label
584644
585645 # Record merge mistake
586- xyz = util .to_physical (voxel )
646+ xyz = img_util .to_physical (voxel )
587647 self .merge_cnt [key ] += 1
588648 self .merged_labels .add ((key , equivalent_label , tuple (xyz )))
589649 if self .save_projections :
@@ -776,6 +836,7 @@ def generate_full_results(self):
776836 "# splits" : generate_result (keys , self .split_cnt ),
777837 "# merges" : generate_result (keys , self .merge_cnt ),
778838 "% omit" : generate_result (keys , self .omit_percent ),
839+ "% split" : generate_result (keys , self .split_percent ),
779840 "% merged" : generate_result (keys , self .merged_percent ),
780841 "edge accuracy" : generate_result (keys , self .edge_accuracy ),
781842 "erl" : generate_result (keys , self .erl ),
@@ -801,6 +862,7 @@ def generate_avg_results(self):
801862 "# splits" : self .avg_result (self .split_cnt ),
802863 "# merges" : self .avg_result (self .merge_cnt ),
803864 "% omit" : self .avg_result (self .omit_percent ),
865+ "% split" : self .avg_result (self .split_percent ),
804866 "% merged" : self .avg_result (self .merged_percent ),
805867 "edge accuracy" : self .avg_result (self .edge_accuracy ),
806868 "erl" : self .avg_result (self .erl ),
@@ -925,6 +987,11 @@ def list_metrics(self):
925987 return metrics
926988
927989 # -- util --
990+ def dist (self , key , i , j ):
991+ xyz_i = self .graphs [key ].nodes [i ]["voxel" ]
992+ xyz_j = self .graphs [key ].nodes [j ]["voxel" ]
993+ return distance .euclidean (xyz_i , xyz_j )
994+
928995 def init_counter (self ):
929996 """
930997 Initializes a dictionary that is used to count some type of mistake
@@ -942,6 +1009,11 @@ def init_counter(self):
9421009 """
9431010 return {key : 0 for key in self .graphs }
9441011
1012+ def to_local_voxels (self , key , i , offset ):
1013+ voxel = np .array (self .graphs [key ].nodes [i ]["voxel" ])
1014+ offset = np .array (offset )
1015+ return tuple (voxel - offset )
1016+
9451017
9461018# -- util --
9471019def find_sites (graphs , get_labels ):
0 commit comments