Skip to content

Commit 3b18e27

Browse files
author
anna-grim
committed
bug: evaluate corrected segmentation
1 parent e681f26 commit 3b18e27

File tree

4 files changed

+141
-165
lines changed

4 files changed

+141
-165
lines changed

src/segmentation_skeleton_metrics/skeleton_graph.py

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
1010
"""
1111

12+
from io import StringIO
1213
from scipy.spatial import distance
1314

1415
import networkx as nx
@@ -59,7 +60,10 @@ def __init__(self, anisotropy=(1.0, 1.0, 1.0)):
5960

6061
# Instance attributes
6162
self.anisotropy = np.array(anisotropy)
63+
self.filename = None
64+
self.labels = None
6265
self.run_length = 0
66+
self.voxels = None
6367

6468
def init_labels(self):
6569
"""
@@ -91,6 +95,23 @@ def init_voxels(self, voxels):
9195
"""
9296
self.voxels = np.array(voxels, dtype=np.int32)
9397

98+
def set_filename(self, filename):
99+
"""
100+
Sets the filename attribute which corresponds to the SWC file that the
101+
graph is built from.
102+
103+
Parameters
104+
----------
105+
filename : str
106+
Name of SWC file that graph is built from.
107+
108+
Returns
109+
-------
110+
None
111+
112+
"""
113+
self.filename = filename
114+
94115
def set_nodes(self):
95116
"""
96117
Adds nodes to the graph. The nodes are assigned indices from 0 to the
@@ -111,7 +132,7 @@ def set_nodes(self):
111132
# --- Getters ---
112133
def get_labels(self):
113134
"""
114-
Gets the unique label values in the "labels" attribute.
135+
Gets the unique non-zero label values in the "labels" attribute.
115136
116137
Parameters
117138
----------
@@ -120,10 +141,13 @@ def get_labels(self):
120141
Returns
121142
-------
122143
numpy.ndarray
123-
A 1D array of unique labels assigned to nodes in the graph.
144+
A 1D array of unique non-zero labels assigned to nodes in the
145+
graph.
124146
125147
"""
126-
return np.unique(self.labels)
148+
labels = set(np.unique(self.labels))
149+
labels.discard(0)
150+
return labels
127151

128152
def nodes_with_label(self, label):
129153
"""
@@ -289,3 +313,47 @@ def upd_labels(self, nodes, label):
289313
"""
290314
for i in nodes:
291315
self.labels[i] = label
316+
317+
def to_zipped_swc(self, zip_writer, color=None):
318+
"""
319+
Writes a graph to an SWC file that is to be stored in a zip.
320+
321+
Parameters
322+
----------
323+
zip_writer : zipfile.ZipFile
324+
...
325+
color : str, optional
326+
...
327+
328+
Returns
329+
-------
330+
None
331+
332+
"""
333+
with StringIO() as text_buffer:
334+
# Preamble
335+
text_buffer.write("# COLOR " + color) if color else None
336+
text_buffer.write("# id, type, z, y, x, r, pid")
337+
338+
# Write entries
339+
n_entries = 0
340+
node_to_idx = dict()
341+
r = 5 if color else 3
342+
for i, j in nx.dfs_edges(self):
343+
# Special Case: Root
344+
x, y, z = tuple(self.voxels[i] * self.anisotropy)
345+
if len(node_to_idx) == 0:
346+
parent = -1
347+
node_to_idx[i] = 1
348+
text_buffer.write("\n" + f"1 2 {x} {y} {z} {r} {parent}")
349+
n_entries += 1
350+
351+
# General Case
352+
node = n_entries + 1
353+
parent = node_to_idx[i]
354+
node_to_idx[j] = n_entries + 1
355+
text_buffer.write("\n" + f"{node} 2 {x} {y} {z} {r} {parent}")
356+
n_entries += 1
357+
358+
# Finish
359+
zip_writer.writestr(self.filename, text_buffer.getvalue())

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 56 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,9 @@
2323
from segmentation_skeleton_metrics.utils import (
2424
graph_util as gutil,
2525
img_util,
26-
swc_util,
2726
util
2827
)
2928

30-
MIN_CNT = 40
31-
3229

