@@ -412,6 +412,8 @@ def is_zero_misalignment(
412412 if utils .check_edge (dfs_edges , (j , k )):
413413 queue .append (k )
414414 dfs_edges = remove_edge (dfs_edges , (j , k ))
415+ elif k == nb :
416+ queue .append (k )
415417
416418 # Upd zero nodes
417419 if len (collision_labels ) == 1 and not black_hole :
@@ -426,16 +428,15 @@ def is_nonzero_misalignment(
426428 # Initialize
427429 origin_label = pred_graph .nodes [nb ]["pred_id" ]
428430 hit_label = pred_graph .nodes [root ]["pred_id" ]
429- parent = nb
430431
431432 # Search
432- queue = [root ]
433+ queue = [( nb , root ) ]
433434 visited = set ([nb ])
434435 while len (queue ) > 0 :
435- j = queue .pop (0 )
436+ parent , j = queue .pop (0 )
436437 label_j = pred_graph .nodes [j ]["pred_id" ]
437438 visited .add (j )
438- if label_j == origin_label :
439+ if label_j == origin_label and len ( queue ) == 0 :
439440 # misalignment
440441 pred_graph = gutils .upd_labels (
441442 pred_graph , visited , origin_label
@@ -444,20 +445,19 @@ def is_nonzero_misalignment(
444445 elif label_j == hit_label :
445446 # continue search
446447 nbs = list (target_graph .neighbors (j ))
447- nbs .remove (parent )
448- if len (nbs ) == 1 :
449- parent = j
450- queue .append (nbs [0 ])
451- dfs_edges = remove_edge (dfs_edges , (j , nbs [0 ]))
452- else :
453- pred_graph = gutils .remove_edge (pred_graph , nb , root )
454- return dfs_edges , pred_graph
448+ for k in [k for k in nbs if k not in visited ]:
449+ queue .append ((j , k ))
450+ dfs_edges = remove_edge (dfs_edges , (j , k ))
455451 else :
456452 # left hit label
457453 dfs_edges .insert (0 , (parent , j ))
458454 pred_graph = gutils .remove_edge (pred_graph , nb , root )
459455 return dfs_edges , pred_graph
460456
457+ # End of search
458+ pred_graph = gutils .remove_edge (pred_graph , nb , root )
459+ return dfs_edges , pred_graph
460+
461461 def quantify_splits (self ):
462462 """
463463 Counts the number of splits, number of omit edges, and percent of omit
@@ -483,6 +483,7 @@ def quantify_splits(self):
483483 self .split_cnts [swc_id ] = n_splits
484484 self .omit_cnts [swc_id ] = n_target_edges - n_pred_edges
485485 self .omit_percents [swc_id ] = 1 - n_pred_edges / n_target_edges
486+ print (swc_id , n_pred_edges / n_target_edges )
486487
487488 def detect_merges (self ):
488489 """
@@ -519,7 +520,7 @@ def detect_merges(self):
519520 for label in intersection :
520521 sites , dist = self .localize (swc_id_1 , swc_id_2 , label )
521522 xyz = utils .get_midpoint (sites [0 ], sites [1 ])
522- if True : # dist > 20 and not self.near_bdd(xyz):
523+ if dist > 20 and not self .near_bdd (xyz ):
523524 # Write site to swc
524525 if self .write_to_swc :
525526 self .save_swc (sites [0 ], sites [1 ], "merge" )
@@ -648,13 +649,8 @@ def compile_results(self):
648649 self .compute_erl ()
649650
650651 # Summarize results
651- swc_ids , results = self .generate_report ()
652- avg_results = dict ([(k , np .mean (v )) for k , v in results .items ()])
653-
654- # Adjust certain stats
655- n_detected_neurons = np .sum (np .array (results ["% omit edges" ]) < 1 )
656- avg_results ["# splits" ] = np .sum (results ["# splits" ]) / n_detected_neurons
657- avg_results ["# merges" ] = avg_results ["# merges" ] / 2
652+ swc_ids , results = self .generate_full_results ()
653+ avg_results = self .generate_avg_results ()
658654
659655 # Reformat full results
660656 full_results = dict ()
@@ -665,7 +661,7 @@ def compile_results(self):
665661
666662 return full_results , avg_results
667663
668- def generate_report (self ):
664+ def generate_full_results (self ):
669665 """
670666 Generates a report by creating a list of the results for each metric.
671667 Each item in this list corresponds to a graph in "self.pred_graphs"
@@ -689,14 +685,33 @@ def generate_report(self):
689685 stats = {
690686 "# splits" : generate_result (swc_ids , self .split_cnts ),
691687 "# merges" : generate_result (swc_ids , self .merge_cnts ),
692- "% omit edges " : generate_result (swc_ids , self .omit_percents ),
693- "% merged edges " : generate_result (swc_ids , self .merged_percents ),
688+ "% omit" : generate_result (swc_ids , self .omit_percents ),
689+ "% merged" : generate_result (swc_ids , self .merged_percents ),
694690 "edge accuracy" : generate_result (swc_ids , self .edge_accuracy ),
695691 "erl" : generate_result (swc_ids , self .erl ),
696692 "normalized erl" : generate_result (swc_ids , self .normalized_erl ),
697693 }
698694 return swc_ids , stats
699695
696+ def generate_avg_results (self ):
697+ avg_stats = {
698+ "# splits" : self .avg_result (self .split_cnts ),
699+ "# merges" : self .avg_result (self .merge_cnts ),
700+ "% omit" : self .avg_result (self .omit_percents ),
701+ "% merged" : self .avg_result (self .merged_percents ),
702+ "edge accuracy" : self .avg_result (self .edge_accuracy ),
703+ "erl" : self .avg_result (self .erl ),
704+ "normalized erl" : self .avg_result (self .normalized_erl ),
705+ }
706+ return avg_stats
707+
708+ def avg_result (self , stats ):
709+ result = []
710+ for swc_id , wgt in self .wgts .items ():
711+ if self .omit_percents [swc_id ] < 1 :
712+ result .append (wgt * stats [swc_id ])
713+ return np .sum (result )
714+
700715 def compute_edge_accuracy (self ):
701716 """
702717 Computes the edge accuracy of each pred_graph.
@@ -731,17 +746,25 @@ def compute_erl(self):
731746 """
732747 self .erl = dict ()
733748 self .normalized_erl = dict ()
749+ self .wgts = dict ()
750+ total_path_length = 0
734751 for swc_id in self .target_graphs .keys ():
735752 pred_graph = self .pred_graphs [swc_id ]
736753 target_graph = self .target_graphs [swc_id ]
737754
738755 path_length = gutils .compute_path_length (target_graph )
739- path_lengths = gutils .compute_run_lengths (pred_graph )
740- wgts = path_lengths / max ( np .sum (path_lengths ), 1 )
756+ run_lengths = gutils .compute_run_lengths (pred_graph )
757+ wgt = run_lengths / np .sum (run_lengths )
741758
742- self .erl [swc_id ] = np .sum (wgts * path_lengths )
759+ self .erl [swc_id ] = np .sum (wgt * run_lengths )
743760 self .normalized_erl [swc_id ] = self .erl [swc_id ] / path_length
744761
762+ self .wgts [swc_id ] = path_length
763+ total_path_length += path_length
764+
765+ for swc_id in self .target_graphs .keys ():
766+ self .wgts [swc_id ] = self .wgts [swc_id ] / total_path_length
767+
745768 def list_metrics (self ):
746769 """
747770 Lists metrics that are computed by this module.
@@ -759,8 +782,8 @@ def list_metrics(self):
759782 metrics = [
760783 "# splits" ,
761784 "# merges" ,
762- "% omit edges " ,
763- "% merged edges " ,
785+ "% omit" ,
786+ "% merged" ,
764787 "edge accuracy" ,
765788 "erl" ,
766789 "normalized erl" ,
0 commit comments