Skip to content

Commit fab0849

Browse files
anna-grimanna-grim
andauthored
feat: detect nonzero misalignments (#41)
Co-authored-by: anna-grim <[email protected]>
1 parent 950f53a commit fab0849

File tree

2 files changed

+97
-41
lines changed

2 files changed

+97
-41
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 88 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from segmentation_skeleton_metrics import graph_utils as gutils
2020
from segmentation_skeleton_metrics import utils
21-
from segmentation_skeleton_metrics.swc_utils import to_graph
21+
from segmentation_skeleton_metrics.swc_utils import save, to_graph
2222

2323

2424
class SkeletonMetric:
@@ -42,7 +42,7 @@ def __init__(
4242
labels,
4343
anisotropy=[1.0, 1.0, 1.0],
4444
ignore_boundary_mistakes=False,
45-
black_holes=None,
45+
black_holes_xyz_id=None,
4646
black_hole_radius=24,
4747
equivalent_ids=None,
4848
valid_ids=None,
@@ -63,7 +63,7 @@ def __init__(
6363
anisotropy : list[float], optional
6464
Image to real-world coordinates scaling factors applied to swc
6565
files. The default is [1.0, 1.0, 1.0]
66-
black_holes : numpy.ndarray
66+
black_holes_xyz_id : list
6767
...
6868
black_hole_radius : float
6969
...
@@ -80,32 +80,40 @@ def __init__(
8080
# Store label options
8181
self.valid_ids = valid_ids
8282
self.labels = labels
83+
84+
self.anisotropy = anisotropy
8385
self.ignore_boundary_mistakes = ignore_boundary_mistakes
84-
self.black_hole_labels = set()
86+
self.init_black_holes(black_holes_xyz_id)
8587
self.black_hole_radius = black_hole_radius
86-
self.init_black_holes(black_holes)
87-
self.write_to_swc = False
88+
89+
self.write_to_swc = write_to_swc
8890
self.output_dir = output_dir
8991

9092
# Build Graphs
9193
self.init_target_graphs(swc_paths, anisotropy)
9294
self.init_pred_graphs()
93-
self.black_hole_labels.discard(0)
9495

9596
def init_black_holes(self, black_holes):
9697
if black_holes:
97-
self.black_holes = KDTree(black_holes)
98+
black_holes_xyz = [bh_dict["xyz"] for bh_dict in black_holes]
99+
black_holes_id = [bh_dict["swc_id"] for bh_dict in black_holes]
100+
self.black_holes = KDTree(black_holes_xyz)
101+
self.black_hole_labels = set(black_holes_id)
98102
else:
99103
self.black_holes = None
104+
self.black_hole_labels = set()
100105

101-
def in_black_hole(self, xyz):
106+
def in_black_hole(self, xyz, print_nn=False):
102107
# Check whether black_holes exists
103108
if self.black_holes is None:
104109
return False
105110

106111
# Search black_holes
107112
radius = self.black_hole_radius
108113
pts = self.black_holes.query_ball_point(xyz, radius)
114+
if print_nn:
115+
dd, ii = self.black_holes.query([xyz], k=[1])
116+
print("Nearest neighbor:", dd)
109117
if len(pts) > 0:
110118
return True
111119
else:
@@ -188,6 +196,7 @@ def label_graph(self, target_graph):
188196
for i in pred_graph.nodes:
189197
img_coord = gutils.get_coord(pred_graph, i)
190198
threads.append(executor.submit(self.get_label, img_coord, i))
199+
191200
# Store results
192201
for thread in as_completed(threads):
193202
i, label = thread.result()
@@ -215,7 +224,6 @@ def get_label(self, img_coord, return_node=False):
215224
"""
216225
label = self.__read_label(img_coord)
217226
if self.in_black_hole(img_coord):
218-
self.black_hole_labels.add(label)
219227
label = -1
220228
return self.output_label(label, return_node)
221229

@@ -337,9 +345,12 @@ def detect_splits(self):
337345
label_i = pred_graph.nodes[i]["pred_id"]
338346
label_j = pred_graph.nodes[j]["pred_id"]
339347
if is_split(label_i, label_j):
340-
pred_graph = gutils.remove_edge(pred_graph, i, j)
348+
# pred_graph = gutils.remove_edge(pred_graph, i, j)
349+
dfs_edges, pred_graph = self.is_nonzero_misalignment(
350+
target_graph, pred_graph, dfs_edges, i, j
351+
)
341352
elif label_j == 0 or label_j == -1:
342-
dfs_edges, pred_graph = self.split_search(
353+
dfs_edges, pred_graph = self.is_zero_misalignment(
343354
target_graph, pred_graph, dfs_edges, i, j
344355
)
345356

@@ -354,7 +365,9 @@ def detect_splits(self):
354365
t, unit = utils.time_writer(time() - t0)
355366
print(f"\nRuntime: {round(t, 2)} {unit}\n")
356367

357-
def split_search(self, target_graph, pred_graph, dfs_edges, nb, root):
368+
def is_zero_misalignment(
369+
self, target_graph, pred_graph, dfs_edges, nb, root
370+
):
358371
"""
359372
Determines whether zero-valued labels correspond to a split or
360373
misalignment between "target_graph" and the predicted segmentation
@@ -382,33 +395,76 @@ def split_search(self, target_graph, pred_graph, dfs_edges, nb, root):
382395
383396
"""
384397
# Search
398+
black_hole = False
399+
collision_labels = set([pred_graph.nodes[nb]["pred_id"]])
385400
queue = [root]
386401
visited = set()
387-
collision_labels = set()
388-
collision_nodes = set()
389402
while len(queue) > 0:
390403
j = queue.pop(0)
391404
label_j = pred_graph.nodes[j]["pred_id"]
392405
visited.add(j)
393406
if label_j > 0:
394407
collision_labels.add(label_j)
395408
else:
409+
# Check for black hole
410+
if label_j == -1:
411+
black_hole = True
412+
413+
# Add nbs to queue
396414
nbs = target_graph.neighbors(j)
397415
for k in [k for k in nbs if k not in visited]:
398416
if utils.check_edge(dfs_edges, (j, k)):
399417
queue.append(k)
400418
dfs_edges = remove_edge(dfs_edges, (j, k))
401-
elif k == nb:
402-
queue.append(k)
403419

404420
# Upd zero nodes
405-
if len(collision_labels) == 1:
421+
if len(collision_labels) == 1 and not black_hole:
406422
label = collision_labels.pop()
407-
visited = visited.difference(collision_nodes)
408423
pred_graph = gutils.upd_labels(pred_graph, visited, label)
409424

410425
return dfs_edges, pred_graph
411426

427+
def is_nonzero_misalignment(
428+
self, target_graph, pred_graph, dfs_edges, nb, root
429+
):
430+
# Initialize
431+
origin_label = pred_graph.nodes[nb]["pred_id"]
432+
hit_label = pred_graph.nodes[root]["pred_id"]
433+
parent = nb
434+
depth = 0
435+
436+
# Search
437+
queue = [root]
438+
visited = set([nb])
439+
while len(queue) > 0:
440+
j = queue.pop(0)
441+
label_j = pred_graph.nodes[j]["pred_id"]
442+
visited.add(j)
443+
depth += 1
444+
if label_j == origin_label:
445+
# misalignment
446+
pred_graph = gutils.upd_labels(
447+
pred_graph, visited, origin_label
448+
)
449+
return dfs_edges, pred_graph
450+
elif label_j == hit_label and depth < 16:
451+
# continue search
452+
nbs = list(target_graph.neighbors(j))
453+
nbs.remove(parent)
454+
if len(nbs) == 1:
455+
if utils.check_edge(dfs_edges, (j, nbs[0])):
456+
parent = j
457+
queue.append(nbs[0])
458+
dfs_edges = remove_edge(dfs_edges, (j, nbs[0]))
459+
else:
460+
pred_graph = gutils.remove_edge(pred_graph, nb, root)
461+
return dfs_edges, pred_graph
462+
else:
463+
# left hit label
464+
dfs_edges.insert(0, (parent, j))
465+
pred_graph = gutils.remove_edge(pred_graph, nb, root)
466+
return dfs_edges, pred_graph
467+
412468
def quantify_splits(self):
413469
"""
414470
Counts the number of splits, number of omit edges, and percent of omit
@@ -468,16 +524,16 @@ def detect_merges(self):
468524
pred_ids_2 = self.get_pred_ids(swc_id_2)
469525
intersection = pred_ids_1.intersection(pred_ids_2)
470526
for label in intersection:
471-
merged_1 = self.label_to_node[swc_id_1][label]
472-
merged_2 = self.label_to_node[swc_id_2][label]
473-
too_small = min(len(merged_1), len(merged_2)) > 16
474-
if not too_small:
475-
site, dist = self.localize(swc_id_1, swc_id_2, label)
476-
near_bdd = self.near_bdd(site)
477-
if not near_bdd:
527+
#merged_1 = self.label_to_node[swc_id_1][label]
528+
#merged_2 = self.label_to_node[swc_id_2][label]
529+
# too_small = min(len(merged_1), len(merged_2)) > 16
530+
if True: # not too_small:
531+
sites, dist = self.localize(swc_id_1, swc_id_2, label)
532+
xyz = utils.get_midpoint(sites[0], sites[1])
533+
if dist > 20 and not self.near_bdd(xyz):
478534
# Write site to swc
479535
if self.write_to_swc:
480-
self.save_swc(site[0], site[1], "merge")
536+
self.save_swc(sites[0], sites[1], "merge")
481537

482538
# Process merge
483539
self.process_merge(swc_id_1, label)
@@ -509,7 +565,6 @@ def localize(self, swc_id_1, swc_id_2, label):
509565
xyz_pair = [xyz_1, xyz_2]
510566
return xyz_pair, min_dist
511567

512-
513568
def near_bdd(self, xyz_pair):
514569
near_bdd_bool = False
515570
if self.ignore_boundary_mistakes:
@@ -720,14 +775,17 @@ def list_metrics(self):
720775
return metrics
721776

722777
def save_swc(self, xyz_1, xyz_2, mistake_type):
778+
xyz_1 = utils.to_world(xyz_1, self.anisotropy)
779+
xyz_2 = utils.to_world(xyz_2, self.anisotropy)
723780
if mistake_type == "split":
724781
color = "1.0 0.0 0.0"
725-
cnt = 1 + np.sum(self.split_cnts) // 2
782+
cnt = 1 + np.sum(list(self.split_cnts.values())) // 2
726783
else:
727784
color = "0.0 1.0 0.0"
728-
cnt = 1 + np.sum(self.merge_cnts) // 2
785+
cnt = 1 + np.sum(list(self.merge_cnts.values())) // 2
729786

730787
path = f"{self.output_dir}/{mistake_type}-{cnt}.swc"
788+
save(path, xyz_1, xyz_2, color=color)
731789

732790

733791
# -- utils --

src/segmentation_skeleton_metrics/swc_utils.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
import networkx as nx
1313
import numpy as np
1414

15-
from segmentation_skeleton_metrics import graph_utils as gutils
16-
1715

1816
def make_entry(node_id, parent_id, xyz):
1917
"""
@@ -32,12 +30,12 @@ def make_entry(node_id, parent_id, xyz):
3230
applied to swc files.
3331
3432
"""
35-
x, y, z = tuple(xyz.tolist())
36-
entry = f"{node_id} 2 {x} {y} {z} 7.5 {parent_id}"
33+
x, y, z = tuple(xyz)
34+
entry = f"{node_id} 2 {x} {y} {z} 8 {parent_id}"
3735
return entry
3836

3937

40-
def write_swc(path, xyz_1, xyz_2, color=None):
38+
def save(path, xyz_1, xyz_2, color=None):
4139
"""
4240
Writes an swc file.
4341
@@ -55,18 +53,18 @@ def write_swc(path, xyz_1, xyz_2, color=None):
5553
None.
5654
5755
"""
58-
# Preamble
5956
with open(path, "w") as f:
57+
# Preamble
6058
if color is not None:
61-
f.write("# COLOR" + color)
59+
f.write("# COLOR " + color)
6260
else:
6361
f.write("# id, type, z, y, x, r, pid")
6462
f.write("\n")
6563

66-
# Entries
67-
f.write(write_entry(1, -1, xyz_1))
68-
f.write("\n")
69-
f.write(write_entry(2, 1, xyz_2))
64+
# Entries
65+
f.write(make_entry(1, -1, xyz_1))
66+
f.write("\n")
67+
f.write(make_entry(2, 1, xyz_2))
7068

7169

7270
def to_graph(path, anisotropy=[1.0, 1.0, 1.0]):

0 commit comments

Comments
 (0)