|
10 | 10 | """ |
11 | 11 |
|
12 | 12 | from abc import ABC, abstractmethod |
| 13 | +from copy import deepcopy |
13 | 14 | from collections import deque |
14 | 15 | from scipy.spatial import KDTree |
15 | 16 | from tqdm import tqdm |
@@ -363,6 +364,7 @@ class MergeCountMetric(SkeletonMetric): |
363 | 364 | """ |
364 | 365 | A skeleton metric subclass that counts the number merges. |
365 | 366 | """ |
| 367 | + merge_dist_threshold = 50 |
366 | 368 |
|
367 | 369 | def __init__(self, verbose=True): |
368 | 370 | """ |
@@ -452,7 +454,7 @@ def search_for_merges(self, gt_graph, fragment_graph): |
452 | 454 | dist, _ = gt_graph.kdtree.query(xyz) |
453 | 455 |
|
454 | 456 | # Check if distance to ground truth flags a merge mistake |
455 | | - if dist > 50: |
| 457 | + if dist > MergeCountMetric.merge_dist_threshold: |
456 | 458 | self.find_merge_site(gt_graph, fragment_graph, leaf, visited) |
457 | 459 |
|
458 | 460 | def find_merge_site(self, gt_graph, fragment_graph, source, visited): |
@@ -523,10 +525,11 @@ def verify_site(self, gt_graph, fragment_graph, gt_node, fragment_node): |
523 | 525 | self.fragments_with_merge.add(fragment_graph.name) |
524 | 526 | self.merge_sites.append( |
525 | 527 | { |
526 | | - "Segment_ID": fragment_graph.segment_id, |
| 528 | + "Segment_ID": fragment_graph.name, |
527 | 529 | "GroundTruth_ID": gt_graph.name, |
528 | 530 | "Voxel": tuple(map(int, voxel)), |
529 | 531 | "World": tuple([float(round(t, 2)) for t in xyz]), |
| 532 | + "Added Cable Length (μm)": 0.0 |
530 | 533 | } |
531 | 534 | ) |
532 | 535 |
|
@@ -877,3 +880,100 @@ def __call__(self, gt_graphs, results): |
877 | 880 | if self.verbose: |
878 | 881 | pbar.update(1) |
879 | 882 | return self.reformat(new_results) |
| 883 | + |
| 884 | + |
| 885 | +class AddedCableLengthMetric(SkeletonMetric): |
| 886 | + """ |
| 887 | + A skeleton metric subclass that computes added cable length. |
| 888 | + """ |
| 889 | + |
| 890 | + def __init__(self, verbose=True): |
| 891 | + """ |
| 892 | + Instantiates an AddedCableLengthMetric object. |
| 893 | +
|
| 894 | + Parameters |
| 895 | + ---------- |
| 896 | + verbose : bool, optional |
| 897 | + Indication of whether to display a progress bar. Default is True. |
| 898 | + """ |
| 899 | + # Call parent class |
| 900 | + super().__init__(verbose=verbose) |
| 901 | + |
| 902 | + # Instance attributes |
| 903 | + self.name = "Added Cable Length (μm)" |
| 904 | + |
| 905 | + def __call__(self, gt_graphs, fragment_graphs, merge_sites): |
| 906 | + """ |
| 907 | + Computes the normalized ERL of the given graphs. |
| 908 | +
|
| 909 | + Parameters |
| 910 | + ---------- |
| 911 | + gt_graphs : Dict[str, LabeledGraph] |
| 912 | + Graphs to be evaluated. |
| 913 | + fragment_graphs : Dict[str, FragmentGraph] |
| 914 | + Graphs corresponding to the predicted segmentation. |
| 915 | + merge_sites : pandas.DataFrame |
| 916 | + Data frame containing detected merge sites. |
| 917 | +
|
| 918 | + Returns |
| 919 | + ------- |
| 920 | + results : pandas.DataFrame |
| 921 | + DataFrame where the indices are the dictionary keys and values are |
| 922 | + stored under a column called "self.name". |
| 923 | + """ |
| 924 | + pbar = self.get_pbar(len(merge_sites.index)) |
| 925 | + pair_to_length = dict() |
| 926 | + for i in merge_sites.index: |
| 927 | + # Extract site info |
| 928 | + segment_id = merge_sites["Segment_ID"][i] |
| 929 | + gt_id = merge_sites["GroundTruth_ID"][i] |
| 930 | + pair_id = (segment_id, gt_id) |
| 931 | + |
| 932 | + # Check wheter to visit |
| 933 | + if pair_id in pair_to_length: |
| 934 | + merge_sites.loc[i, self.name] = pair_to_length[pair_id] |
| 935 | + else: |
| 936 | + # Get graphs |
| 937 | + gt_graph = gt_graphs[gt_id] |
| 938 | + fragment_graph = deepcopy(fragment_graphs[segment_id]) |
| 939 | + |
| 940 | + # Compute metric |
| 941 | + pair_to_length[pair_id] = self.compute_added_length( |
| 942 | + gt_graph, fragment_graph |
| 943 | + ) |
| 944 | + merge_sites.loc[i, self.name] = pair_to_length[pair_id] |
| 945 | + |
| 946 | + # Update progress bar |
| 947 | + if self.verbose: |
| 948 | + pbar.update(1) |
| 949 | + |
| 950 | + def compute_added_length(self, gt_graph, fragment_graph): |
| 951 | + """ |
| 952 | + Computes the total cable length of fragment components that are not |
| 953 | + sufficiently close to the ground-truth graph. |
| 954 | +
|
| 955 | + Parameters |
| 956 | + ---------- |
| 957 | + gt_graph : LabeledGraph |
| 958 | + Graph containing merge mistake. |
| 959 | + fragment_graph : FragmentGraph |
| 960 | + Fragment that is merged to the given ground truth graph. |
| 961 | +
|
| 962 | + Returns |
| 963 | + ------- |
| 964 | + cable_length : float |
| 965 | + Total cable length of fragment components that remain after pruning |
| 966 | + nodes near the ground-truth graph. |
| 967 | + """ |
| 968 | + # Remove nodes close to ground truth |
| 969 | + xyz_arr = fragment_graph.voxels * fragment_graph.anisotropy |
| 970 | + dists, _ = gt_graph.kdtree.query(xyz_arr) |
| 971 | + max_dist = MergeCountMetric.merge_dist_threshold |
| 972 | + fragment_graph.remove_nodes_from(np.where(dists < max_dist)[0]) |
| 973 | + |
| 974 | + # Compute cable length |
| 975 | + cable_length = 0 |
| 976 | + for nodes in nx.connected_components(fragment_graph): |
| 977 | + node = util.sample_once(nodes) |
| 978 | + cable_length += fragment_graph.run_length_from(node) |
| 979 | + return round(float(cable_length), 2) |
0 commit comments