3330
class SkeletonMetric:
3431
"""
@@ -155,10 +152,8 @@ def load_groundtruth(self, swc_pointer):
155152
def load_fragments(self, swc_pointer):
156153
print("\n(2) Load Fragments")
157154
if swc_pointer:
158-
coords_only = False #not self.save_projections
159155
graph_builder = gutil.GraphBuilder(
160156
anisotropy=self.anisotropy,
161-
coords_only=coords_only,
162157
selected_ids=self.get_all_node_labels(),
163158
use_anisotropy=True,
164159
)
@@ -172,7 +167,7 @@ def set_fragment_ids(self):
172167
for key in self.fragment_graphs:
173168
self.fragment_ids.add(util.get_segment_id(key))
174169

175-
def label_graphs(self, key, batch_size=64):
170+
def label_graphs(self, key, batch_size=128):
176171
"""
177172
Iterates over nodes in "graph" and stores the corresponding label from
178173
predicted segmentation mask (i.e. "self.label_mask") as a node-level
@@ -201,7 +196,7 @@ def label_graphs(self, key, batch_size=64):
201196
visited.add(i)
202197

203198
# Check whether to submit batch
204-
is_node_far = self.graphs[key].dist(root, j) > 128
199+
is_node_far = self.graphs[key].dist(root, j) > batch_size
205200
is_batch_full = len(batch) >= batch_size
206201
if is_node_far or is_batch_full:
207202
threads.append(
@@ -306,9 +301,7 @@ def init_zip_writer(self):
306301
self.zip_writer = dict()
307302
for key in self.graphs.keys():
308303
self.zip_writer[key] = ZipFile(f"{output_dir}/{key}.zip", "w")
309-
swc_util.to_zipped_swc(
310-
self.zip_writer[key], self.graphs[key],
311-
)
304+
self.graphs[key].to_zipped_swc(self.zip_writer[key])
312305

313306
# -- Main Routine --
314307
def run(self):
@@ -331,11 +324,6 @@ def run(self):
331324
self.detect_splits()
332325
self.quantify_splits()
333326

334-
# Check for prexisting merges
335-
if self.preexisting_merges:
336-
for key in self.graphs:
337-
self.adjust_metrics(key)
338-
339327
# Merge evaluation
340328
self.detect_merges()
341329
self.quantify_merges()
@@ -344,39 +332,6 @@ def run(self):
344332
full_results, avg_results = self.compile_results()
345333
return full_results, avg_results
346334

347-
def adjust_metrics(self, key):
348-
"""
349-
Adjusts the metrics of the graph associated with the given key by
350-
removing nodes corresponding to known merges and their corresponding
351-
subgraphs. Updates the total number of edges and run lengths in the
352-
graph.
353-
354-
Parameters
355-
----------
356-
key : str
357-
Identifier for the graph to adjust.
358-
359-
Returns
360-
-------
361-
None
362-
363-
"""
364-
for label in self.preexisting_merges:
365-
label = self.label_map[label] if self.label_map else label
366-
if label in self.graphs[key].get_labels():
367-
# Extract subgraph
368-
nodes = self.graphs[key].nodes_with_label(label)
369-
subgraph = self.graphs[key].subgraph(nodes)
370-
371-
# Adjust metrics
372-
n_edges = subgraph.number_of_edges()
373-
rls = gutil.compute_run_lengths(subgraph)
374-
self.graphs[key].graph["run_length"] -= np.sum(rls)
375-
self.graphs[key].graph["n_edges"] -= n_edges
376-
377-
# Update graph
378-
self.graphs[key].remove_nodes_from(nodes)
379-
380335
# -- Split Detection --
381336
def detect_splits(self):
382337
"""
@@ -393,7 +348,7 @@ def detect_splits(self):
393348
394349
"""
395350
pbar = tqdm(total=len(self.graphs), desc="Split Detection")
396-
with ProcessPoolExecutor() as executor:
351+
with ProcessPoolExecutor(max_workers=8) as executor:
397352
# Assign processes
398353
processes = list()
399354
for key, graph in self.graphs.items():
@@ -470,7 +425,12 @@ def detect_merges(self):
470425
self.count_merges(key, kdtree)
471426
pbar.update(1)
472427

473-
# Process merges
428+
# Adjust metrics (if applicable)
429+
if self.preexisting_merges:
430+
for key in self.graphs:
431+
self.adjust_metrics(key)
432+
433+
# Find graphs with common node labels
474434
for (key_1, key_2), label in self.find_label_intersections():
475435
self.process_merge(key_1, label, -1)
476436
self.process_merge(key_2, label, -1)
@@ -502,7 +462,7 @@ def count_merges(self, key, kdtree):
502462
"""
503463
for label in self.get_node_labels(key):
504464
nodes = self.graphs[key].nodes_with_label(label)
505-
if len(nodes) > MIN_CNT:
465+
if len(nodes) > 50:
506466
for label in self.label_handler.get_class(label):
507467
if label in self.fragment_ids:
508468
self.is_fragment_merge(key, label, kdtree)
@@ -539,16 +499,45 @@ def is_fragment_merge(self, key, label, kdtree):
539499
self.merged_labels.add((key, equiv_label, tuple(xyz)))
540500

541501
# Save merged fragment (if applicable)
542-
if self.save_projections and label in self.fragment_graphs:
543-
swc_util.to_zipped_swc(
544-
self.zip_writer[key], self.fragment_graphs[label]
545-
)
502+
if self.save_projections:
503+
fragment_graph.to_zipped_swc(self.zip_writer[key])
546504
break
547505

548-
def find_graph_from_label(self, label):
549-
for key in self.fragment_graphs:
550-
if label == util.get_segment_id(key):
551-
return self.fragment_graphs[key]
506+
def adjust_metrics(self, key):
507+
"""
508+
Adjusts the metrics of the graph associated with the given key by
509+
removing nodes corresponding to known merges and their corresponding
510+
subgraphs. Updates the total number of edges and run lengths in the
511+
graph.
512+
513+
Parameters
514+
----------
515+
key : str
516+
Identifier for the graph to adjust.
517+
518+
Returns
519+
-------
520+
None
521+
522+
"""
523+
visited = set()
524+
for label in self.preexisting_merges:
525+
label = self.label_handler.mapping[label]
526+
if label in self.graphs[key].get_labels():
527+
if label not in visited and label != 0:
528+
# Get component with label
529+
nodes = self.graphs[key].nodes_with_label(label)
530+
root = util.sample_once(list(nodes))
531+
532+
# Adjust metrics
533+
rl = self.graphs[key].run_length_from(root)
534+
self.graphs[key].run_length -= np.sum(rl)
535+
self.graphs[key].graph["n_edges"] -= len(nodes) - 1
536+
537+
# Update graph
538+
self.graphs[key].remove_nodes_from(nodes)
539+
visited.add(label)
540+
print("# nodes deleted:", len(nodes))
552541

553542
def find_label_intersections(self):
554543
"""
@@ -673,7 +662,7 @@ def get_merged_label(self, label):
673662
for l in self.label_handler.get_class(label):
674663
if l in self.fragment_graphs.keys():
675664
return l
676-
return self.inverse_label_map[label]
665+
return self.label_handler.inverse_mapping[label]
677666

678667
# -- Compute Metrics --
679668
def compile_results(self):
@@ -866,7 +855,13 @@ def list_metrics(self):
866855
]
867856
return metrics
868857

