Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/segmentation_skeleton_metrics/data_handling/graph_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,12 @@ def prune_branches(self):
"""
pass

def run_length_from(self):
"""
Placeholder method to be implemented by subclasses.
"""
pass

# --- Writers ---
def to_zipped_swc(self, zip_writer):
"""
Expand Down Expand Up @@ -487,3 +493,33 @@ def prune_branches(self, depth=24):
elif self.degree(j) > 2:
self.remove_nodes_from(branch)
break

def run_length_from(self, root):
"""
Computes the physical path length of the connected component that
contains "root".

Parameters
----------
root : int
Node contained in connected component to compute run length of.

Returns
-------
run_length : float
Physical path length.
"""
run_length = 0
queue = [(root, root)]
visited = set([root])
while queue:
# Visit node
i, j = queue.pop()
run_length += self.physical_dist(i, j)

# Update queue
for k in self.neighbors(j):
if k not in visited:
queue.append((j, k))
visited.add(k)
return run_length
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ def read_from_gcs_swcs(self, bucket_name, swc_paths):
# Assign threads
threads = list()
for path in swc_paths:

if "003" not in path: # TEMP
continue # TEMP

threads.append(
executor.submit(self.read_from_gcs_swc, bucket_name, path)
)
Expand Down
23 changes: 16 additions & 7 deletions src/segmentation_skeleton_metrics/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pandas as pd

