1818
1919from segmentation_skeleton_metrics import graph_utils as gutils
2020from segmentation_skeleton_metrics import utils
21- from segmentation_skeleton_metrics .swc_utils import to_graph
21+ from segmentation_skeleton_metrics .swc_utils import save , to_graph
2222
2323
2424class SkeletonMetric :
@@ -42,7 +42,7 @@ def __init__(
4242 labels ,
4343 anisotropy = [1.0 , 1.0 , 1.0 ],
4444 ignore_boundary_mistakes = False ,
45- black_holes = None ,
45+ black_holes_xyz_id = None ,
4646 black_hole_radius = 24 ,
4747 equivalent_ids = None ,
4848 valid_ids = None ,
@@ -63,7 +63,7 @@ def __init__(
6363 anisotropy : list[float], optional
6464 Image to real-world coordinates scaling factors applied to swc
6565 files. The default is [1.0, 1.0, 1.0]
66- black_holes : numpy.ndarray
66+ black_holes_xyz_id : list
6767 ...
6868 black_hole_radius : float
6969 ...
@@ -80,32 +80,40 @@ def __init__(
8080 # Store label options
8181 self .valid_ids = valid_ids
8282 self .labels = labels
83+
84+ self .anisotropy = anisotropy
8385 self .ignore_boundary_mistakes = ignore_boundary_mistakes
84- self .black_hole_labels = set ( )
86+ self .init_black_holes ( black_holes_xyz_id )
8587 self .black_hole_radius = black_hole_radius
86- self . init_black_holes ( black_holes )
87- self .write_to_swc = False
88+
89+ self .write_to_swc = write_to_swc
8890 self .output_dir = output_dir
8991
9092 # Build Graphs
9193 self .init_target_graphs (swc_paths , anisotropy )
9294 self .init_pred_graphs ()
93- self .black_hole_labels .discard (0 )
9495
9596 def init_black_holes (self , black_holes ):
9697 if black_holes :
97- self .black_holes = KDTree (black_holes )
98+ black_holes_xyz = [bh_dict ["xyz" ] for bh_dict in black_holes ]
99+ black_holes_id = [bh_dict ["swc_id" ] for bh_dict in black_holes ]
100+ self .black_holes = KDTree (black_holes_xyz )
101+ self .black_hole_labels = set (black_holes_id )
98102 else :
99103 self .black_holes = None
104+ self .black_hole_labels = set ()
100105
101- def in_black_hole (self , xyz ):
106+ def in_black_hole (self , xyz , print_nn = False ):
102107 # Check whether black_holes exists
103108 if self .black_holes is None :
104109 return False
105110
106111 # Search black_holes
107112 radius = self .black_hole_radius
108113 pts = self .black_holes .query_ball_point (xyz , radius )
114+ if print_nn :
115+ dd , ii = self .black_holes .query ([xyz ], k = [1 ])
116+ print ("Nearest neighbor:" , dd )
109117 if len (pts ) > 0 :
110118 return True
111119 else :
@@ -188,6 +196,7 @@ def label_graph(self, target_graph):
188196 for i in pred_graph .nodes :
189197 img_coord = gutils .get_coord (pred_graph , i )
190198 threads .append (executor .submit (self .get_label , img_coord , i ))
199+
191200 # Store results
192201 for thread in as_completed (threads ):
193202 i , label = thread .result ()
@@ -215,7 +224,6 @@ def get_label(self, img_coord, return_node=False):
215224 """
216225 label = self .__read_label (img_coord )
217226 if self .in_black_hole (img_coord ):
218- self .black_hole_labels .add (label )
219227 label = - 1
220228 return self .output_label (label , return_node )
221229
@@ -337,9 +345,12 @@ def detect_splits(self):
337345 label_i = pred_graph .nodes [i ]["pred_id" ]
338346 label_j = pred_graph .nodes [j ]["pred_id" ]
339347 if is_split (label_i , label_j ):
340- pred_graph = gutils .remove_edge (pred_graph , i , j )
348+ # pred_graph = gutils.remove_edge(pred_graph, i, j)
349+ dfs_edges , pred_graph = self .is_nonzero_misalignment (
350+ target_graph , pred_graph , dfs_edges , i , j
351+ )
341352 elif label_j == 0 or label_j == - 1 :
342- dfs_edges , pred_graph = self .split_search (
353+ dfs_edges , pred_graph = self .is_zero_misalignment (
343354 target_graph , pred_graph , dfs_edges , i , j
344355 )
345356
@@ -354,7 +365,9 @@ def detect_splits(self):
354365 t , unit = utils .time_writer (time () - t0 )
355366 print (f"\n Runtime: { round (t , 2 )} { unit } \n " )
356367
357- def split_search (self , target_graph , pred_graph , dfs_edges , nb , root ):
368+ def is_zero_misalignment (
369+ self , target_graph , pred_graph , dfs_edges , nb , root
370+ ):
358371 """
359372 Determines whether zero-valued labels correspond to a split or
360373 misalignment between "target_graph" and the predicted segmentation
@@ -382,33 +395,76 @@ def split_search(self, target_graph, pred_graph, dfs_edges, nb, root):
382395
383396 """
384397 # Search
398+ black_hole = False
399+ collision_labels = set ([pred_graph .nodes [nb ]["pred_id" ]])
385400 queue = [root ]
386401 visited = set ()
387- collision_labels = set ()
388- collision_nodes = set ()
389402 while len (queue ) > 0 :
390403 j = queue .pop (0 )
391404 label_j = pred_graph .nodes [j ]["pred_id" ]
392405 visited .add (j )
393406 if label_j > 0 :
394407 collision_labels .add (label_j )
395408 else :
409+ # Check for black hole
410+ if label_j == - 1 :
411+ black_hole = True
412+
413+ # Add nbs to queue
396414 nbs = target_graph .neighbors (j )
397415 for k in [k for k in nbs if k not in visited ]:
398416 if utils .check_edge (dfs_edges , (j , k )):
399417 queue .append (k )
400418 dfs_edges = remove_edge (dfs_edges , (j , k ))
401- elif k == nb :
402- queue .append (k )
403419
404420 # Upd zero nodes
405- if len (collision_labels ) == 1 :
421+ if len (collision_labels ) == 1 and not black_hole :
406422 label = collision_labels .pop ()
407- visited = visited .difference (collision_nodes )
408423 pred_graph = gutils .upd_labels (pred_graph , visited , label )
409424
410425 return dfs_edges , pred_graph
411426
427+ def is_nonzero_misalignment (
428+ self , target_graph , pred_graph , dfs_edges , nb , root
429+ ):
430+ # Initialize
431+ origin_label = pred_graph .nodes [nb ]["pred_id" ]
432+ hit_label = pred_graph .nodes [root ]["pred_id" ]
433+ parent = nb
434+ depth = 0
435+
436+ # Search
437+ queue = [root ]
438+ visited = set ([nb ])
439+ while len (queue ) > 0 :
440+ j = queue .pop (0 )
441+ label_j = pred_graph .nodes [j ]["pred_id" ]
442+ visited .add (j )
443+ depth += 1
444+ if label_j == origin_label :
445+ # misalignment
446+ pred_graph = gutils .upd_labels (
447+ pred_graph , visited , origin_label
448+ )
449+ return dfs_edges , pred_graph
450+ elif label_j == hit_label and depth < 16 :
451+ # continue search
452+ nbs = list (target_graph .neighbors (j ))
453+ nbs .remove (parent )
454+ if len (nbs ) == 1 :
455+ if utils .check_edge (dfs_edges , (j , nbs [0 ])):
456+ parent = j
457+ queue .append (nbs [0 ])
458+ dfs_edges = remove_edge (dfs_edges , (j , nbs [0 ]))
459+ else :
460+ pred_graph = gutils .remove_edge (pred_graph , nb , root )
461+ return dfs_edges , pred_graph
462+ else :
463+ # left hit label
464+ dfs_edges .insert (0 , (parent , j ))
465+ pred_graph = gutils .remove_edge (pred_graph , nb , root )
466+ return dfs_edges , pred_graph
467+
412468 def quantify_splits (self ):
413469 """
414470 Counts the number of splits, number of omit edges, and percent of omit
@@ -468,16 +524,16 @@ def detect_merges(self):
468524 pred_ids_2 = self .get_pred_ids (swc_id_2 )
469525 intersection = pred_ids_1 .intersection (pred_ids_2 )
470526 for label in intersection :
471- merged_1 = self .label_to_node [swc_id_1 ][label ]
472- merged_2 = self .label_to_node [swc_id_2 ][label ]
473- too_small = min (len (merged_1 ), len (merged_2 )) > 16
474- if not too_small :
475- site , dist = self .localize (swc_id_1 , swc_id_2 , label )
476- near_bdd = self . near_bdd ( site )
477- if not near_bdd :
527+ # merged_1 = self.label_to_node[swc_id_1][label]
528+ # merged_2 = self.label_to_node[swc_id_2][label]
529+ # too_small = min(len(merged_1), len(merged_2)) > 16
530+ if True : # not too_small:
531+ sites , dist = self .localize (swc_id_1 , swc_id_2 , label )
532+ xyz = utils . get_midpoint ( sites [ 0 ], sites [ 1 ] )
533+ if dist > 20 and not self . near_bdd ( xyz ) :
478534 # Write site to swc
479535 if self .write_to_swc :
480- self .save_swc (site [0 ], site [1 ], "merge" )
536+ self .save_swc (sites [0 ], sites [1 ], "merge" )
481537
482538 # Process merge
483539 self .process_merge (swc_id_1 , label )
@@ -509,7 +565,6 @@ def localize(self, swc_id_1, swc_id_2, label):
509565 xyz_pair = [xyz_1 , xyz_2 ]
510566 return xyz_pair , min_dist
511567
512-
513568 def near_bdd (self , xyz_pair ):
514569 near_bdd_bool = False
515570 if self .ignore_boundary_mistakes :
@@ -720,14 +775,17 @@ def list_metrics(self):
720775 return metrics
721776
722777 def save_swc (self , xyz_1 , xyz_2 , mistake_type ):
778+ xyz_1 = utils .to_world (xyz_1 , self .anisotropy )
779+ xyz_2 = utils .to_world (xyz_2 , self .anisotropy )
723780 if mistake_type == "split" :
724781 color = "1.0 0.0 0.0"
725- cnt = 1 + np .sum (self .split_cnts ) // 2
782+ cnt = 1 + np .sum (list ( self .split_cnts . values ()) ) // 2
726783 else :
727784 color = "0.0 1.0 0.0"
728- cnt = 1 + np .sum (self .merge_cnts ) // 2
785+ cnt = 1 + np .sum (list ( self .merge_cnts . values ()) ) // 2
729786
730787 path = f"{ self .output_dir } /{ mistake_type } -{ cnt } .swc"
788+ save (path , xyz_1 , xyz_2 , color = color )
731789
732790
733791# -- utils --
0 commit comments