Skip to content

Commit 1699b63

Browse files
anna-grimanna-grim
andauthored
refactor: graph labeling moved to graph util (#135)
Co-authored-by: anna-grim <[email protected]>
1 parent 0505d46 commit 1699b63

File tree

6 files changed

+268
-113
lines changed

6 files changed

+268
-113
lines changed

demo/demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def evaluate():
3535
fragments_pointer=fragments_pointer,
3636
output_dir=output_dir,
3737
)
38-
skeleton_metric.run(output_dir)
38+
skeleton_metric.run()
3939

4040

4141
if __name__ == "__main__":

demo/results-overview.txt

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,152 @@ Average Results...
1111
Total Results...
1212
# splits: 36
1313
# merges: 0
14+
Average Results...
15+
# Splits: 7.9680
16+
# Merges: 0.0000
17+
% Split: 1.1996
18+
% Omit: 4.5086
19+
% Merged: 0.0000
20+
Edge Accuracy: 95.4914
21+
ERL: 180.8849
22+
Normalized ERL: 0.3949
23+
24+
Total Results...
25+
# Splits: 31
26+
# Merges: 0
27+
Average Results...
28+
# Splits: 7.9680
29+
# Merges: 0.0000
30+
% Split: 1.1996
31+
% Omit: 4.5086
32+
% Merged: 0.0000
33+
Edge Accuracy: 95.4914
34+
ERL: 180.8849
35+
Normalized ERL: 0.3949
36+
37+
Total Results...
38+
# Splits: 31
39+
# Merges: 0
40+
Average Results...
41+
# Splits: 7.9680
42+
# Merges: 0.0000
43+
% Split: 1.1996
44+
% Omit: 4.5086
45+
% Merged: 0.0000
46+
Edge Accuracy: 95.4914
47+
ERL: 180.8849
48+
Normalized ERL: 0.3949
49+
50+
Total Results...
51+
# Splits: 31
52+
# Merges: 0
53+
Average Results...
54+
# Splits: 7.9680
55+
# Merges: 0.0000
56+
% Split: 1.1996
57+
% Omit: 4.5086
58+
% Merged: 0.0000
59+
Edge Accuracy: 95.4914
60+
ERL: 180.8849
61+
Normalized ERL: 0.3949
62+
63+
Total Results...
64+
# Splits: 31
65+
# Merges: 0
66+
Average Results...
67+
# Splits: 7.9680
68+
# Merges: 0.0000
69+
% Split: 1.1996
70+
% Omit: 4.5086
71+
% Merged: 0.0000
72+
Edge Accuracy: 95.4914
73+
ERL: 180.8849
74+
Normalized ERL: 0.3949
75+
76+
Total Results...
77+
# Splits: 31
78+
# Merges: 0
79+
Average Results...
80+
# Splits: 7.9680
81+
# Merges: 0.0000
82+
% Split: 1.1996
83+
% Omit: 4.5086
84+
% Merged: 0.0000
85+
Edge Accuracy: 95.4914
86+
ERL: 180.8849
87+
Normalized ERL: 0.3949
88+
89+
Total Results...
90+
# Splits: 31
91+
Average Results...
92+
# Splits: 7.9680
93+
# Merges: 0.0000
94+
% Split: 1.1996
95+
% Omit: 4.5086
96+
% Merged: 0.0000
97+
Edge Accuracy: 95.4914
98+
ERL: 180.8849
99+
Normalized ERL: 0.3949
100+
101+
Total Results...
102+
# Splits: 31
103+
Average Results...
104+
# Splits: 7.9680
105+
# Merges: 0.0000
106+
% Split: 1.1996
107+
% Omit: 4.5086
108+
% Merged: 0.0000
109+
Edge Accuracy: 95.4914
110+
ERL: 180.8849
111+
Normalized ERL: 0.3949
112+
113+
Total Results...
114+
# Splits: 31
115+
Average Results...
116+
# Splits: 7.9680
117+
# Merges: 0.0000
118+
% Split: 1.1996
119+
% Omit: 4.5086
120+
% Merged: 0.0000
121+
Edge Accuracy: 95.4914
122+
ERL: 180.8849
123+
Normalized ERL: 0.3949
124+
125+
Total Results...
126+
# Splits: 31
127+
Average Results...
128+
# Splits: 7.9680
129+
% Split: 1.1996
130+
% Omit: 4.5086
131+
% Merged: 0.0000
132+
Edge Accuracy: 95.4914
133+
ERL: 180.8849
134+
Normalized ERL: 0.3949
135+
136+
Total Results...
137+
# Splits: 31
138+
Average Results...
139+
# Splits: 7.9680
140+
# Merges: 0.0000
141+
% Split: 1.1996
142+
% Omit: 4.5086
143+
% Merged: 0.0000
144+
Edge Accuracy: 95.4914
145+
ERL: 180.8849
146+
Normalized ERL: 0.3949
147+
148+
Total Results...
149+
# Splits: 31
150+
Average Results...
151+
# Splits: 7.9680
152+
# Merges: 0.0000
153+
% Split: 1.1996
154+
% Omit: 4.5086
155+
% Merged: 0.0000
156+
Edge Accuracy: 95.4914
157+
ERL: 180.8849
158+
Normalized ERL: 0.3949
159+
160+
Total Results...
161+
# Splits: 31
162+
# Merges: 0

