Skip to content

Commit c667865

Browse files
author
anna-grim
committed
major refactor
1 parent 8a6cf74 commit c667865

File tree

13 files changed

+1101
-1108
lines changed

13 files changed

+1101
-1108
lines changed

demo/demo.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,24 @@
88
99
"""
1010

11-
from segmentation_skeleton_metrics.skeleton_metric import SkeletonMetric
11+
from segmentation_skeleton_metrics.evaluate import evaluate
1212
from segmentation_skeleton_metrics.utils.img_util import TiffReader
1313

1414

15-
def evaluate():
15+
def main():
1616
"""
1717
Evaluates the accuracy of a predicted segmentation by comparing it to a
1818
set of ground truth skeletons, then reports and saves various performance
1919
metrics.
20-
21-
Parameters
22-
----------
23-
None
24-
25-
Returns
26-
-------
27-
None
28-
2920
"""
3021
# Initializations
3122
pred_labels = TiffReader(pred_labels_path, swap_axes=False)
32-
skeleton_metric = SkeletonMetric(
23+
skeleton_metric = evaluate(
3324
groundtruth_pointer,
3425
pred_labels,
26+
output_dir,
3527
fragments_pointer=fragments_pointer,
36-
output_dir=output_dir,
3728
)
38-
skeleton_metric.run()
3929

4030

4131
if __name__ == "__main__":
@@ -46,4 +36,4 @@ def evaluate():
4636
groundtruth_pointer = "./data/target_swcs.zip"
4737

4838
# Run
49-
evaluate()
39+
main()

demo/results-overview.txt

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
Average Results...
2-
# Splits: 7.9680
2+
# Splits: 7.2780
33
# Merges: 0.0000
4-
Split Rate: 114.6940
5-
Merge Rate: 0.0000
4+
Split Rate: 124.3079
5+
Merge Rate: nan
66
% Split Edges: 1.1996
77
% Omit Edges: 4.5086
88
% Merged Edges: 0.0000
@@ -11,20 +11,5 @@ Average Results...
1111
Normalized ERL: 0.3949
1212

1313
Total Results...
14-
# Splits: 31
15-
# Merges: 0
16-
Average Results...
17-
# Splits: 7.9680
18-
# Merges: 0.0000
19-
Split Rate: 114.6940
20-
Merge Rate: 0.0000
21-
% Split Edges: 1.1996
22-
% Omit Edges: 4.5086
23-
% Merged Edges: 0.0000
24-
Edge Accuracy: 95.4914
25-
ERL: 180.8849
26-
Normalized ERL: 0.3949
27-
28-
Total Results...
29-
# Splits: 31
14+
# Splits: 27
3015
# Merges: 0

demo/results.csv

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
,# Splits,# Merges,Split Rate,Merge Rate,% Split Edges,% Omit Edges,% Merged Edges,Edge Accuracy,ERL,Normalized ERL,GT Run Length
2-
SNT_Data-002,0,,84.08376161857129,,0.0,0.0,0.0,100.0,84.08,1.0,84.08
3-
SNT_Data-035,0,,163.37074041968347,,0.0,0.0,0.0,100.0,163.37,1.0,163.37
4-
SNT_Data-037,1,,211.9805590314331,,0.98,1.96,0.0,98.04,184.57,0.8585,214.98
5-
SNT_Data-038,1,,251.78664303157595,,0.83,1.65,0.0,98.35,144.38,0.5667,254.79
6-
SNT_Data-041,4,,55.46881994891758,,2.83,7.55,0.0,92.45,52.62,0.2286,230.17
7-
SNT_Data-043,2,,46.354813254114404,,1.92,9.62,0.0,90.38,35.32,0.3334,105.96
8-
SNT_Data-051,0,,106.56843799568803,,0.0,0.0,0.0,100.0,106.57,1.0,106.57
9-
SNT_Data-052,6,,99.17738848187402,,1.37,6.48,0.0,93.52,265.1,0.4184,633.57
10-
SNT_Data-053,15,,98.7153068880868,,1.31,4.46,0.0,95.54,207.79,0.1343,1547.45
11-
SNT_Data-062,2,,105.1897956496423,,0.9,8.11,0.0,91.89,116.9,0.5144,227.24
12-
SNT_Data-074,0,,80.17991450259146,,0.0,0.0,0.0,100.0,80.18,1.0,80.18
1+
# Splits,# Merges,% Split Edges,% Omit Edges,% Merged Edges,ERL,Normalized ERL,Edge Accuracy,Split Rate,Merge Rate
2+
0.0,,0.0,0.0,0.0,84.08,1.0,100.0,,
3+
0.0,,0.0,0.0,0.0,163.37,1.0,100.0,,
4+
1.0,,0.9803921568627451,0.9803921568627451,0.0,40.27,0.1873,98.04,62.82,
5+
1.0,,0.8264462809917356,5.785123966942149,0.0,146.36,0.5744,93.39,224.52,
6+
3.0,,2.830188679245283,0.0,0.0,45.66,0.1984,97.17,51.2,
7+
1.0,,1.9230769230769231,1.9230769230769231,0.0,30.3,0.286,96.15,51.19,
8+
0.0,,0.0,0.0,0.0,106.57,1.0,100.0,,
9+
5.0,,1.3651877133105803,2.3890784982935154,0.0,293.17,0.4627,96.25,103.53,
10+
14.0,,1.3123359580052494,0.7874015748031497,0.0,217.71,0.1407,97.9,92.8,
11+
2.0,,0.9009009009009009,4.504504504504505,0.0,116.9,0.5144,94.59,105.19,
12+
0.0,,0.0,0.0,0.0,80.18,1.0,100.0,,

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ dependencies = [
2525
"scikit-image",
2626
"tensorstore",
2727
"tifffile",
28-
"xlwt",
2928
"zarr",
3029
]
3130

src/segmentation_skeleton_metrics/data_handling/graph_loading.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717
import networkx as nx
1818
import numpy as np
1919

20-
from segmentation_skeleton_metrics.data_handling.skeleton_graph import SkeletonGraph
21-
from segmentation_skeleton_metrics.utils import swc_util, util
20+
from segmentation_skeleton_metrics.data_handling import swc_loading
21+
from segmentation_skeleton_metrics.data_handling.skeleton_graph import (
22+
FragmentGraph, LabeledGraph
23+
)
24+
from segmentation_skeleton_metrics.utils import util
2225

2326

2427
class DataLoader:
@@ -93,6 +96,7 @@ def load_fragments(self, swc_pointer, gt_graphs):
9396
graph_loader = GraphLoader(
9497
anisotropy=self.anisotropy,
9598
is_groundtruth=False,
99+
label_handler=self.label_handler,
96100
selected_ids=selected_ids,
97101
use_anisotropy=self.use_anisotropy,
98102
)
@@ -112,10 +116,10 @@ def get_all_node_labels(self, graphs):
112116
labels : Set[int]
113117
Unique node labels across all graphs.
114118
"""
115-
labels = set()
119+
node_labels = set()
116120
for graph in graphs.values():
117-
labels |= self.label_handler.get_node_labels(graph)
118-
return labels
121+
node_labels |= self.label_handler.get_node_labels(graph)
122+
return node_labels
119123

