Skip to content

Commit 063d2c3

Browse files
authored
major refactor
1 parent 5a03524 commit 063d2c3

File tree

2 files changed

+691
-16
lines changed

2 files changed

+691
-16
lines changed

src/segmentation_skeleton_metrics/evaluate.py

Lines changed: 125 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,28 @@
88
99
"""
1010

11-
from segmentation_skeleton_metrics.skeleton_metric import SkeletonMetric
11+
from zipfile import ZipFile
12+
13+
import numpy as np
14+
import os
15+
import pandas as pd
16+
17+
from segmentation_skeleton_metrics.skeleton_metrics import (
18+
MergeCountMetric,
19+
MergeRateMetric,
20+
MergedEdgePercentMetric,
21+
OmitEdgePercentMetric,
22+
SplitEdgePercentMetric,
23+
SplitCountMetric,
24+
SplitRateMetric,
25+
EdgeAccuracyMetric,
26+
ERLMetric,
27+
NormalizedERLMetric
28+
)
1229
from segmentation_skeleton_metrics.data_handling.graph_loading import (
1330
DataLoader, LabelHandler
1431
)
32+
from segmentation_skeleton_metrics.utils import util
1533

1634

1735
def evaluate(
@@ -21,6 +39,7 @@ def evaluate(
2139
anisotropy=(1.0, 1.0, 1.0),
2240
connections_path=None,
2341
fragments_pointer=None,
42+
results_filename="results",
2443
save_merges=False,
2544
save_fragments=False,
2645
use_anisotropy=False,
@@ -56,27 +75,117 @@ def evaluate(
5675
anisotropy=anisotropy,
5776
use_anisotropy=use_anisotropy,
5877
verbose=verbose
59-
)
78+
)
6079
gt_graphs = dataloader.load_groundtruth(gt_pointer, label_mask)
6180
fragment_graphs = dataloader.load_fragments(fragments_pointer, gt_graphs)
6281

63-
# Evaluator
64-
skeleton_metric = SkeletonMetric(
65-
gt_graphs,
66-
fragment_graphs,
67-
label_handler,
68-
anisotropy=anisotropy,
69-
output_dir=output_dir,
70-
save_merges=save_merges,
71-
save_fragments=save_fragments,
72-
)
73-
skeleton_metric.run()
82+
# Run evaluation
83+
evaluator = Evaluator(output_dir, results_filename, verbose)
84+
evaluator.run(gt_graphs, fragment_graphs)
7485

75-
# Compute metrics
86+
# Optional saves
87+
if save_merges:
88+
evaluator.save_merge_results()
7689

77-
# Report results
90+
if save_fragments and fragment_graphs:
91+
evaluator.save_fragments(gt_graphs, fragment_graphs)
7892

7993

8094
# --- Evaluator ---
8195
class Evaluator:
82-
pass
96+
97+
def __init__(self, output_dir, results_filename, verbose=True):
98+
# Instance attributes
99+
self.output_dir = output_dir
100+
self.results_filename = results_filename
101+
self.verbose = verbose
102+
103+
# Set core metrics
104+
self.metrics = {
105+
"# Splits": SplitCountMetric(verbose=verbose),
106+
"# Merges": MergeCountMetric(verbose=verbose),
107+
"% Split Edges": SplitEdgePercentMetric(verbose=verbose),
108+
"% Omit Edges": OmitEdgePercentMetric(verbose=verbose),
109+
"% Merged Edges": MergedEdgePercentMetric(verbose=verbose),
110+
"ERL": ERLMetric(verbose=verbose)
111+
}
112+
113+
# Set derived metrics
114+
self.derived_metrics = {
115+
"Normalized ERL": NormalizedERLMetric(verbose=verbose),
116+
"Edge Accuracy": EdgeAccuracyMetric(verbose=verbose),
117+
"Split Rate": SplitRateMetric(verbose=verbose),
118+
"Merge Rate": MergeRateMetric(verbose=verbose),
119+
}
120+
121+
# --- Core Routines ---
122+
def run(self, gt_graphs, fragment_graphs=None):
123+
# Printout step
124+
if self.verbose:
125+
print("\n(3) Compute Metrics")
126+
127+
# Compute core metrics
128+
results = self.init_results(gt_graphs)
129+
for name, metric in self.metrics.items():
130+
if name == "# Merges" and fragment_graphs:
131+
results[name] = metric.compute(gt_graphs, fragment_graphs)
132+
elif name != "# Merges":
133+
results.update(metric.compute(gt_graphs))
134+
135+
# Compute derived metrics
136+
for name, metric in self.derived_metrics.items():
137+
if name == "Merge Rate" and fragment_graphs:
138+
results[name] = metric.compute(gt_graphs, results)
139+
elif name != "Merge Rate":
140+
results[name] = metric.compute(gt_graphs, results)
141+
142+
# Save report
143+
path = f"{self.output_dir}/{self.results_filename}.csv"
144+
results.to_csv(path, index=False)
145+
self.report_summary(results)
146+
147+
def init_results(self, gt_graphs):
148+
cols = list(self.metrics.keys()) + list(self.derived_metrics.keys())
149+
index = list(gt_graphs.keys())
150+
index.sort()
151+
return pd.DataFrame(np.nan, index=index, columns=cols)
152+
153+
def report_summary(self, results):
154+
pass
155+
156+
# --- Writers ---
157+
def save_fragments(self):
158+
pass
159+
160+
def save_merge_results(self, gt_graphs, fragment_graphs, output_dir):
161+
# Initialize a writer
162+
zip_path = os.path.join(output_dir, "merged_fragments.zip")
163+
util.rm_file(zip_path)
164+
zip_writer = ZipFile(zip_path, "a")
165+
166+
# Save SWC files
167+
self.save_merge_sites(zip_writer)
168+
self.save_skeletons_with_merge(gt_graphs, fragment_graphs, zip_writer)
169+
zip_writer.close()
170+
171+
# Save CSV file
172+
path = os.path.join(output_dir, "merge_sites.csv")
173+
self.merge_sites.to_csv(path, index=True)
174+
175+
def save_merge_sites(self, zip_writer):
176+
merge_sites = self.metrics["# Merges"].merge_sites
177+
for i in range(len(merge_sites)):
178+
filename = merge_sites.index[i]
179+
xyz = merge_sites["World"].iloc[i]
180+
util.to_zipped_point(zip_writer, filename, xyz)
181+
182+
def save_skeletons_with_merge(
183+
self, gt_graphs, fragment_graphs, zip_writer
184+
):
185+
# Save ground truth skeletons
186+
for key in self.merge_sites["GroundTruth_ID"].unique():
187+
gt_graphs[key].to_zipped_swc(zip_writer)
188+
189+
# Save fragments
190+
for key in self.fragments_with_merge:
191+
fragment_graphs[key].to_zipped_swc(zip_writer)

0 commit comments

Comments
 (0)