from segmentation_skeleton_metrics.skeleton_metrics import (
AddedCableLengthMetric,
MergeCountMetric,
MergeRateMetric,
MergedEdgePercentMetric,
Expand Down Expand Up @@ -101,7 +102,7 @@ def evaluate(

# Run evaluation
evaluator = Evaluator(output_dir, results_filename, verbose)
evaluator.run(gt_graphs, fragment_graphs)
evaluator(gt_graphs, fragment_graphs)

# Optional saves
if save_merges:
Expand Down Expand Up @@ -179,7 +180,7 @@ def __init__(self, output_dir, results_filename, verbose=True):
}

# --- Core Routines ---
def run(self, gt_graphs, fragment_graphs=None):
def __call__(self, gt_graphs, fragment_graphs=None):
"""
Computes evaluation metrics for neuron reconstructions and saves a CSV
report.
Expand Down Expand Up @@ -212,6 +213,17 @@ def run(self, gt_graphs, fragment_graphs=None):
elif name != "Merge Rate":
results[name] = metric(gt_graphs, results)

# Compute special metrics
metric = AddedCableLengthMetric(verbose=self.verbose)
metric(
gt_graphs, fragment_graphs, self.metrics["# Merges"].merge_sites
)

# Save merge sites
filename = f"{self.results_filename}-merge_sites.csv"
path = os.path.join(self.output_dir, filename)
self.metrics["# Merges"].merge_sites.to_csv(path, index=True)

# Save report
path = f"{self.output_dir}/{self.results_filename}.csv"
results.to_csv(path, index=True)
Expand Down Expand Up @@ -337,7 +349,8 @@ def save_merge_results(self, gt_graphs, fragment_graphs, output_dir):
Directory that results are written to.
"""
# Initialize a writer
zip_path = os.path.join(output_dir, "merged_fragments.zip")
filename = f"{self.results_filename}-merged_fragments.zip"
zip_path = os.path.join(output_dir, filename)
util.rm_file(zip_path)
zip_writer = ZipFile(zip_path, "a")

Expand All @@ -346,10 +359,6 @@ def save_merge_results(self, gt_graphs, fragment_graphs, output_dir):
self.save_skeletons_with_merge(gt_graphs, fragment_graphs, zip_writer)
zip_writer.close()

# Save CSV file
path = os.path.join(output_dir, "merge_sites.csv")
self.metrics["# Merges"].merge_sites.to_csv(path, index=True)

def save_merge_sites(self, zip_writer):
"""
Saves merge site coordinates into a ZIP archive.
Expand Down
104 changes: 102 additions & 2 deletions src/segmentation_skeleton_metrics/skeleton_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""

from abc import ABC, abstractmethod
from copy import deepcopy
from collections import deque
from scipy.spatial import KDTree
from tqdm import tqdm
Expand Down Expand Up @@ -363,6 +364,7 @@ class MergeCountMetric(SkeletonMetric):
"""
A skeleton metric subclass that counts the number merges.
"""
merge_dist_threshold = 50

def __init__(self, verbose=True):
"""
Expand Down Expand Up @@ -452,7 +454,7 @@ def search_for_merges(self, gt_graph, fragment_graph):
dist, _ = gt_graph.kdtree.query(xyz)

# Check if distance to ground truth flags a merge mistake
if dist > 50:
if dist > MergeCountMetric.merge_dist_threshold:
self.find_merge_site(gt_graph, fragment_graph, leaf, visited)

def find_merge_site(self, gt_graph, fragment_graph, source, visited):
Expand Down Expand Up @@ -523,10 +525,11 @@ def verify_site(self, gt_graph, fragment_graph, gt_node, fragment_node):
self.fragments_with_merge.add(fragment_graph.name)
self.merge_sites.append(
{
"Segment_ID": fragment_graph.segment_id,
"Segment_ID": fragment_graph.name,
"GroundTruth_ID": gt_graph.name,
"Voxel": tuple(map(int, voxel)),
"World": tuple([float(round(t, 2)) for t in xyz]),
"Added Cable Length (μm)": 0.0
}
)

Expand Down Expand Up @@ -877,3 +880,100 @@ def __call__(self, gt_graphs, results):
if self.verbose:
pbar.update(1)
return self.reformat(new_results)


class AddedCableLengthMetric(SkeletonMetric):
"""
A skeleton metric subclass that computes added cable length.
"""

def __init__(self, verbose=True):
"""
Instantiates an AddedCableLengthMetric object.

Parameters
----------
verbose : bool, optional
Indication of whether to display a progress bar. Default is True.
"""
# Call parent class
super().__init__(verbose=verbose)

# Instance attributes
self.name = "Added Cable Length (μm)"

def __call__(self, gt_graphs, fragment_graphs, merge_sites):
"""
Computes the normalized ERL of the given graphs.

Parameters
----------
gt_graphs : Dict[str, LabeledGraph]
Graphs to be evaluated.
fragment_graphs : Dict[str, FragmentGraph]
Graphs corresponding to the predicted segmentation.
merge_sites : pandas.DataFrame
Data frame containing detected merge sites.

Returns
-------
results : pandas.DataFrame
DataFrame where the indices are the dictionary keys and values are
stored under a column called "self.name".
"""
pbar = self.get_pbar(len(merge_sites.index))
pair_to_length = dict()
for i in merge_sites.index:
# Extract site info
segment_id = merge_sites["Segment_ID"][i]
gt_id = merge_sites["GroundTruth_ID"][i]
pair_id = (segment_id, gt_id)

# Check wheter to visit
if pair_id in pair_to_length:
merge_sites.loc[i, self.name] = pair_to_length[pair_id]
else:
# Get graphs
gt_graph = gt_graphs[gt_id]
fragment_graph = deepcopy(fragment_graphs[segment_id])

# Compute metric
pair_to_length[pair_id] = self.compute_added_length(
gt_graph, fragment_graph
)
merge_sites.loc[i, self.name] = pair_to_length[pair_id]

# Update progress bar
if self.verbose:
pbar.update(1)

def compute_added_length(self, gt_graph, fragment_graph):
"""
Computes the total cable length of fragment components that are not
sufficiently close to the ground-truth graph.

Parameters
----------
gt_graph : LabeledGraph
Graph containing merge mistake.
fragment_graph : FragmentGraph
Fragment that is merged to the given ground truth graph.

Returns
-------
cable_length : float
Total cable length of fragment components that remain after pruning
nodes near the ground-truth graph.
"""
# Remove nodes close to ground truth
xyz_arr = fragment_graph.voxels * fragment_graph.anisotropy
dists, _ = gt_graph.kdtree.query(xyz_arr)
max_dist = MergeCountMetric.merge_dist_threshold
fragment_graph.remove_nodes_from(np.where(dists < max_dist)[0])

# Compute cable length
cable_length = 0
for nodes in nx.connected_components(fragment_graph):
node = util.sample_once(nodes)
cable_length += fragment_graph.run_length_from(node)
return round(float(cable_length), 2)
Loading