2020from segmentation_skeleton_metrics import split_detection , swc_utils , utils
2121from segmentation_skeleton_metrics .swc_utils import save , to_graph
2222
23- INTERSECTION_THRESHOLD = 8
23+ INTERSECTION_THRESHOLD = 10
2424MERGE_DIST_THRESHOLD = 40
2525
2626
@@ -50,8 +50,8 @@ def __init__(
5050 equivalent_ids = None ,
5151 ignore_boundary_mistakes = False ,
5252 output_dir = None ,
53- valid_size_threshold = 40 ,
54- write_to_swc = False ,
53+ valid_size_threshold = 25 ,
54+ save_swc = False ,
5555 ):
5656 """
5757 Constructs skeleton metric object that evaluates the quality of a
@@ -87,7 +87,7 @@ def __init__(
8787 Threshold on the number of nodes contained in an swc file. Only swc
8888 files with more than "valid_size_threshold" nodes are stored in
8989 "self.valid_labels". The default is 40.
90- write_to_swc : bool, optional
90+ save_swc : bool, optional
9191 Indication of whether to write mistake sites to an swc file. The
9292 default is False.
9393
@@ -100,7 +100,7 @@ def __init__(
100100 self .anisotropy = anisotropy
101101 self .ignore_boundary_mistakes = ignore_boundary_mistakes
102102 self .output_dir = output_dir
103- self .write_to_swc = write_to_swc
103+ self .save = save_swc
104104
105105 self .init_black_holes (black_holes_xyz_id )
106106 self .black_hole_radius = black_hole_radius
@@ -372,11 +372,13 @@ def compute_metrics(self):
372372 """
373373 # Split evaluation
374374 print ("Detecting Splits..." )
375+ self .saved_site_cnt = 0
375376 self .detect_splits ()
376377 self .quantify_splits ()
377378
378379 # Merge evaluation
379380 print ("Detecting Merges..." )
381+ self .saved_site_cnt = 0
380382 self .detect_merges ()
381383 self .quantify_merges ()
382384
@@ -509,18 +511,18 @@ def detect_merges(self):
509511 merge = (frozenset ((target_id_1 , target_id_2 )), label )
510512 if merge not in detected_merges :
511513 detected_merges .add (merge )
512- site , d = self .locate (target_id_1 , target_id_2 , label )
513- if d < MERGE_DIST_THRESHOLD :
514- self .merge_cnts [target_id_1 ] += 1
515- self .merge_cnts [target_id_2 ] += 1
516- if self .write_to_swc :
514+ if self .save :
515+ site , d = self .locate (target_id_1 , target_id_2 , label )
516+ if d < MERGE_DIST_THRESHOLD :
517517 self .save_swc (site [0 ], site [1 ], "merge" )
518518
519519 # Update graph
520520 for (target_ids , label ) in detected_merges :
521521 target_id_1 , target_id_2 = tuple (target_ids )
522522 self .process_merge (target_id_1 , label )
523523 self .process_merge (target_id_2 , label )
524+ self .merge_cnts [target_id_1 ] += 1
525+ self .merge_cnts [target_id_2 ] += 1
524526
525527 # Report Runtime
526528 t , unit = utils .time_writer (time () - t0 )
@@ -541,6 +543,8 @@ def locate(self, target_id_1, target_id_2, label):
541543 if utils .dist (xyz_1 , xyz_2 ) < min_dist :
542544 min_dist = utils .dist (xyz_1 , xyz_2 )
543545 xyz_pair = [xyz_1 , xyz_2 ]
546+ if min_dist < MERGE_DIST_THRESHOLD :
547+ return xyz_pair , min_dist
544548 return xyz_pair , min_dist
545549
546550 def near_bdd (self , xyz ):
@@ -787,16 +791,11 @@ def list_metrics(self):
787791 return metrics
788792
789793 def save_swc (self , xyz_1 , xyz_2 , mistake_type ):
794+ self .saved_site_cnt += 1
790795 xyz_1 = utils .to_world (xyz_1 , self .anisotropy )
791796 xyz_2 = utils .to_world (xyz_2 , self .anisotropy )
792- if mistake_type == "split" :
793- color = "0.0 1.0 0.0"
794- cnt = 1 + np .sum (list (self .split_cnts .values ())) // 2
795- else :
796- color = "0.0 0.0 1.0"
797- cnt = 1 + np .sum (list (self .merge_cnts .values ())) // 2
798-
799- path = f"{ self .output_dir } /{ mistake_type } -{ cnt } .swc"
797+ color = "0.0 1.0 0.0" if mistake_type == "split" else "0.0 0.0 1.0"
798+ path = f"{ self .output_dir } /{ mistake_type } -{ self .saved_site_cnt } .swc"
800799 save (path , xyz_1 , xyz_2 , color = color )
801800
802801
0 commit comments