2323from segmentation_skeleton_metrics .utils import (
2424 graph_util as gutil ,
2525 img_util ,
26- swc_util ,
2726 util
2827)
2928
30- MIN_CNT = 40
31-
3229
3330class SkeletonMetric :
3431 """
@@ -155,10 +152,8 @@ def load_groundtruth(self, swc_pointer):
155152 def load_fragments (self , swc_pointer ):
156153 print ("\n (2) Load Fragments" )
157154 if swc_pointer :
158- coords_only = False #not self.save_projections
159155 graph_builder = gutil .GraphBuilder (
160156 anisotropy = self .anisotropy ,
161- coords_only = coords_only ,
162157 selected_ids = self .get_all_node_labels (),
163158 use_anisotropy = True ,
164159 )
@@ -172,7 +167,7 @@ def set_fragment_ids(self):
172167 for key in self .fragment_graphs :
173168 self .fragment_ids .add (util .get_segment_id (key ))
174169
175- def label_graphs (self , key , batch_size = 64 ):
170+ def label_graphs (self , key , batch_size = 128 ):
176171 """
177172 Iterates over nodes in "graph" and stores the corresponding label from
178173 predicted segmentation mask (i.e. "self.label_mask") as a node-level
@@ -201,7 +196,7 @@ def label_graphs(self, key, batch_size=64):
201196 visited .add (i )
202197
203198 # Check whether to submit batch
204- is_node_far = self .graphs [key ].dist (root , j ) > 128
199+ is_node_far = self .graphs [key ].dist (root , j ) > batch_size
205200 is_batch_full = len (batch ) >= batch_size
206201 if is_node_far or is_batch_full :
207202 threads .append (
@@ -306,9 +301,7 @@ def init_zip_writer(self):
306301 self .zip_writer = dict ()
307302 for key in self .graphs .keys ():
308303 self .zip_writer [key ] = ZipFile (f"{ output_dir } /{ key } .zip" , "w" )
309- swc_util .to_zipped_swc (
310- self .zip_writer [key ], self .graphs [key ],
311- )
304+ self .graphs [key ].to_zipped_swc (self .zip_writer [key ])
312305
313306 # -- Main Routine --
314307 def run (self ):
@@ -331,11 +324,6 @@ def run(self):
331324 self .detect_splits ()
332325 self .quantify_splits ()
333326
334- # Check for prexisting merges
335- if self .preexisting_merges :
336- for key in self .graphs :
337- self .adjust_metrics (key )
338-
339327 # Merge evaluation
340328 self .detect_merges ()
341329 self .quantify_merges ()
@@ -344,39 +332,6 @@ def run(self):
344332 full_results , avg_results = self .compile_results ()
345333 return full_results , avg_results
346334
347- def adjust_metrics (self , key ):
348- """
349- Adjusts the metrics of the graph associated with the given key by
350- removing nodes corresponding to known merges and their corresponding
351- subgraphs. Updates the total number of edges and run lengths in the
352- graph.
353-
354- Parameters
355- ----------
356- key : str
357- Identifier for the graph to adjust.
358-
359- Returns
360- -------
361- None
362-
363- """
364- for label in self .preexisting_merges :
365- label = self .label_map [label ] if self .label_map else label
366- if label in self .graphs [key ].get_labels ():
367- # Extract subgraph
368- nodes = self .graphs [key ].nodes_with_label (label )
369- subgraph = self .graphs [key ].subgraph (nodes )
370-
371- # Adjust metrics
372- n_edges = subgraph .number_of_edges ()
373- rls = gutil .compute_run_lengths (subgraph )
374- self .graphs [key ].graph ["run_length" ] -= np .sum (rls )
375- self .graphs [key ].graph ["n_edges" ] -= n_edges
376-
377- # Update graph
378- self .graphs [key ].remove_nodes_from (nodes )
379-
380335 # -- Split Detection --
381336 def detect_splits (self ):
382337 """
@@ -393,7 +348,7 @@ def detect_splits(self):
393348
394349 """
395350 pbar = tqdm (total = len (self .graphs ), desc = "Split Detection" )
396- with ProcessPoolExecutor () as executor :
351+ with ProcessPoolExecutor (max_workers = 8 ) as executor :
397352 # Assign processes
398353 processes = list ()
399354 for key , graph in self .graphs .items ():
@@ -470,7 +425,12 @@ def detect_merges(self):
470425 self .count_merges (key , kdtree )
471426 pbar .update (1 )
472427
473- # Process merges
428+ # Adjust metrics (if applicable)
429+ if self .preexisting_merges :
430+ for key in self .graphs :
431+ self .adjust_metrics (key )
432+
433+ # Find graphs with common node labels
474434 for (key_1 , key_2 ), label in self .find_label_intersections ():
475435 self .process_merge (key_1 , label , - 1 )
476436 self .process_merge (key_2 , label , - 1 )
@@ -502,7 +462,7 @@ def count_merges(self, key, kdtree):
502462 """
503463 for label in self .get_node_labels (key ):
504464 nodes = self .graphs [key ].nodes_with_label (label )
505- if len (nodes ) > MIN_CNT :
465+ if len (nodes ) > 50 :
506466 for label in self .label_handler .get_class (label ):
507467 if label in self .fragment_ids :
508468 self .is_fragment_merge (key , label , kdtree )
@@ -539,16 +499,45 @@ def is_fragment_merge(self, key, label, kdtree):
539499 self .merged_labels .add ((key , equiv_label , tuple (xyz )))
540500
541501 # Save merged fragment (if applicable)
542- if self .save_projections and label in self .fragment_graphs :
543- swc_util .to_zipped_swc (
544- self .zip_writer [key ], self .fragment_graphs [label ]
545- )
502+ if self .save_projections :
503+ fragment_graph .to_zipped_swc (self .zip_writer [key ])
546504 break
547505
548- def find_graph_from_label (self , label ):
549- for key in self .fragment_graphs :
550- if label == util .get_segment_id (key ):
551- return self .fragment_graphs [key ]
506+ def adjust_metrics (self , key ):
507+ """
508+ Adjusts the metrics of the graph associated with the given key by
509+ removing nodes corresponding to known merges and their corresponding
510+ subgraphs. Updates the total number of edges and run lengths in the
511+ graph.
512+
513+ Parameters
514+ ----------
515+ key : str
516+ Identifier for the graph to adjust.
517+
518+ Returns
519+ -------
520+ None
521+
522+ """
523+ visited = set ()
524+ for label in self .preexisting_merges :
525+ label = self .label_handler .mapping [label ]
526+ if label in self .graphs [key ].get_labels ():
527+ if label not in visited and label != 0 :
528+ # Get component with label
529+ nodes = self .graphs [key ].nodes_with_label (label )
530+ root = util .sample_once (list (nodes ))
531+
532+ # Adjust metrics
533+ rl = self .graphs [key ].run_length_from (root )
534+ self .graphs [key ].run_length -= np .sum (rl )
535+ self .graphs [key ].graph ["n_edges" ] -= len (nodes ) - 1
536+
537+ # Update graph
538+ self .graphs [key ].remove_nodes_from (nodes )
539+ visited .add (label )
540+ print ("# nodes deleted:" , len (nodes ))
552541
553542 def find_label_intersections (self ):
554543 """
@@ -673,7 +662,7 @@ def get_merged_label(self, label):
673662 for l in self .label_handler .get_class (label ):
674663 if l in self .fragment_graphs .keys ():
675664 return l
676- return self .inverse_label_map [label ]
665+ return self .label_handler . inverse_mapping [label ]
677666
678667 # -- Compute Metrics --
679668 def compile_results (self ):
@@ -866,7 +855,13 @@ def list_metrics(self):
866855 ]
867856 return metrics
868857
869- # -- util --
858+ # -- Helpers --
859+ def find_graph_from_label (self , label ):
860+ for key in self .fragment_graphs :
861+ if label == util .get_segment_id (key ):
862+ return self .fragment_graphs [key ]
863+ return None
864+
870865 def physical_dist (self , voxel_1 , voxel_2 ):
871866 xyz_1 = img_util .to_physical (voxel_1 , self .anisotropy )
872867 xyz_2 = img_util .to_physical (voxel_2 , self .anisotropy )
@@ -896,40 +891,6 @@ def to_local_voxels(self, key, i, offset):
896891
897892
898893# -- util --
899- def find_sites (graphs , get_labels ):
900- """
901- Detects merges between ground truth graphs which are considered to be
902- potential merge sites.
903-
904- Parameters
905- ----------
906- graphs : dict
907- Dictionary where the keys are graph ids and values are graphs.
908- get_labels : func
909- Gets the label of a node in "graphs".
910-
911- Returns
912- -------
913- merge_ids : set[tuple]
914- Set of tuples containing a tuple of graph ids and common label between
915- the graphs.
916-
917- """
918- merge_ids = set ()
919- visited = set ()
920- for key_1 in graphs :
921- for key_2 in graphs :
922- keys = frozenset ((key_1 , key_2 ))
923- if key_1 != key_2 and keys not in visited :
924- visited .add (keys )
925- intersection = get_labels (key_1 ).intersection (
926- get_labels (key_2 )
927- )
928- for label in intersection :
929- merge_ids .add ((keys , label ))
930- return merge_ids
931-
932-
933894def generate_result (keys , stats ):
934895 """
935896 Reorders items in "stats" with respect to the order defined by "keys".
0 commit comments