Skip to content

Commit e17e18d

Browse files
author
anna-grim
committed
refactor: skeleton graph label attribute
1 parent 55d8de3 commit e17e18d

File tree

4 files changed

+183
-139
lines changed

4 files changed

+183
-139
lines changed
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from scipy.spatial import distance
2+
3+
import networkx as nx
4+
import numpy as np
5+
6+
from segmentation_skeleton_metrics.utils import util
7+
8+
9+
class SkeletonGraph(nx.Graph):
10+
11+
def __init__(self, anisotropy=(1.0, 1.0, 1.0)):
12+
# Call parent class
13+
super(SkeletonGraph, self).__init__()
14+
15+
# Instance attributes
16+
self.anisotropy = anisotropy
17+
self.run_length = 0
18+
19+
def set_labels(self):
20+
self.labels = np.zeros((self.number_of_nodes()), dtype=int)
21+
22+
def set_nodes(self):
23+
num_nodes = len(self.voxels)
24+
self.add_nodes_from(np.arange(num_nodes))
25+
26+
def set_voxels(self, voxels):
27+
self.voxels = np.array(voxels, dtype=np.int32)
28+
29+
# --- Getters ---
30+
def get_labels(self):
31+
return np.unique(self.labels)
32+
33+
def nodes_with_label(self, label):
34+
return np.where(self.labels == label)[0]
35+
36+
# --- Computation ---
37+
def dist(self, i, j):
38+
"""
39+
Computes the Euclidean distance between the voxel coordinates
40+
cooresponding to the given nodes.
41+
42+
Parameters
43+
----------
44+
i : int
45+
Node ID.
46+
j : int
47+
Node ID.
48+
49+
Returns
50+
-------
51+
float
52+
Distance between voxel coordinates of the given nodes.
53+
54+
"""
55+
return distance.euclidean(self.voxels[i], self.voxels[j])
56+
57+
def physical_dist(self, i, j):
58+
"""
59+
Computes the Euclidean distance between the physical coordinates
60+
cooresponding to the given nodes.
61+
62+
Parameters
63+
----------
64+
i : int
65+
Node ID.
66+
j : int
67+
Node ID.
68+
69+
Returns
70+
-------
71+
float
72+
Distance between physical coordinates of the given nodes.
73+
74+
"""
75+
xyz_i = self.voxels[i] * self.anisotropy
76+
xyz_j = self.voxels[j] * self.anisotropy
77+
return distance.euclidean(xyz_i, xyz_j)
78+
79+
def get_bbox(self, nodes):
80+
bbox_min = np.inf * np.ones(3)
81+
bbox_max = np.zeros(3)
82+
for i in nodes:
83+
bbox_min = np.minimum(bbox_min, self.voxels[i])
84+
bbox_max = np.maximum(bbox_max, self.voxels[i] + 1)
85+
return {"min": bbox_min.astype(int), "max": bbox_max.astype(int)}
86+
87+
def remove_nodes_with_label(self, label):
88+
"""
89+
Removes nodes with the given label
90+
91+
Parameters
92+
----------
93+
label : int
94+
Label to be deleted from graph.
95+
96+
Returns
97+
-------
98+
None
99+
100+
"""
101+
nodes = self.nodes_with_label(label)
102+
self.remove_nodes_from(nodes)
103+
104+
def run_lengths(self):
105+
"""
106+
Computes the path length of each connected component.
107+
108+
Parameters
109+
----------
110+
None
111+
112+
Returns
113+
-------
114+
numpy.ndarray
115+
Array containing run lengths of each connected component.
116+
117+
"""
118+
run_lengths = []
119+
if self.number_of_nodes() > 0:
120+
for nodes in nx.connected_components(self):
121+
root = util.sample_once(nodes)
122+
run_lengths.append(self.run_length_from(root))
123+
else:
124+
run_lengths.append(0)
125+
return np.array(run_lengths)
126+
127+
def run_length_from(self, root):
128+
"""
129+
Computes the path length of the connected component that contains
130+
"root".
131+
132+
Parameters
133+
----------
134+
graph : networkx.Graph
135+
Graph to be parsed.
136+
137+
Returns
138+
-------
139+
float
140+
Path length.
141+
142+
"""
143+
run_length = 0
144+
for i, j in nx.dfs_edges(self, source=root):
145+
run_length += self.physical_dist(i, j)
146+
return run_length
147+
148+
def upd_labels(self, nodes, label):
149+
"""
150+
Updates the label of each node in "nodes" with "label".
151+
152+
Parameters
153+
----------
154+
nodes : List[int]
155+
Nodes to be updated.
156+
label : int
157+
Updated label.
158+
159+
Returns
160+
-------
161+
None
162+
163+
"""
164+
for i in nodes:
165+
self.labels[i] = label

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,11 @@ def label_graphs(self, key, batch_size=64):
221221
threads.append(executor.submit(self.get_patch_labels, key, batch))
222222

