1313from concurrent .futures import (
1414 as_completed ,
1515 ProcessPoolExecutor ,
16- ThreadPoolExecutor ,
1716)
1817from copy import deepcopy
1918from scipy .spatial import distance , KDTree
@@ -119,8 +118,7 @@ def __init__(
119118 )
120119
121120 # Load data
122- self .label_mask = label_mask
123- self .load_groundtruth (gt_pointer )
121+ self .load_groundtruth (gt_pointer , label_mask )
124122 self .load_fragments (fragments_pointer )
125123
126124 # Initialize metrics
@@ -144,14 +142,16 @@ def __init__(
144142 self .metrics = pd .DataFrame (index = row_names , columns = col_names )
145143
146144 # --- Load Data ---
147- def load_groundtruth (self , swc_pointer ):
145+ def load_groundtruth (self , swc_pointer , label_mask ):
148146 """
149147 Loads ground truth graphs and initializes the "graphs" attribute.
150148
151149 Parameters
152150 ----------
153151 swc_pointer : Any
154152 Pointer to ground truth SWC files.
153+ label_mask : ImageReader
154+ Predicted segmentation mask.
155155
156156 Returns
157157 -------
@@ -160,19 +160,16 @@ def load_groundtruth(self, swc_pointer):
160160 """
161161 # Build graphs
162162 print ("\n (1) Load Ground Truth" )
163- graph_builder = gutil .GraphBuilder (
163+ graph_loader = gutil .GraphLoader (
164164 anisotropy = self .anisotropy ,
165165 is_groundtruth = True ,
166- label_mask = self .label_mask ,
166+ label_handler = self .label_handler ,
167+ label_mask = label_mask ,
167168 use_anisotropy = False ,
168169 )
169- self .graphs = graph_builder .run (swc_pointer )
170+ self .graphs = graph_loader .run (swc_pointer )
170171 self .gt_graphs = deepcopy (self .graphs )
171172
172- # Label nodes
173- for key in tqdm (self .graphs , desc = "Labeling Graphs" ):
174- self .label_graphs (key )
175-
176173 def load_fragments (self , swc_pointer ):
177174 """
178175 Loads fragments generated from the segmentation and initializes the
@@ -190,13 +187,13 @@ def load_fragments(self, swc_pointer):
190187 """
191188 print ("\n (2) Load Fragments" )
192189 if swc_pointer :
193- graph_builder = gutil .GraphBuilder (
190+ graph_loader = gutil .GraphLoader (
194191 anisotropy = self .anisotropy ,
195192 is_groundtruth = False ,
196193 selected_ids = self .get_all_node_labels (),
197194 use_anisotropy = self .use_anisotropy ,
198195 )
199- self .fragment_graphs = graph_builder .run (swc_pointer )
196+ self .fragment_graphs = graph_loader .run (swc_pointer )
200197 self .set_fragment_ids ()
201198 else :
202199 self .fragment_graphs = None
@@ -219,86 +216,6 @@ def set_fragment_ids(self):
219216 for key in self .fragment_graphs :
220217 self .fragment_ids .add (util .get_segment_id (key ))
221218
222- def label_graphs (self , key ):
223- """
224- Iterates over nodes in "graph" and stores the corresponding label from
225- "self.label_mask") as a node-level attribute called "labels".
226-
227- Parameters
228- ----------
229- key : str
230- Unique identifier of graph to be labeled.
231-
232- Returns
233- -------
234- None
235-
236- """
237- with ThreadPoolExecutor () as executor :
238- # Assign threads
239- batch = set ()
240- threads = list ()
241- visited = set ()
242- for i , j in nx .dfs_edges (self .graphs [key ]):
243- # Check if starting new batch
244- if len (batch ) == 0 :
245- root = i
246- batch .add (i )
247- visited .add (i )
248-
249- # Check whether to submit batch
250- is_node_far = self .graphs [key ].dist (root , j ) > 128
251- is_batch_full = len (batch ) >= 128
252- if is_node_far or is_batch_full :
253- threads .append (
254- executor .submit (self .get_patch_labels , key , batch )
255- )
256- batch = set ()
257-
258- # Visit j
259- if j not in visited :
260- batch .add (j )
261- visited .add (j )
262- if len (batch ) == 1 :
263- root = j
264-
265- # Submit last batch
266- threads .append (executor .submit (self .get_patch_labels , key , batch ))
267-
268- # Store results
269- self .graphs [key ].init_labels ()
270- for thread in as_completed (threads ):
271- node_to_label = thread .result ()
272- for i , label in node_to_label .items ():
273- self .graphs [key ].labels [i ] = label
274-
275- def get_patch_labels (self , key , nodes ):
276- """
277- Gets the segment labels for a given set of nodes within a specified
278- patch of the label mask.
279-
280- Parameters
281- ----------
282- key : str
283- Unique identifier of graph to be labeled.
284- nodes : List[int]
285- Node IDs for which the labels are to be retrieved.
286-
287- Returns
288- -------
289- dict
290- A dictionary that maps node IDs to their respective labels.
291-
292- """
293- bbox = self .graphs [key ].get_bbox (nodes )
294- label_patch = self .label_mask .read_with_bbox (bbox )
295- node_to_label = dict ()
296- for i in nodes :
297- voxel = self .to_local_voxels (key , i , bbox ["min" ])
298- label = self .label_handler .get (label_patch [voxel ])
299- node_to_label [i ] = label
300- return node_to_label
301-
302219 def get_all_node_labels (self ):
303220 """
304221 Gets the set of unique node labels from all graphs in "self.graphs".
@@ -407,7 +324,8 @@ def run(self):
407324 # Save results
408325 prefix = "corrected-" if self .connections_path else ""
409326 path = f"{ self .output_dir } /{ prefix } results.csv"
410- self .metrics .fillna (0 )
327+ if self .fragment_graphs is None :
328+ self .metrics = self .metrics .drop ("# Merges" , axis = 1 )
411329 self .metrics .to_csv (path , index = True )
412330
413331 # Report results
@@ -419,10 +337,12 @@ def run(self):
419337 util .update_txt (path , f" { column_name } : { avg :.4f} " )
420338
421339 n_splits = self .metrics ["# Splits" ].sum ()
422- n_merges = self .metrics ["# Merges" ].sum ()
423340 util .update_txt (path , "\n Total Results..." )
424341 util .update_txt (path , " # Splits: " + str (n_splits ))
425- util .update_txt (path , " # Merges: " + str (n_merges ))
342+
343+ if self .fragment_graphs is not None :
344+ n_merges = self .metrics ["# Merges" ].sum ()
345+ util .update_txt (path , " # Merges: " + str (n_merges ))
426346
427347 # -- Split Detection --
428348 def detect_splits (self ):
0 commit comments