Skip to content

Commit 057cf34

Browse files
author
anna-grim
committed
refactor: optimized graph labeling and split det
1 parent 0e5d7f3 commit 057cf34

File tree

8 files changed

+857
-366
lines changed

8 files changed

+857
-366
lines changed

demo/demo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import numpy as np
2-
from tifffile import imread
32
from xlwt import Workbook
43

54
from segmentation_skeleton_metrics.skeleton_metric import SkeletonMetric
5+
from segmentation_skeleton_metrics.utils.img_util import TiffReader
66

77

88
def evaluate():
99
# Initializations
10-
pred_labels = imread(pred_labels_path)
10+
pred_labels = TiffReader(pred_labels_path)
1111
skeleton_metric = SkeletonMetric(
1212
target_swcs_pointer,
1313
pred_labels,
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""
2+
Created on Wed March 11 16:00:00 2024
3+
4+
@author: Anna Grim
5+
6+
7+
Detects splits in a predicted segmentation by comparing the ground truth
8+
skeletons (i.e. graphs) to the predicted segmentation label mask.
9+
10+
"""
11+
12+
from collections import deque
13+
14+
import networkx as nx
15+
16+
from segmentation_skeleton_metrics.utils import graph_util as gutil
17+
18+
19+
def correct_graph_misalignments(process_id, graph):
20+
"""
21+
Adjusts misalignments between ground truth graph and segmentation mask.
22+
23+
Parameters
24+
----------
25+
graph : networkx.Graph
26+
Graph that represents a ground truth neuron.
27+
28+
Returns
29+
-------
30+
graph : networkx.Graph
31+
Labeled graph with omit and split edges removed.
32+
33+
"""
34+
# Initializations
35+
split_cnt = 0
36+
source = gutil.sample_leaf(graph)
37+
dfs_edges = deque(list(nx.dfs_edges(graph, source=source)))
38+
visited_edges = set()
39+
40+
# Main
41+
while len(dfs_edges) > 0:
42+
# Check whether to visit edge
43+
i, j = dfs_edges.popleft()
44+
if frozenset({i, j}) in visited_edges:
45+
continue
46+
47+
# Visit edge
48+
label_i = graph.nodes[i]["label"]
49+
label_j = graph.nodes[j]["label"]
50+
if is_split(label_i, label_j):
51+
graph.remove_edge(i, j)
52+
split_cnt += 1
53+
elif label_j == 0:
54+
check_misalignment(graph, visited_edges, i, j)
55+
visited_edges.add(frozenset({i, j}))
56+
57+
# Finish
58+
split_percent = split_cnt / graph.graph["n_edges"]
59+
return process_id, graph, split_percent
60+
61+
62+
def check_misalignment(graph, visited_edges, nb, root):
63+
"""
64+
Determines whether zero-valued label correspond to a split or misalignment
65+
between the graph and segmentation mask.
66+
67+
Parameters
68+
----------
69+
graph : networkx.Graph
70+
Graph that represents a ground truth neuron.
71+
visited_edges : list[tuple]
72+
List of edges in "graph" that have been visited.
73+
nb : int
74+
Neighbor of "root".
75+
root : int
76+
Node where possible split starts (i.e. zero-valued label).
77+
78+
Returns
79+
-------
80+
dfs_edges : list[tuple].
81+
Updated "dfs_edges" with visited edges removed.
82+
graph : networkx.Graph
83+
Ground truth graph with nodes labeled with respect to corresponding
84+
voxel in predicted segmentation.
85+
86+
"""
87+
# Search
88+
label_collisions = set()
89+
queue = deque([root])
90+
visited = set()
91+
while len(queue) > 0:
92+
# Visit node
93+
j = queue.popleft()
94+
if graph.nodes[j]["label"] != 0:
95+
label_collisions.add(graph.nodes[j]["label"])
96+
visited.add(j)
97+
98+
# Update queue
99+
if graph.nodes[j]["label"] == 0:
100+
for k in graph.neighbors(j):
101+
if k not in visited:
102+
if frozenset({j, k}) not in visited_edges or k == nb:
103+
queue.append(k)
104+
visited_edges.add(frozenset({j, k}))
105+
106+
# Upd zero nodes
107+
if len(label_collisions) == 1:
108+
label = label_collisions.pop()
109+
graph = gutil.upd_labels(graph, visited, label)
110+
111+
112+
# -- Helpers --
113+
def is_split(a, b):
114+
"""
115+
Checks if "a" and "b" are positive and not equal.
116+
117+
Parameters
118+
----------
119+
a : int
120+
label at node i.
121+
b : int
122+
label at node j.
123+
124+
Returns
125+
-------
126+
bool
127+
Indication of whether there is a split.
128+
129+
"""
130+
return (a > 0 and b > 0) and (a != b)

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 105 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,25 @@
88
"""
99

1010

11-
from concurrent.futures import ThreadPoolExecutor, as_completed
11+
from concurrent.futures import (
12+
as_completed,
13+
ProcessPoolExecutor,
14+
ThreadPoolExecutor,
15+
)
1216
from copy import deepcopy
13-
from scipy.spatial import KDTree
17+
from scipy.spatial import distance, KDTree
1418
from time import time
1519
from tqdm import tqdm
1620
from zipfile import ZipFile
1721

1822
import networkx as nx
1923
import numpy as np
2024
import os
21-
import tensorstore as ts
2225

23-
from segmentation_skeleton_metrics import split_detection
26+
from segmentation_skeleton_metrics import graph_segmentation_alignment as gsa
2427
from segmentation_skeleton_metrics.utils import (
2528
graph_util as gutil,
29+
img_util,
2630
swc_util,
2731
util
2832
)
@@ -108,8 +112,8 @@ def __init__(
108112
self.output_dir = output_dir
109113
self.preexisting_merges = preexisting_merges
110114

111-
# Load Labels, Graphs, Fragments
112-
print("\n(1) Initializations")
115+
# Load Data
116+
print("\n(1) Load Data")
113117
assert type(valid_labels) is set if valid_labels else True
114118
self.label_mask = pred_labels
115119
self.valid_labels = valid_labels
@@ -177,7 +181,7 @@ def init_graphs(self, paths):
177181
self.graphs[key]
178182
)
179183

180-
def set_node_labels(self, key):
184+
def set_node_labels(self, key, batch_size=128):
181185
"""
182186
Iterates over nodes in "graph" and stores the corresponding label from
183187
predicted segmentation mask (i.e. "self.label_mask") as a node-level
@@ -195,17 +199,62 @@ def set_node_labels(self, key):
195199
"""
196200
with ThreadPoolExecutor() as executor:
197201
# Assign threads
198-
threads = []
199-
for i in self.graphs[key].nodes:
200-
voxel = tuple(self.graphs[key].nodes[i]["voxel"])
201-
threads.append(executor.submit(self.get_label, i, voxel))
202+
batch = set()
203+
threads = list()
204+
visited = set()
205+
for i, j in nx.dfs_edges(self.graphs[key]):
206+
# Check for new batch
207+
if len(batch) == 0:
208+
root = i
209+
batch.add(i)
210+
visited.add(i)
211+
212+
# Check whether to submit batch
213+
is_node_far = self.dist(key, root, j) > 128
214+
is_batch_full = len(batch) >= batch_size
215+
if is_node_far or is_batch_full:
216+
threads.append(
217+
executor.submit(self.get_patch_labels, key, batch)
218+
)
219+
batch = set()
202220

203-
# Store label
204-
for thread in as_completed(threads):
205-
i, label = thread.result()
206-
self.graphs[key].nodes[i].update({"label": label})
221+
# Visit j
222+
if j not in visited:
223+
batch.add(j)
224+
visited.add(j)
225+
if len(batch) == 1:
226+
root = j
227+
228+
# Submit last thread
229+
threads.append(executor.submit(self.get_patch_labels, key, batch))
207230

208-
def get_label(self, i, voxel):
231+
# Process results
232+
for thread in as_completed(threads):
233+
node_to_label = thread.result()
234+
for i, label in node_to_label.items():
235+
self.graphs[key].nodes[i].update({"label": label})
236+
237+
def get_patch_labels(self, key, nodes):
238+
# Get bounding box
239+
bbox = {"min": [np.inf, np.inf, np.inf], "max": [0, 0, 0]}
240+
for i in nodes:
241+
voxel = deepcopy(self.graphs[key].nodes[i]["voxel"])
242+
for idx in range(3):
243+
if voxel[idx] < bbox["min"][idx]:
244+
bbox["min"][idx] = voxel[idx]
245+
if voxel[idx] >= bbox["max"][idx]:
246+
bbox["max"][idx] = voxel[idx] + 1
247+
248+
# Read labels
249+
label_patch = self.label_mask.read_with_bbox(bbox)
250+
node_to_label = dict()
251+
for i in nodes:
252+
voxel = self.to_local_voxels(key, i, bbox["min"])
253+
label = self.adjust_label(label_patch[voxel])
254+
node_to_label[i] = label
255+
return node_to_label
256+
257+
def adjust_label(self, label):
209258
"""
210259
Gets label of voxel in "self.label_mask".
211260
@@ -222,18 +271,11 @@ def get_label(self, i, voxel):
222271
Label of voxel.
223272
224273
"""
225-
# Read label
226-
if isinstance(self.label_mask, ts.TensorStore):
227-
label = int(self.label_mask[voxel].read().result())
228-
else:
229-
label = self.label_mask[voxel]
230-
231-
# Check whether to update label
232274
if self.label_map:
233275
label = self.get_equivalent_label(label)
234276
elif self.valid_labels:
235277
label = 0 if label not in self.valid_labels else label
236-
return i, label
278+
return label
237279

238280
def get_equivalent_label(self, label):
239281
"""
@@ -443,15 +485,29 @@ def detect_splits(self):
443485
444486
"""
445487
t0 = time()
446-
for key, graph in tqdm(self.graphs.items(), desc="Split Detection:"):
447-
# Detection
448-
graph = split_detection.run(graph, self.graphs[key])
488+
pbar = tqdm(total=len(self.graphs), desc="Split Detection:")
489+
with ProcessPoolExecutor() as executor:
490+
# Assign processes
491+
processes = list()
492+
for key, graph in self.graphs.items():
493+
processes.append(
494+
executor.submit(
495+
gsa.correct_graph_misalignments,
496+
key,
497+
graph,
498+
)
499+
)
449500

450-
# Update graph by removing omits (i.e. nodes labeled 0)
451-
self.graphs[key] = gutil.delete_nodes(graph, 0)
452-
self.key_to_label_to_nodes[key] = gutil.init_label_to_nodes(
453-
self.graphs[key]
454-
)
501+
# Store results
502+
self.split_percent = dict()
503+
for process in as_completed(processes):
504+
key, graph, split_percent = process.result()
505+
self.graphs[key] = gutil.delete_nodes(graph, 0)
506+
self.key_to_label_to_nodes[key] = gutil.init_label_to_nodes(
507+
self.graphs[key]
508+
)
509+
self.split_percent[key] = split_percent
510+
pbar.update(1)
455511

456512
# Report runtime
457513
t, unit = util.time_writer(time() - t0)
@@ -505,15 +561,19 @@ def detect_merges(self):
505561

506562
# Count total merges
507563
if self.fragment_graphs:
564+
pbar = tqdm(total=len(self.graphs), desc="Count Merges:")
508565
for key, graph in self.graphs.items():
509566
if graph.number_of_nodes() > 0:
510567
kdtree = KDTree(gutil.to_array(graph))
511568
self.count_merges(key, kdtree)
569+
pbar.update(1)
512570

513571
# Process merges
572+
pbar = tqdm(total=len(self.graphs), desc="Compute Percent Merged:")
514573
for (key_1, key_2), label in self.find_label_intersections():
515574
self.process_merge(key_1, label, -1)
516575
self.process_merge(key_2, label, -1)
576+
pbar.update(1)
517577

518578
for key, label, xyz in self.merged_labels:
519579
self.process_merge(key, label, xyz, update_merged_labels=False)
@@ -583,7 +643,7 @@ def is_fragment_merge(self, key, label, kdtree):
583643
equivalent_label = label
584644

585645
# Record merge mistake
586-
xyz = util.to_physical(voxel)
646+
xyz = img_util.to_physical(voxel)
587647
self.merge_cnt[key] += 1
588648
self.merged_labels.add((key, equivalent_label, tuple(xyz)))
589649
if self.save_projections:
@@ -776,6 +836,7 @@ def generate_full_results(self):
776836
"# splits": generate_result(keys, self.split_cnt),
777837
"# merges": generate_result(keys, self.merge_cnt),
778838
"% omit": generate_result(keys, self.omit_percent),
839+
"% split": generate_result(keys, self.split_percent),
779840
"% merged": generate_result(keys, self.merged_percent),
780841
"edge accuracy": generate_result(keys, self.edge_accuracy),
781842
"erl": generate_result(keys, self.erl),
@@ -801,6 +862,7 @@ def generate_avg_results(self):
801862
"# splits": self.avg_result(self.split_cnt),
802863
"# merges": self.avg_result(self.merge_cnt),
803864
"% omit": self.avg_result(self.omit_percent),
865+
"% split": self.avg_result(self.split_percent),
804866
"% merged": self.avg_result(self.merged_percent),
805867
"edge accuracy": self.avg_result(self.edge_accuracy),
806868
"erl": self.avg_result(self.erl),
@@ -925,6 +987,11 @@ def list_metrics(self):
925987
return metrics
926988

927989
# -- util --
990+
def dist(self, key, i, j):
991+
xyz_i = self.graphs[key].nodes[i]["voxel"]
992+
xyz_j = self.graphs[key].nodes[j]["voxel"]
993+
return distance.euclidean(xyz_i, xyz_j)
994+
928995
def init_counter(self):
929996
"""
930997
Initializes a dictionary that is used to count some type of mistake
@@ -942,6 +1009,11 @@ def init_counter(self):
9421009
"""
9431010
return {key: 0 for key in self.graphs}
9441011

1012+
def to_local_voxels(self, key, i, offset):
1013+
voxel = np.array(self.graphs[key].nodes[i]["voxel"])
1014+
offset = np.array(offset)
1015+
return tuple(voxel - offset)
1016+
9451017

9461018
# -- util --
9471019
def find_sites(graphs, get_labels):

0 commit comments

Comments
 (0)