223223
# Process results
224-
n_nodes = self.graphs[key].number_of_nodes()
225-
self.graphs[key].graph["label"] = np.zeros((n_nodes), dtype=int)
224+
self.graphs[key].set_labels()
226225
for thread in as_completed(threads):
227226
node_to_label = thread.result()
228227
for i, label in node_to_label.items():
229-
self.graphs[key].graph["label"][i] = label
228+
self.graphs[key].labels[i] = label
230229

231230
def get_patch_labels(self, key, nodes):
232231
bbox = self.graphs[key].get_bbox(nodes)
@@ -411,7 +410,7 @@ def detect_splits(self):
411410
self.split_percent = dict()
412411
for process in as_completed(processes):
413412
key, graph, split_percent = process.result()
414-
self.graphs[key] = gutil.remove_nodes(graph, 0)
413+
self.graphs[key] = graph
415414
self.split_percent[key] = split_percent
416415
pbar.update(1)
417416

src/segmentation_skeleton_metrics/split_detection.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313

1414
import networkx as nx
1515

16-
from segmentation_skeleton_metrics.utils import graph_util as gutil
17-
1816

1917
def run(process_id, graph):
2018
"""
@@ -33,7 +31,7 @@ def run(process_id, graph):
3331
"""
3432
# Initializations
3533
split_cnt = 0
36-
source = gutil.sample_leaf(graph)
34+
source = get_leaf(graph)
3735
dfs_edges = deque(list(nx.dfs_edges(graph, source=source)))
3836
visited_edges = set()
3937

@@ -45,8 +43,8 @@ def run(process_id, graph):
4543
continue
4644

4745
# Visit edge
48-
label_i = int(graph.graph["label"][i])
49-
label_j = int(graph.graph["label"][j])
46+
label_i = int(graph.labels[i])
47+
label_j = int(graph.labels[j])
5048
if is_split(label_i, label_j):
5149
graph.remove_edge(i, j)
5250
split_cnt += 1
@@ -56,6 +54,7 @@ def run(process_id, graph):
5654

5755
# Finish
5856
split_percent = split_cnt / graph.graph["n_edges"]
57+
graph.remove_nodes_with_label(0)
5958
return process_id, graph, split_percent
6059

6160

@@ -91,7 +90,7 @@ def check_misalignment(graph, visited_edges, nb, root):
9190
while len(queue) > 0:
9291
# Visit node
9392
j = queue.popleft()
94-
label_j = int(graph.graph["label"][j])
93+
label_j = int(graph.labels[j])
9594
if label_j != 0:
9695
label_collisions.add(label_j)
9796
visited.add(j)
@@ -107,7 +106,7 @@ def check_misalignment(graph, visited_edges, nb, root):
107106
# Upd zero nodes
108107
if len(label_collisions) == 1:
109108
label = label_collisions.pop()
110-
upd_labels(graph, visited, label)
109+
graph.upd_labels(visited, label)
111110

112111

113112
# -- Helpers --
@@ -131,24 +130,21 @@ def is_split(a, b):
131130
return (a > 0 and b > 0) and (a != b)
132131

133132

134-
def upd_labels(graph, nodes, label):
133+
def get_leaf(graph):
135134
"""
136-
Updates the label of each node in "nodes" with "label".
135+
Gets a leaf node from "graph".
137136
138137
Parameters
139138
----------
140139
graph : networkx.Graph
141-
Graph to be updated.
142-
nodes : list
143-
List of nodes to be updated.
144-
label : int
145-
New label of each node in "nodes".
140+
Graph to be sampled from.
146141
147142
Returns
148143
-------
149-
networkx.Graph
150-
Updated graph.
144+
int
145+
Leaf node of "graph"
151146
152147
"""
153-
for i in nodes:
154-
graph.graph["label"][i] = label
148+
for i in graph.nodes:
149+
if graph.degree[i] == 1:
150+
return i

0 commit comments

Comments
 (0)