120124

121125
class GraphLoader:
@@ -146,7 +150,7 @@ def __init__(
146150
label_mask : ImageReader, optional
147151
Predicted segmentation mask.
148152
selected_ids : Set[int], optional
149-
Only SWC files with an swc_id contained in this set are read.
153+
Only SWC files with a name contained in this set are read.
150154
Default is None.
151155
use_anisotropy : bool, optional
152156
Indication of whether coordinates in SWC files should be converted
@@ -161,7 +165,7 @@ def __init__(
161165

162166
# Reader
163167
anisotropy = anisotropy if use_anisotropy else (1.0, 1.0, 1.0)
164-
self.swc_reader = swc_util.Reader(
168+
self.swc_reader = swc_loading.Reader(
165169
anisotropy, selected_ids=selected_ids
166170
)
167171

@@ -181,11 +185,11 @@ def run(self, swc_pointer):
181185
Dictionary where the keys are unique identifiers (i.e. filenames
182186
of SWC files) and values are the corresponding SkeletonGraph.
183187
"""
184-
graph_dict = self._build_graphs_from_swcs(swc_pointer)
188+
graphs = self._build_graphs_from_swcs(swc_pointer)
185189
if self.label_mask:
186-
for key in graph_dict:
187-
self._label_graph(graph_dict[key])
188-
return graph_dict
190+
for name in graphs:
191+
self._label_graph(graphs[name])
192+
return graphs
189193

190194
# --- Build Graphs ---
191195
def _build_graphs_from_swcs(self, swc_pointer):
@@ -246,25 +250,40 @@ def to_graph(self, swc_dict):
246250
Graph built from an SWC file.
247251
"""
248252
# Initialize graph
249-
graph = SkeletonGraph(
250-
anisotropy=self.anisotropy, is_groundtruth=self.is_groundtruth
251-
)
252-
graph.init_voxels(swc_dict["voxel"])
253-
graph.set_filename(swc_dict["swc_id"] + ".swc")
254-
graph.set_nodes(len(swc_dict["id"]))
253+
graph = self._init_graph(swc_dict)
255254

256-
# Build graph
255+
# Build graph structure
257256
id_lookup = dict()
258257
for i, id_i in enumerate(swc_dict["id"]):
259258
id_lookup[id_i] = i
260259
if swc_dict["pid"][i] != -1:
261260
parent = id_lookup[swc_dict["pid"][i]]
262261
graph.add_edge(i, parent)
263262
graph.run_length += graph.dist(i, parent)
263+
graph.prune_branches()
264+
return {graph.name: graph}
265+
266+
def _init_graph(self, swc_dict):
267+
# Instantiate graph
268+
if self.is_groundtruth:
269+
graph = LabeledGraph(
270+
anisotropy=self.anisotropy, name=swc_dict["swc_name"]
271+
)
272+
else:
273+
segment_id = util.get_segment_id(swc_dict["swc_name"])
274+
label = self.label_handler.get(segment_id)
275+
graph = FragmentGraph(
276+
anisotropy=self.anisotropy,
277+
name=swc_dict["swc_name"],
278+
label=label,
279+
segment_id=segment_id
280+
)
264281

265-
# Set graph-level attributes
266-
graph.graph["n_initial_edges"] = graph.number_of_edges()
267-
return {swc_dict["swc_id"]: graph}
282+
# Set class attributes
283+
graph.init_voxels(swc_dict["voxel"])
284+
graph.set_filename(swc_dict["swc_name"] + ".swc")
285+
graph.set_nodes(len(swc_dict["id"]))
286+
return graph
268287

269288
# --- Label Graphs ---
270289
def _label_graph(self, graph):
@@ -311,11 +330,11 @@ def _label_graph(self, graph):
311330
)
312331

313332
# Store results
314-
graph.init_labels()
333+
graph.init_node_labels()
315334
for thread in as_completed(threads):
316335
node_to_label = thread.result()
317336
for i, label in node_to_label.items():
318-
graph.labels[i] = label
337+
graph.node_labels[i] = label
319338

320339
def get_patch_labels(self, graph, nodes):
321340
"""
@@ -366,6 +385,9 @@ def to_local_voxels(self, graph, i, offset):
366385
offset = np.array(offset)
367386
return tuple(voxel - offset)
368387

388+
def fix_label_misalignments(self, graph):
389+
pass
390+
369391

370392
class LabelHandler:
371393
"""
@@ -527,7 +549,7 @@ def get_node_labels(self, graph):
527549
labels : Set[int]
528550
Labels corresponding to nodes in the graph identified by "key".
529551
"""
530-
labels = graph.get_labels()
552+
labels = graph.get_node_labels()
531553
if self.use_mapping():
532554
labels = set().union(*(self.inverse_mapping[l] for l in labels))
533555
return labels

0 commit comments

Comments
 (0)