Skip to content

Commit e2161cf

Browse files
anna-grimanna-grim
andauthored
feat: generates report of evaluation (#36)
Co-authored-by: anna-grim <[email protected]>
1 parent a4bfb76 commit e2161cf

File tree

3 files changed

+209
-11
lines changed

3 files changed

+209
-11
lines changed

src/segmentation_skeleton_metrics/graph_utils.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66
77
88
"""
9-
from scipy.spatial.distance import euclidean as dist
109
from random import sample
1110

12-
import math
1311
import networkx as nx
12+
from scipy.spatial.distance import euclidean as dist
1413

1514

1615
# -- edit graph --
@@ -164,10 +163,27 @@ def count_splits(graph):
164163

165164

166165
def compute_run_lengths(graph):
166+
"""
167+
Computes the path length of each connected component in "graph".
168+
169+
Parameters
170+
----------
171+
graph : networkx.Graph
172+
Graph to be parsed.
173+
174+
Returns
175+
-------
176+
run_lengths : list
177+
List of run lengths of each connected component in "graph".
178+
179+
"""
167180
run_lengths = []
168-
for nodes in nx.connected_components(graph):
169-
subgraph = graph.subgraph(nodes)
170-
run_lengths.append(compute_path_length(subgraph))
181+
if graph.number_of_nodes():
182+
for nodes in nx.connected_components(graph):
183+
subgraph = graph.subgraph(nodes)
184+
run_lengths.append(compute_path_length(subgraph))
185+
else:
186+
run_lengths.append(0)
171187
return run_lengths
172188

173189

@@ -193,6 +209,7 @@ def compute_path_length(graph):
193209
path_length += dist(xyz_1, xyz_2)
194210
return path_length
195211

212+
196213
# -- miscellaneous --
197214
def sample_leaf(graph):
198215
"""

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 165 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,20 @@ def init_target_graphs(self, paths, anisotropy):
9999
self.target_graphs[swc_id] = to_graph(path, anisotropy=anisotropy)
100100

101101
def init_pred_graphs(self):
102+
"""
103+
Initializes "self.pred_graphs" by copying each graph in
104+
"self.target_graphs", then labels each node with the label in
105+
"self.labels" that coincides with it.
106+
107+
Parameters
108+
----------
109+
None
110+
111+
Returns
112+
-------
113+
None
114+
115+
"""
102116
print("Labelling Target Graphs...")
103117
t0 = time()
104118
self.pred_graphs = dict()
@@ -227,7 +241,8 @@ def compute_metrics(self):
227241
self.quantify_merges()
228242

229243
# Compute metrics
230-
self.compile_results()
244+
full_results, avg_results = self.compile_results()
245+
return full_results, avg_results
231246

232247
def detect_splits(self):
233248
"""
@@ -397,9 +412,39 @@ def detect_merges(self):
397412
print(f"\nRuntime: {round(t, 2)} {unit}\n")
398413

399414
def init_merge_counter(self):
415+
"""
416+
Initializes a dictionary that is used to count the number of merge
417+
type mistakes for each pred_graph.
418+
419+
Parameters
420+
----------
421+
None
422+
423+
Returns
424+
-------
425+
dict
426+
Dictionary used to count number of merge type mistakes.
427+
428+
"""
400429
return dict([(swc_id, 0) for swc_id in self.pred_graphs.keys()])
401430

402431
def process_merge(self, swc_id, label):
432+
"""
433+
Once a merge has been detected that corresponds to "label", every node
434+
in "self.pred_graph[swc_id]" with that label is deleted.
435+
436+
Parameters
437+
----------
438+
swc_id : str
439+
Key associated with the pred_graph to be searched.
440+
label : int
441+
Label assocatied with a merge.
442+
443+
Returns
444+
-------
445+
None
446+
447+
"""
403448
# Update graph
404449
graph = self.pred_graphs[swc_id].copy()
405450
graph, merged_cnt = gutils.delete_nodes(graph, label, return_cnt=True)
@@ -411,29 +456,124 @@ def process_merge(self, swc_id, label):
411456
self.merged_cnts[swc_id] += merged_cnt
412457

413458
def quantify_merges(self):
459+
"""
460+
Computes the percentage of merged edges for each pred_graph.
461+
462+
Parameters
463+
----------
464+
None
465+
466+
Returns
467+
-------
468+
None
469+
470+
"""
414471
self.merged_percents = dict()
415472
for swc_id in self.target_graphs.keys():
416473
n_edges = self.target_graphs[swc_id].number_of_edges()
417474
self.merged_percents[swc_id] = self.merged_cnts[swc_id] / n_edges
418475

419476
def compile_results(self):
477+
"""
478+
Compiles a dictionary containing the metrics computed by this module.
479+
480+
Parameters
481+
----------
482+
None
483+
484+
Returns
485+
-------
486+
full_results : dict
487+
Dictionary where the keys are swc_ids and the values are the result
488+
of computing each metric for the corresponding graphs.
489+
avg_result : dict
490+
Dictionary where the keys are names of metrics computed by this
491+
module and values are the averaged result over all swc_ids.
492+
493+
"""
420494
# Compute remaining metrics
421495
self.compute_edge_accuracy()
422496
self.compute_erl()
423497

498+
# Summarize results
499+
swc_ids, full_results = self.generate_report()
500+
avg_results = dict([(k, np.mean(v)) for k, v in full_results.items()])
501+
full_results = dict(zip(swc_ids, full_results))
502+
return full_results, avg_results
503+
504+
def generate_report(self):
505+
"""
506+
Generates a report by creating a list of the results for each metric.
507+
Each item in this list corresponds to a graph in "self.pred_graphs"
508+
and this list is ordered with respect to "swc_ids".
509+
510+
Parameters
511+
----------
512+
None
513+
514+
Results
515+
-------
516+
swc_ids : list[str]
517+
Specifies the ordering of results for each value in "stats".
518+
stats : dict
519+
Dictionary where the keys are metrics and values are the result of
520+
computing that metric for each graph in "self.pred_graphs".
521+
522+
"""
523+
swc_ids = list(self.pred_graphs.keys())
524+
swc_ids.sort()
525+
stats = {
526+
"# splits": generate_result(swc_ids, self.split_cnts),
527+
"# merges": generate_result(swc_ids, self.merge_cnts),
528+
"% omit edges": generate_result(swc_ids, self.omit_percents),
529+
"% merged edges": generate_result(swc_ids, self.merged_percents),
530+
"edge accuracy": generate_result(swc_ids, self.edge_accuracy),
531+
"erl": generate_result(swc_ids, self.erl),
532+
"normalized erl": generate_result(swc_ids, self.normalized_erl),
533+
}
534+
return swc_ids, stats
535+
424536
def compute_edge_accuracy(self):
537+
"""
538+
Computes the edge accuracy of each pred_graph.
539+
540+
Parameters
541+
----------
542+
None
543+
544+
Returns
545+
-------
546+
None
547+
548+
"""
425549
self.edge_accuracy = dict()
426550
for swc_id in self.target_graphs.keys():
427551
omit_percent = self.omit_percents[swc_id]
428552
merged_percent = self.merged_percents[swc_id]
429553
self.edge_accuracy[swc_id] = 1 - omit_percent - merged_percent
430554

431555
def compute_erl(self):
556+
"""
557+
Computes the expected run length (ERL) of each pred_graph.
558+
559+
Parameters
560+
----------
561+
None
562+
563+
Returns
564+
-------
565+
None
566+
567+
"""
432568
self.erl = dict()
569+
self.normalized_erl = dict()
433570
for swc_id in self.target_graphs.keys():
434-
graph = self.pred_graphs[swc_id]
435-
path_lengths = gutils.compute_run_lengths(graph)
571+
pred_graph = self.pred_graphs[swc_id]
572+
target_graph = self.target_graphs[swc_id]
573+
path_lengths = gutils.compute_run_lengths(pred_graph)
574+
path_length = gutils.compute_path_length(target_graph)
436575
self.erl[swc_id] = np.mean(path_lengths)
576+
self.normalized_erl[swc_id] = np.mean(path_lengths) / path_length
437577

438578

439579
# -- utils --
@@ -480,3 +620,25 @@ def remove_edge(dfs_edges, edge):
480620
elif (edge[1], edge[0]) in dfs_edges:
481621
dfs_edges.remove((edge[1], edge[0]))
482622
return dfs_edges
623+
624+
625+
def generate_result(swc_ids, stats):
626+
"""
627+
Reorders items in "stats" with respect to the order defined by "swc_ids".
628+
629+
Parameters
630+
----------
631+
swc_ids : list[str]
632+
List of all swc_ids of graphs in "self.pred_graphs".
633+
stats : dict
634+
Dictionary where the keys are swc_ids and values are the result of
635+
computing some metrics.
636+
637+
Returns
638+
-------
639+
list
640+
Reorded items in "stats" with respect to the order defined by
641+
"swc_ids".
642+
643+
"""
644+
return [stats[swc_id] for swc_id in swc_ids]

src/segmentation_skeleton_metrics/utils.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def open_tensorstore(path, driver):
8383
Sparse image volume.
8484
8585
"""
86-
assert driver in SUPPORTED_DRIVERS, "Error! Driver is not supported!"
86+
assert driver in SUPPORTED_DRIVERS, "Driver is not supported!"
8787
arr = ts.open(
8888
{
8989
"driver": driver,
@@ -112,7 +112,7 @@ def read_tensorstore(path):
112112
Parameters
113113
----------
114114
path : str
115-
Path to directory containing shard files.
115+
Path to directory containing shardsS.
116116
117117
Returns
118118
-------
@@ -205,6 +205,25 @@ def time_writer(t, unit="seconds"):
205205

206206

207207
def progress_bar(current, total, bar_length=50):
208+
"""
209+
Reports the progress of completing some process.
210+
211+
Parameters
212+
----------
213+
current : int
214+
Current iteration of process.
215+
total : int
216+
Total number of iterations to be completed
217+
bar_length : int, optional
218+
Length of progress bar
219+
220+
Returns
221+
-------
222+
None
223+
224+
"""
208225
progress = int(current / total * bar_length)
209-
bar = f"[{'=' * progress}{' ' * (bar_length - progress)}] {current}/{total}"
226+
bar = (
227+
f"[{'=' * progress}{' ' * (bar_length - progress)}] {current}/{total}"
228+
)
210229
print(f"\r{bar}", end="", flush=True)

0 commit comments

Comments
 (0)