demo/results.csv

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
,# Splits,# Merges,% Split,% Omit,% Merged,Edge Accuracy,ERL,Normalized ERL,GT Run Length
2+
SNT_Data-002,0,,0.0,0.0,0.0,100.0,84.08,1.0,84.08
3+
SNT_Data-035,0,,0.0,0.0,0.0,100.0,163.37,1.0,163.37
4+
SNT_Data-037,1,,0.98,1.96,0.0,98.04,184.57,0.8585,214.98
5+
SNT_Data-038,1,,0.83,1.65,0.0,98.35,144.38,0.5667,254.79
6+
SNT_Data-041,4,,2.83,7.55,0.0,92.45,52.62,0.2286,230.17
7+
SNT_Data-043,2,,1.92,9.62,0.0,90.38,35.32,0.3334,105.96
8+
SNT_Data-051,0,,0.0,0.0,0.0,100.0,106.57,1.0,106.57
9+
SNT_Data-052,6,,1.37,6.48,0.0,93.52,265.1,0.4184,633.57
10+
SNT_Data-053,15,,1.31,4.46,0.0,95.54,207.79,0.1343,1547.45
11+
SNT_Data-062,2,,0.9,8.11,0.0,91.89,116.9,0.5144,227.24
12+
SNT_Data-074,0,,0.0,0.0,0.0,100.0,80.18,1.0,80.18

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 16 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from concurrent.futures import (
1414
as_completed,
1515
ProcessPoolExecutor,
16-
ThreadPoolExecutor,
1716
)
1817
from copy import deepcopy
1918
from scipy.spatial import distance, KDTree
@@ -119,8 +118,7 @@ def __init__(
119118
)
120119

121120
# Load data
122-
self.label_mask = label_mask
123-
self.load_groundtruth(gt_pointer)
121+
self.load_groundtruth(gt_pointer, label_mask)
124122
self.load_fragments(fragments_pointer)
125123

