Skip to content

Commit 0686e8d

Browse files
anna-grimanna-grim
andauthored
upds (#43)
Co-authored-by: anna-grim <[email protected]>
1 parent 2856776 commit 0686e8d

File tree

1 file changed

+18
-24
lines changed

1 file changed

+18
-24
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,6 @@ def in_black_hole(self, xyz, print_nn=False):
111111
# Search black_holes
112112
radius = self.black_hole_radius
113113
pts = self.black_holes.query_ball_point(xyz, radius)
114-
if print_nn:
115-
dd, ii = self.black_holes.query([xyz], k=[1])
116-
print("Nearest neighbor:", dd)
117114
if len(pts) > 0:
118115
return True
119116
else:
@@ -430,7 +427,6 @@ def is_nonzero_misalignment(
430427
origin_label = pred_graph.nodes[nb]["pred_id"]
431428
hit_label = pred_graph.nodes[root]["pred_id"]
432429
parent = nb
433-
depth = 0
434430

435431
# Search
436432
queue = [root]
@@ -439,22 +435,20 @@ def is_nonzero_misalignment(
439435
j = queue.pop(0)
440436
label_j = pred_graph.nodes[j]["pred_id"]
441437
visited.add(j)
442-
depth += 1
443438
if label_j == origin_label:
444439
# misalignment
445440
pred_graph = gutils.upd_labels(
446441
pred_graph, visited, origin_label
447442
)
448443
return dfs_edges, pred_graph
449-
elif label_j == hit_label and depth < 16:
444+
elif label_j == hit_label:
450445
# continue search
451446
nbs = list(target_graph.neighbors(j))
452447
nbs.remove(parent)
453448
if len(nbs) == 1:
454-
if utils.check_edge(dfs_edges, (j, nbs[0])):
455-
parent = j
456-
queue.append(nbs[0])
457-
dfs_edges = remove_edge(dfs_edges, (j, nbs[0]))
449+
parent = j
450+
queue.append(nbs[0])
451+
dfs_edges = remove_edge(dfs_edges, (j, nbs[0]))
458452
else:
459453
pred_graph = gutils.remove_edge(pred_graph, nb, root)
460454
return dfs_edges, pred_graph
@@ -523,20 +517,16 @@ def detect_merges(self):
523517
pred_ids_2 = self.get_pred_ids(swc_id_2)
524518
intersection = pred_ids_1.intersection(pred_ids_2)
525519
for label in intersection:
526-
#merged_1 = self.label_to_node[swc_id_1][label]
527-
#merged_2 = self.label_to_node[swc_id_2][label]
528-
# too_small = min(len(merged_1), len(merged_2)) > 16
529-
if True: # not too_small:
530-
sites, dist = self.localize(swc_id_1, swc_id_2, label)
531-
xyz = utils.get_midpoint(sites[0], sites[1])
532-
if dist > 20 and not self.near_bdd(xyz):
533-
# Write site to swc
534-
if self.write_to_swc:
535-
self.save_swc(sites[0], sites[1], "merge")
536-
537-
# Process merge
538-
self.process_merge(swc_id_1, label)
539-
self.process_merge(swc_id_2, label)
520+
sites, dist = self.localize(swc_id_1, swc_id_2, label)
521+
xyz = utils.get_midpoint(sites[0], sites[1])
522+
if True: #dist > 20 and not self.near_bdd(xyz):
523+
# Write site to swc
524+
if self.write_to_swc:
525+
self.save_swc(sites[0], sites[1], "merge")
526+
527+
# Process merge
528+
self.process_merge(swc_id_1, label)
529+
self.process_merge(swc_id_2, label)
540530

541531
# Remove label to avoid reprocessing
542532
del self.label_to_node[swc_id_1][label]
@@ -660,6 +650,10 @@ def compile_results(self):
660650
# Summarize results
661651
swc_ids, results = self.generate_report()
662652
avg_results = dict([(k, np.mean(v)) for k, v in results.items()])
653+
654+
# Adjust certain stats
655+
n_detected_neurons = np.sum(np.array(results["% omit edges"]) < 1)
656+
avg_results["# splits"] = np.sum(results["# splits"]) / n_detected_neurons
663657
avg_results["# merges"] = avg_results["# merges"] / 2
664658

665659
# Reformat full results

0 commit comments

Comments
 (0)