Skip to content

Commit 019ed4d

Browse files
anna-grimanna-grim
andauthored
Test merge detection (#45)
* upds * bug: averaged results --------- Co-authored-by: anna-grim <[email protected]>
1 parent c288821 commit 019ed4d

File tree

1 file changed

+51
-28
lines changed

1 file changed

+51
-28
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,8 @@ def is_zero_misalignment(
412412
if utils.check_edge(dfs_edges, (j, k)):
413413
queue.append(k)
414414
dfs_edges = remove_edge(dfs_edges, (j, k))
415+
elif k == nb:
416+
queue.append(k)
415417

416418
# Upd zero nodes
417419
if len(collision_labels) == 1 and not black_hole:
@@ -426,16 +428,15 @@ def is_nonzero_misalignment(
426428
# Initialize
427429
origin_label = pred_graph.nodes[nb]["pred_id"]
428430
hit_label = pred_graph.nodes[root]["pred_id"]
429-
parent = nb
430431

431432
# Search
432-
queue = [root]
433+
queue = [(nb, root)]
433434
visited = set([nb])
434435
while len(queue) > 0:
435-
j = queue.pop(0)
436+
parent, j = queue.pop(0)
436437
label_j = pred_graph.nodes[j]["pred_id"]
437438
visited.add(j)
438-
if label_j == origin_label:
439+
if label_j == origin_label and len(queue) == 0:
439440
# misalignment
440441
pred_graph = gutils.upd_labels(
441442
pred_graph, visited, origin_label
@@ -444,20 +445,19 @@ def is_nonzero_misalignment(
444445
elif label_j == hit_label:
445446
# continue search
446447
nbs = list(target_graph.neighbors(j))
447-
nbs.remove(parent)
448-
if len(nbs) == 1:
449-
parent = j
450-
queue.append(nbs[0])
451-
dfs_edges = remove_edge(dfs_edges, (j, nbs[0]))
452-
else:
453-
pred_graph = gutils.remove_edge(pred_graph, nb, root)
454-
return dfs_edges, pred_graph
448+
for k in [k for k in nbs if k not in visited]:
449+
queue.append((j, k))
450+
dfs_edges = remove_edge(dfs_edges, (j, k))
455451
else:
456452
# left hit label
457453
dfs_edges.insert(0, (parent, j))
458454
pred_graph = gutils.remove_edge(pred_graph, nb, root)
459455
return dfs_edges, pred_graph
460456

457+
# End of search
458+
pred_graph = gutils.remove_edge(pred_graph, nb, root)
459+
return dfs_edges, pred_graph
460+
461461
def quantify_splits(self):
462462
"""
463463
Counts the number of splits, number of omit edges, and percent of omit
@@ -483,6 +483,7 @@ def quantify_splits(self):
483483
self.split_cnts[swc_id] = n_splits
484484
self.omit_cnts[swc_id] = n_target_edges - n_pred_edges
485485
self.omit_percents[swc_id] = 1 - n_pred_edges / n_target_edges
486+
print(swc_id, n_pred_edges / n_target_edges)
486487

487488
def detect_merges(self):
488489
"""
@@ -519,7 +520,7 @@ def detect_merges(self):
519520
for label in intersection:
520521
sites, dist = self.localize(swc_id_1, swc_id_2, label)
521522
xyz = utils.get_midpoint(sites[0], sites[1])
522-
if True: #dist > 20 and not self.near_bdd(xyz):
523+
if dist > 20 and not self.near_bdd(xyz):
523524
# Write site to swc
524525
if self.write_to_swc:
525526
self.save_swc(sites[0], sites[1], "merge")
@@ -648,13 +649,8 @@ def compile_results(self):
648649
self.compute_erl()
649650

650651
# Summarize results
651-
swc_ids, results = self.generate_report()
652-
avg_results = dict([(k, np.mean(v)) for k, v in results.items()])
653-
654-
# Adjust certain stats
655-
n_detected_neurons = np.sum(np.array(results["% omit edges"]) < 1)
656-
avg_results["# splits"] = np.sum(results["# splits"]) / n_detected_neurons
657-
avg_results["# merges"] = avg_results["# merges"] / 2
652+
swc_ids, results = self.generate_full_results()
653+
avg_results = self.generate_avg_results()
658654

659655
# Reformat full results
660656
full_results = dict()
@@ -665,7 +661,7 @@ def compile_results(self):
665661

666662
return full_results, avg_results
667663

668-
def generate_report(self):
664+
def generate_full_results(self):
669665
"""
670666
Generates a report by creating a list of the results for each metric.
671667
Each item in this list corresponds to a graph in "self.pred_graphs"
@@ -689,14 +685,33 @@ def generate_report(self):
689685
stats = {
690686
"# splits": generate_result(swc_ids, self.split_cnts),
691687
"# merges": generate_result(swc_ids, self.merge_cnts),
692-
"% omit edges": generate_result(swc_ids, self.omit_percents),
693-
"% merged edges": generate_result(swc_ids, self.merged_percents),
688+
"% omit": generate_result(swc_ids, self.omit_percents),
689+
"% merged": generate_result(swc_ids, self.merged_percents),
694690
"edge accuracy": generate_result(swc_ids, self.edge_accuracy),
695691
"erl": generate_result(swc_ids, self.erl),
696692
"normalized erl": generate_result(swc_ids, self.normalized_erl),
697693
}
698694
return swc_ids, stats
699695

696+
def generate_avg_results(self):
697+
avg_stats = {
698+
"# splits": self.avg_result(self.split_cnts),
699+
"# merges": self.avg_result(self.merge_cnts),
700+
"% omit": self.avg_result(self.omit_percents),
701+
"% merged": self.avg_result(self.merged_percents),
702+
"edge accuracy": self.avg_result(self.edge_accuracy),
703+
"erl": self.avg_result(self.erl),
704+
"normalized erl": self.avg_result(self.normalized_erl),
705+
}
706+
return avg_stats
707+
708+
def avg_result(self, stats):
709+
result = []
710+
for swc_id, wgt in self.wgts.items():
711+
if self.omit_percents[swc_id] < 1:
712+
result.append(wgt * stats[swc_id])
713+
return np.sum(result)
714+
700715
def compute_edge_accuracy(self):
701716
"""
702717
Computes the edge accuracy of each pred_graph.
@@ -731,17 +746,25 @@ def compute_erl(self):
731746
"""
732747
self.erl = dict()
733748
self.normalized_erl = dict()
749+
self.wgts = dict()
750+
total_path_length = 0
734751
for swc_id in self.target_graphs.keys():
735752
pred_graph = self.pred_graphs[swc_id]
736753
target_graph = self.target_graphs[swc_id]
737754

738755
path_length = gutils.compute_path_length(target_graph)
739-
path_lengths = gutils.compute_run_lengths(pred_graph)
740-
wgts = path_lengths / max(np.sum(path_lengths), 1)
756+
run_lengths = gutils.compute_run_lengths(pred_graph)
757+
wgt = run_lengths / np.sum(run_lengths)
741758

742-
self.erl[swc_id] = np.sum(wgts * path_lengths)
759+
self.erl[swc_id] = np.sum(wgt * run_lengths)
743760
self.normalized_erl[swc_id] = self.erl[swc_id] / path_length
744761

762+
self.wgts[swc_id] = path_length
763+
total_path_length += path_length
764+
765+
for swc_id in self.target_graphs.keys():
766+
self.wgts[swc_id] = self.wgts[swc_id] / total_path_length
767+
745768
def list_metrics(self):
746769
"""
747770
Lists metrics that are computed by this module.
@@ -759,8 +782,8 @@ def list_metrics(self):
759782
metrics = [
760783
"# splits",
761784
"# merges",
762-
"% omit edges",
763-
"% merged edges",
785+
"% omit",
786+
"% merged",
764787
"edge accuracy",
765788
"erl",
766789
"normalized erl",

0 commit comments

Comments
 (0)