@@ -111,9 +111,6 @@ def in_black_hole(self, xyz, print_nn=False):
111111 # Search black_holes
112112 radius = self .black_hole_radius
113113 pts = self .black_holes .query_ball_point (xyz , radius )
114- if print_nn :
115- dd , ii = self .black_holes .query ([xyz ], k = [1 ])
116- print ("Nearest neighbor:" , dd )
117114 if len (pts ) > 0 :
118115 return True
119116 else :
@@ -430,7 +427,6 @@ def is_nonzero_misalignment(
430427 origin_label = pred_graph .nodes [nb ]["pred_id" ]
431428 hit_label = pred_graph .nodes [root ]["pred_id" ]
432429 parent = nb
433- depth = 0
434430
435431 # Search
436432 queue = [root ]
@@ -439,22 +435,20 @@ def is_nonzero_misalignment(
439435 j = queue .pop (0 )
440436 label_j = pred_graph .nodes [j ]["pred_id" ]
441437 visited .add (j )
442- depth += 1
443438 if label_j == origin_label :
444439 # misalignment
445440 pred_graph = gutils .upd_labels (
446441 pred_graph , visited , origin_label
447442 )
448443 return dfs_edges , pred_graph
449- elif label_j == hit_label and depth < 16 :
444+ elif label_j == hit_label :
450445 # continue search
451446 nbs = list (target_graph .neighbors (j ))
452447 nbs .remove (parent )
453448 if len (nbs ) == 1 :
454- if utils .check_edge (dfs_edges , (j , nbs [0 ])):
455- parent = j
456- queue .append (nbs [0 ])
457- dfs_edges = remove_edge (dfs_edges , (j , nbs [0 ]))
449+ parent = j
450+ queue .append (nbs [0 ])
451+ dfs_edges = remove_edge (dfs_edges , (j , nbs [0 ]))
458452 else :
459453 pred_graph = gutils .remove_edge (pred_graph , nb , root )
460454 return dfs_edges , pred_graph
@@ -523,20 +517,16 @@ def detect_merges(self):
523517 pred_ids_2 = self .get_pred_ids (swc_id_2 )
524518 intersection = pred_ids_1 .intersection (pred_ids_2 )
525519 for label in intersection :
526- #merged_1 = self.label_to_node[swc_id_1][label]
527- #merged_2 = self.label_to_node[swc_id_2][label]
528- # too_small = min(len(merged_1), len(merged_2)) > 16
529- if True : # not too_small:
530- sites , dist = self .localize (swc_id_1 , swc_id_2 , label )
531- xyz = utils .get_midpoint (sites [0 ], sites [1 ])
532- if dist > 20 and not self .near_bdd (xyz ):
533- # Write site to swc
534- if self .write_to_swc :
535- self .save_swc (sites [0 ], sites [1 ], "merge" )
536-
537- # Process merge
538- self .process_merge (swc_id_1 , label )
539- self .process_merge (swc_id_2 , label )
520+ sites , dist = self .localize (swc_id_1 , swc_id_2 , label )
521+ xyz = utils .get_midpoint (sites [0 ], sites [1 ])
522+ if True : #dist > 20 and not self.near_bdd(xyz):
523+ # Write site to swc
524+ if self .write_to_swc :
525+ self .save_swc (sites [0 ], sites [1 ], "merge" )
526+
527+ # Process merge
528+ self .process_merge (swc_id_1 , label )
529+ self .process_merge (swc_id_2 , label )
540530
541531 # Remove label to avoid reprocessing
542532 del self .label_to_node [swc_id_1 ][label ]
@@ -660,6 +650,10 @@ def compile_results(self):
660650 # Summarize results
661651 swc_ids , results = self .generate_report ()
662652 avg_results = dict ([(k , np .mean (v )) for k , v in results .items ()])
653+
654+ # Adjust certain stats
655+ n_detected_neurons = np .sum (np .array (results ["% omit edges" ]) < 1 )
656+ avg_results ["# splits" ] = np .sum (results ["# splits" ]) / n_detected_neurons
663657 avg_results ["# merges" ] = avg_results ["# merges" ] / 2
664658
665659 # Reformat full results
0 commit comments