Skip to content

Commit 85c935e

Browse files
anna-grimanna-grim
andauthored
Feat add cablelength (#170)
* refactor: improved find merge * removed test blocks * feat: added cable length --------- Co-authored-by: anna-grim <anna.grim@alleninstitute.org>
1 parent f3605bf commit 85c935e

File tree

4 files changed

+158
-9
lines changed

4 files changed

+158
-9
lines changed

src/segmentation_skeleton_metrics/data_handling/graph_classes.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,12 @@ def prune_branches(self):
199199
"""
200200
pass
201201

202+
def run_length_from(self):
203+
"""
204+
Placeholder method to be implemented by subclasses.
205+
"""
206+
pass
207+
202208
# --- Writers ---
203209
def to_zipped_swc(self, zip_writer):
204210
"""
@@ -487,3 +493,33 @@ def prune_branches(self, depth=24):
487493
elif self.degree(j) > 2:
488494
self.remove_nodes_from(branch)
489495
break
496+
497+
def run_length_from(self, root):
498+
"""
499+
Computes the physical path length of the connected component that
500+
contains "root".
501+
502+
Parameters
503+
----------
504+
root : int
505+
Node contained in connected component to compute run length of.
506+
507+
Returns
508+
-------
509+
run_length : float
510+
Physical path length.
511+
"""
512+
run_length = 0
513+
queue = [(root, root)]
514+
visited = set([root])
515+
while queue:
516+
# Visit node
517+
i, j = queue.pop()
518+
run_length += self.physical_dist(i, j)
519+
520+
# Update queue
521+
for k in self.neighbors(j):
522+
if k not in visited:
523+
queue.append((j, k))
524+
visited.add(k)
525+
return run_length

src/segmentation_skeleton_metrics/data_handling/swc_loading.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,10 @@ def read_from_gcs_swcs(self, bucket_name, swc_paths):
323323
# Assign threads
324324
threads = list()
325325
for path in swc_paths:
326+
327+
if "003" not in path: # TEMP
328+
continue # TEMP
329+
326330
threads.append(
327331
executor.submit(self.read_from_gcs_swc, bucket_name, path)
328332
)

src/segmentation_skeleton_metrics/evaluate.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import pandas as pd
1717

1818
from segmentation_skeleton_metrics.skeleton_metrics import (
19+
AddedCableLengthMetric,
1920
MergeCountMetric,
2021
MergeRateMetric,
2122
MergedEdgePercentMetric,
@@ -101,7 +102,7 @@ def evaluate(
101102

102103
# Run evaluation
103104
evaluator = Evaluator(output_dir, results_filename, verbose)
104-
evaluator.run(gt_graphs, fragment_graphs)
105+
evaluator(gt_graphs, fragment_graphs)
105106

106107
# Optional saves
107108
if save_merges:
@@ -179,7 +180,7 @@ def __init__(self, output_dir, results_filename, verbose=True):
179180
}
180181

181182
# --- Core Routines ---
182-
def run(self, gt_graphs, fragment_graphs=None):
183+
def __call__(self, gt_graphs, fragment_graphs=None):
183184
"""
184185
Computes evaluation metrics for neuron reconstructions and saves a CSV
185186
report.
@@ -212,6 +213,17 @@ def run(self, gt_graphs, fragment_graphs=None):
212213
elif name != "Merge Rate":
213214
results[name] = metric(gt_graphs, results)
214215

216+
# Compute special metrics
217+
metric = AddedCableLengthMetric(verbose=self.verbose)
218+
metric(
219+
gt_graphs, fragment_graphs, self.metrics["# Merges"].merge_sites
220+
)
221+
222+
# Save merge sites
223+
filename = f"{self.results_filename}-merge_sites.csv"
224+
path = os.path.join(self.output_dir, filename)
225+
self.metrics["# Merges"].merge_sites.to_csv(path, index=True)
226+
215227
# Save report
216228
path = f"{self.output_dir}/{self.results_filename}.csv"
217229
results.to_csv(path, index=True)
@@ -337,7 +349,8 @@ def save_merge_results(self, gt_graphs, fragment_graphs, output_dir):
337349
Directory that results are written to.
338350
"""
339351
# Initialize a writer
340-
zip_path = os.path.join(output_dir, "merged_fragments.zip")
352+
filename = f"{self.results_filename}-merged_fragments.zip"
353+
zip_path = os.path.join(output_dir, filename)
341354
util.rm_file(zip_path)
342355
zip_writer = ZipFile(zip_path, "a")
343356

@@ -346,10 +359,6 @@ def save_merge_results(self, gt_graphs, fragment_graphs, output_dir):
346359
self.save_skeletons_with_merge(gt_graphs, fragment_graphs, zip_writer)
347360
zip_writer.close()
348361

349-
# Save CSV file
350-
path = os.path.join(output_dir, "merge_sites.csv")
351-
self.metrics["# Merges"].merge_sites.to_csv(path, index=True)
352-
353362
def save_merge_sites(self, zip_writer):
354363
"""
355364
Saves merge site coordinates into a ZIP archive.

src/segmentation_skeleton_metrics/skeleton_metrics.py

Lines changed: 102 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"""
1111

1212
from abc import ABC, abstractmethod
13+
from copy import deepcopy
1314
from collections import deque
1415
from scipy.spatial import KDTree
1516
from tqdm import tqdm
@@ -363,6 +364,7 @@ class MergeCountMetric(SkeletonMetric):
363364
"""
364365
A skeleton metric subclass that counts the number merges.
365366
"""
367+
merge_dist_threshold = 50
366368

367369
def __init__(self, verbose=True):
368370
"""
@@ -452,7 +454,7 @@ def search_for_merges(self, gt_graph, fragment_graph):
452454
dist, _ = gt_graph.kdtree.query(xyz)
453455

454456
# Check if distance to ground truth flags a merge mistake
455-
if dist > 50:
457+
if dist > MergeCountMetric.merge_dist_threshold:
456458
self.find_merge_site(gt_graph, fragment_graph, leaf, visited)
457459

458460
def find_merge_site(self, gt_graph, fragment_graph, source, visited):
@@ -523,10 +525,11 @@ def verify_site(self, gt_graph, fragment_graph, gt_node, fragment_node):
523525
self.fragments_with_merge.add(fragment_graph.name)
524526
self.merge_sites.append(
525527
{
526-
"Segment_ID": fragment_graph.segment_id,
528+
"Segment_ID": fragment_graph.name,
527529
"GroundTruth_ID": gt_graph.name,
528530
"Voxel": tuple(map(int, voxel)),
529531
"World": tuple([float(round(t, 2)) for t in xyz]),
532+
"Added Cable Length (μm)": 0.0
530533
}
531534
)
532535

@@ -877,3 +880,100 @@ def __call__(self, gt_graphs, results):
877880
if self.verbose:
878881
pbar.update(1)
879882
return self.reformat(new_results)
883+
884+
885+
class AddedCableLengthMetric(SkeletonMetric):
886+
"""
887+
A skeleton metric subclass that computes added cable length.
888+
"""
889+
890+
def __init__(self, verbose=True):
891+
"""
892+
Instantiates an AddedCableLengthMetric object.
893+
894+
Parameters
895+
----------
896+
verbose : bool, optional
897+
Indication of whether to display a progress bar. Default is True.
898+
"""
899+
# Call parent class
900+
super().__init__(verbose=verbose)
901+
902+
# Instance attributes
903+
self.name = "Added Cable Length (μm)"
904+
905+
def __call__(self, gt_graphs, fragment_graphs, merge_sites):
906+
"""
907+
Computes the normalized ERL of the given graphs.
908+
909+
Parameters
910+
----------
911+
gt_graphs : Dict[str, LabeledGraph]
912+
Graphs to be evaluated.
913+
fragment_graphs : Dict[str, FragmentGraph]
914+
Graphs corresponding to the predicted segmentation.
915+
merge_sites : pandas.DataFrame
916+
Data frame containing detected merge sites.
917+
918+
Returns
919+
-------
920+
results : pandas.DataFrame
921+
DataFrame where the indices are the dictionary keys and values are
922+
stored under a column called "self.name".
923+
"""
924+
pbar = self.get_pbar(len(merge_sites.index))
925+
pair_to_length = dict()
926+
for i in merge_sites.index:
927+
# Extract site info
928+
segment_id = merge_sites["Segment_ID"][i]
929+
gt_id = merge_sites["GroundTruth_ID"][i]
930+
pair_id = (segment_id, gt_id)
931+
932+
# Check wheter to visit
933+
if pair_id in pair_to_length:
934+
merge_sites.loc[i, self.name] = pair_to_length[pair_id]
935+
else:
936+
# Get graphs
937+
gt_graph = gt_graphs[gt_id]
938+
fragment_graph = deepcopy(fragment_graphs[segment_id])
939+
940+
# Compute metric
941+
pair_to_length[pair_id] = self.compute_added_length(
942+
gt_graph, fragment_graph
943+
)
944+
merge_sites.loc[i, self.name] = pair_to_length[pair_id]
945+
946+
# Update progress bar
947+
if self.verbose:
948+
pbar.update(1)
949+
950+
def compute_added_length(self, gt_graph, fragment_graph):
951+
"""
952+
Computes the total cable length of fragment components that are not
953+
sufficiently close to the ground-truth graph.
954+
955+
Parameters
956+
----------
957+
gt_graph : LabeledGraph
958+
Graph containing merge mistake.
959+
fragment_graph : FragmentGraph
960+
Fragment that is merged to the given ground truth graph.
961+
962+
Returns
963+
-------
964+
cable_length : float
965+
Total cable length of fragment components that remain after pruning
966+
nodes near the ground-truth graph.
967+
"""
968+
# Remove nodes close to ground truth
969+
xyz_arr = fragment_graph.voxels * fragment_graph.anisotropy
970+
dists, _ = gt_graph.kdtree.query(xyz_arr)
971+
max_dist = MergeCountMetric.merge_dist_threshold
972+
fragment_graph.remove_nodes_from(np.where(dists < max_dist)[0])
973+
974+
# Compute cable length
975+
cable_length = 0
976+
for nodes in nx.connected_components(fragment_graph):
977+
node = util.sample_once(nodes)
978+
cable_length += fragment_graph.run_length_from(node)
979+
return round(float(cable_length), 2)

0 commit comments

Comments
 (0)