Skip to content

Commit 82bf082

Browse files
anna-grimanna-grim
andauthored
Feat add cablelength (#169)
* refactor: improved find merge * removed test blocks --------- Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
1 parent d6936cf commit 82bf082

File tree

1 file changed

+17
-19
lines changed

1 file changed

+17
-19
lines changed

src/segmentation_skeleton_metrics/skeleton_metrics.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"""
1111

1212
from abc import ABC, abstractmethod
13-
from copy import deepcopy
13+
from collections import deque
1414
from scipy.spatial import KDTree
1515
from tqdm import tqdm
1616

@@ -410,7 +410,7 @@ def __call__(self, gt_graphs, fragment_graphs):
410410
labels = gt_graph.get_node_labels()
411411
for fragment_graph in fragment_graphs.values():
412412
if fragment_graph.label in labels:
413-
self.search_for_merges(gt_graph, deepcopy(fragment_graph))
413+
self.search_for_merges(gt_graph, fragment_graph)
414414

415415
# Update progress bar
416416
if self.verbose:
@@ -441,12 +441,6 @@ def search_for_merges(self, gt_graph, fragment_graph):
441441
fragment_graph : FragmentGraph
442442
Graph corresponding to a segment in the predicted segmentation.
443443
"""
444-
# Remove nodes that are too far
445-
xyz_arr = fragment_graph.voxels * fragment_graph.anisotropy
446-
dists, _ = gt_graph.kdtree.query(xyz_arr)
447-
fragment_graph.remove_nodes_from(np.where(dists > 200)[0])
448-
449-
# Search remaining graph
450444
visited = set()
451445
for leaf in util.get_leafs(fragment_graph):
452446
# Check whether to visit
@@ -479,19 +473,23 @@ def find_merge_site(self, gt_graph, fragment_graph, source, visited):
479473
Node IDs from "fragment_graphs" that have already been visited,
480474
used to avoid redundant exploration.
481475
"""
482-
# Traverse until close to ground truth
483-
for _, node in nx.dfs_edges(fragment_graph, source=source):
484-
# Check whether to visit
485-
if node in visited or visited.add(node):
486-
continue
487-
488-
# Check if close to ground truth
489-
xyz = fragment_graph.get_xyz(node)
490-
dist, gt_node = gt_graph.kdtree.query(xyz)
491-
if dist < 6:
492-
self.verify_site(gt_graph, fragment_graph, gt_node, node)
476+
queue = deque([source])
477+
visited.add(source)
478+
while len(queue) > 0:
479+
# Visit node
480+
i = queue.pop()
481+
xyz_i = fragment_graph.get_xyz(i)
482+
dist_i, gt_node = gt_graph.kdtree.query(xyz_i)
483+
if dist_i < 6:
484+
self.verify_site(gt_graph, fragment_graph, gt_node, i)
493485
break
494486

487+
# Update queue
488+
for j in fragment_graph.neighbors(i):
489+
if j not in visited:
490+
queue.append(j)
491+
visited.add(j)
492+
495493
def verify_site(self, gt_graph, fragment_graph, gt_node, fragment_node):
496494
"""
497495
Verifies whether a given site in a fragment graph corresponds to a

0 commit comments

Comments
 (0)