869-
# -- util --
858+
# -- Helpers --
859+
def find_graph_from_label(self, label):
860+
for key in self.fragment_graphs:
861+
if label == util.get_segment_id(key):
862+
return self.fragment_graphs[key]
863+
return None
864+
870865
def physical_dist(self, voxel_1, voxel_2):
871866
xyz_1 = img_util.to_physical(voxel_1, self.anisotropy)
872867
xyz_2 = img_util.to_physical(voxel_2, self.anisotropy)
@@ -896,40 +891,6 @@ def to_local_voxels(self, key, i, offset):
896891

897892

898893
# -- util --
899-
def find_sites(graphs, get_labels):
900-
"""
901-
Detects merges between ground truth graphs which are considered to be
902-
potential merge sites.
903-
904-
Parameters
905-
----------
906-
graphs : dict
907-
Dictionary where the keys are graph ids and values are graphs.
908-
get_labels : func
909-
Gets the label of a node in "graphs".
910-
911-
Returns
912-
-------
913-
merge_ids : set[tuple]
914-
Set of tuples containing a tuple of graph ids and common label between
915-
the graphs.
916-
917-
"""
918-
merge_ids = set()
919-
visited = set()
920-
for key_1 in graphs:
921-
for key_2 in graphs:
922-
keys = frozenset((key_1, key_2))
923-
if key_1 != key_2 and keys not in visited:
924-
visited.add(keys)
925-
intersection = get_labels(key_1).intersection(
926-
get_labels(key_2)
927-
)
928-
for label in intersection:
929-
merge_ids.add((keys, label))
930-
return merge_ids
931-
932-
933894
def generate_result(keys, stats):
934895
"""
935896
Reorders items in "stats" with respect to the order defined by "keys".

0 commit comments

Comments
 (0)