126124
# Initialize metrics
@@ -144,14 +142,16 @@ def __init__(
144142
self.metrics = pd.DataFrame(index=row_names, columns=col_names)
145143

146144
# --- Load Data ---
147-
def load_groundtruth(self, swc_pointer):
145+
def load_groundtruth(self, swc_pointer, label_mask):
148146
"""
149147
Loads ground truth graphs and initializes the "graphs" attribute.
150148
151149
Parameters
152150
----------
153151
swc_pointer : Any
154152
Pointer to ground truth SWC files.
153+
label_mask : ImageReader
154+
Predicted segmentation mask.
155155
156156
Returns
157157
-------
@@ -160,19 +160,16 @@ def load_groundtruth(self, swc_pointer):
160160
"""
161161
# Build graphs
162162
print("\n(1) Load Ground Truth")
163-
graph_builder = gutil.GraphBuilder(
163+
graph_loader = gutil.GraphLoader(
164164
anisotropy=self.anisotropy,
165165
is_groundtruth=True,
166-
label_mask=self.label_mask,
166+
label_handler=self.label_handler,
167+
label_mask=label_mask,
167168
use_anisotropy=False,
168169
)
169-
self.graphs = graph_builder.run(swc_pointer)
170+
self.graphs = graph_loader.run(swc_pointer)
170171
self.gt_graphs = deepcopy(self.graphs)
171172

172-
# Label nodes
173-
for key in tqdm(self.graphs, desc="Labeling Graphs"):
174-
self.label_graphs(key)
175-
176173
def load_fragments(self, swc_pointer):
177174
"""
178175
Loads fragments generated from the segmentation and initializes the
@@ -190,13 +187,13 @@ def load_fragments(self, swc_pointer):
190187
"""
191188
print("\n(2) Load Fragments")
192189
if swc_pointer:
193-
graph_builder = gutil.GraphBuilder(
190+
graph_loader = gutil.GraphLoader(
194191
anisotropy=self.anisotropy,
195192
is_groundtruth=False,
196193
selected_ids=self.get_all_node_labels(),
197194
use_anisotropy=self.use_anisotropy,
198195
)
199-
self.fragment_graphs = graph_builder.run(swc_pointer)
196+
self.fragment_graphs = graph_loader.run(swc_pointer)
200197
self.set_fragment_ids()
201198
else:
202199
self.fragment_graphs = None
@@ -219,86 +216,6 @@ def set_fragment_ids(self):
219216
for key in self.fragment_graphs:
220217
self.fragment_ids.add(util.get_segment_id(key))
221218

222-
def label_graphs(self, key):
223-
"""
224-
Iterates over nodes in "graph" and stores the corresponding label from
225-
"self.label_mask") as a node-level attribute called "labels".
226-
227-
Parameters
228-
----------
229-
key : str
230-
Unique identifier of graph to be labeled.
231-
232-
Returns
233-
-------
234-
None
235-
236-
"""
237-
with ThreadPoolExecutor() as executor:
238-
# Assign threads
239-
batch = set()
240-
threads = list()
241-
visited = set()
242-
for i, j in nx.dfs_edges(self.graphs[key]):
243-
# Check if starting new batch
244-
if len(batch) == 0:
245-
root = i
246-
batch.add(i)
247-
visited.add(i)
248-
249-
# Check whether to submit batch
250-
is_node_far = self.graphs[key].dist(root, j) > 128
251-
is_batch_full = len(batch) >= 128
252-
if is_node_far or is_batch_full:
253-
threads.append(
254-
executor.submit(self.get_patch_labels, key, batch)
255-
)
256-
batch = set()
257-
258-
# Visit j
259-
if j not in visited:
260-
batch.add(j)
261-
visited.add(j)
262-
if len(batch) == 1:
263-
root = j
264-
265-
# Submit last batch
266-
threads.append(executor.submit(self.get_patch_labels, key, batch))
267-
268-
# Store results
269-
self.graphs[key].init_labels()
270-
for thread in as_completed(threads):
271-
node_to_label = thread.result()
272-
for i, label in node_to_label.items():
273-
self.graphs[key].labels[i] = label
274-
275-
def get_patch_labels(self, key, nodes):
276-
"""
277-
Gets the segment labels for a given set of nodes within a specified
278-
patch of the label mask.
279-
280-
Parameters
281-
----------
282-
key : str
283-
Unique identifier of graph to be labeled.
284-
nodes : List[int]
285-
Node IDs for which the labels are to be retrieved.
286-
287-
Returns
288-
-------
289-
dict
290-
A dictionary that maps node IDs to their respective labels.
291-
292-
"""
293-
bbox = self.graphs[key].get_bbox(nodes)
294-
label_patch = self.label_mask.read_with_bbox(bbox)
295-
node_to_label = dict()
296-
for i in nodes:
297-
voxel = self.to_local_voxels(key, i, bbox["min"])
298-
label = self.label_handler.get(label_patch[voxel])
299-
node_to_label[i] = label
300-
return node_to_label
301-
302219
def get_all_node_labels(self):
303220
"""
304221
Gets the set of unique node labels from all graphs in "self.graphs".
@@ -407,7 +324,8 @@ def run(self):
407324
# Save results
408325
prefix = "corrected-" if self.connections_path else ""
409326
path = f"{self.output_dir}/{prefix}results.csv"
410-
self.metrics.fillna(0)
327+
if self.fragment_graphs is None:
328+
self.metrics = self.metrics.drop("# Merges", axis=1)
411329
self.metrics.to_csv(path, index=True)
412330

413331
# Report results
@@ -419,10 +337,12 @@ def run(self):
419337
util.update_txt(path, f" {column_name}: {avg:.4f}")
420338

421339
n_splits = self.metrics["# Splits"].sum()
422-
n_merges = self.metrics["# Merges"].sum()
423340
util.update_txt(path, "\nTotal Results...")
424341
util.update_txt(path, " # Splits: " + str(n_splits))
425-
util.update_txt(path, " # Merges: " + str(n_merges))
342+
343+
if self.fragment_graphs is not None:
344+
n_merges = self.metrics["# Merges"].sum()
345+
util.update_txt(path, " # Merges: " + str(n_merges))
426346

427347
# -- Split Detection --
428348
def detect_splits(self):

0 commit comments

Comments
 (0)