|
10 | 10 | """ |
11 | 11 |
|
12 | 12 | from abc import ABC, abstractmethod |
13 | | -from copy import deepcopy |
| 13 | +from collections import deque |
14 | 14 | from scipy.spatial import KDTree |
15 | 15 | from tqdm import tqdm |
16 | 16 |
|
@@ -410,7 +410,7 @@ def __call__(self, gt_graphs, fragment_graphs): |
410 | 410 | labels = gt_graph.get_node_labels() |
411 | 411 | for fragment_graph in fragment_graphs.values(): |
412 | 412 | if fragment_graph.label in labels: |
413 | | - self.search_for_merges(gt_graph, deepcopy(fragment_graph)) |
| 413 | + self.search_for_merges(gt_graph, fragment_graph) |
414 | 414 |
|
415 | 415 | # Update progress bar |
416 | 416 | if self.verbose: |
@@ -441,12 +441,6 @@ def search_for_merges(self, gt_graph, fragment_graph): |
441 | 441 | fragment_graph : FragmentGraph |
442 | 442 | Graph corresponding to a segment in the predicted segmentation. |
443 | 443 | """ |
444 | | - # Remove nodes that are too far |
445 | | - xyz_arr = fragment_graph.voxels * fragment_graph.anisotropy |
446 | | - dists, _ = gt_graph.kdtree.query(xyz_arr) |
447 | | - fragment_graph.remove_nodes_from(np.where(dists > 200)[0]) |
448 | | - |
449 | | - # Search remaining graph |
450 | 444 | visited = set() |
451 | 445 | for leaf in util.get_leafs(fragment_graph): |
452 | 446 | # Check whether to visit |
@@ -479,19 +473,23 @@ def find_merge_site(self, gt_graph, fragment_graph, source, visited): |
479 | 473 | Node IDs from "fragment_graphs" that have already been visited, |
480 | 474 | used to avoid redundant exploration. |
481 | 475 | """ |
482 | | - # Traverse until close to ground truth |
483 | | - for _, node in nx.dfs_edges(fragment_graph, source=source): |
484 | | - # Check whether to visit |
485 | | - if node in visited or visited.add(node): |
486 | | - continue |
487 | | - |
488 | | - # Check if close to ground truth |
489 | | - xyz = fragment_graph.get_xyz(node) |
490 | | - dist, gt_node = gt_graph.kdtree.query(xyz) |
491 | | - if dist < 6: |
492 | | - self.verify_site(gt_graph, fragment_graph, gt_node, node) |
| 476 | + queue = deque([source]) |
| 477 | + visited.add(source) |
| 478 | + while len(queue) > 0: |
| 479 | + # Visit node |
| 480 | + i = queue.pop() |
| 481 | + xyz_i = fragment_graph.get_xyz(i) |
| 482 | + dist_i, gt_node = gt_graph.kdtree.query(xyz_i) |
| 483 | + if dist_i < 6: |
| 484 | + self.verify_site(gt_graph, fragment_graph, gt_node, i) |
493 | 485 | break |
494 | 486 |
|
| 487 | + # Update queue |
| 488 | + for j in fragment_graph.neighbors(i): |
| 489 | + if j not in visited: |
| 490 | + queue.append(j) |
| 491 | + visited.add(j) |
| 492 | + |
495 | 493 | def verify_site(self, gt_graph, fragment_graph, gt_node, fragment_node): |
496 | 494 | """ |
497 | 495 | Verifies whether a given site in a fragment graph corresponds to a |
|
0 commit comments