Skip to content

Commit 99f8ad2

Browse files
author
anna-grim
committed
doc: skeleton graph
1 parent 9857dea commit 99f8ad2

File tree

5 files changed

+164
-14
lines changed

5 files changed

+164
-14
lines changed

demo/demo.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1+
"""
2+
Created on Wed Dec 21 19:00:00 2022
3+
4+
@author: Anna Grim
5+
6+
7+
Code that runs of demo of using this library to compute skeleton metrics.
8+
9+
"""
10+
111
import numpy as np
212
from xlwt import Workbook
313

@@ -6,18 +16,32 @@
616

717

818
def evaluate():
19+
"""
20+
Evaluates the accuracy of a predicted segmentation by comparing it to a
21+
set of ground truth skeletons, then reports and saves various performance
22+
metrics.
23+
24+
Parameters
25+
----------
26+
None
27+
28+
Returns
29+
-------
30+
None
31+
32+
"""
933
# Initializations
1034
pred_labels = TiffReader(pred_labels_path)
1135
skeleton_metric = SkeletonMetric(
12-
target_swcs_pointer,
36+
groundtruth_pointer,
1337
pred_labels,
14-
fragments_pointer=pred_swcs_pointer,
38+
fragments_pointer=fragments_pointer,
1539
output_dir=output_dir,
1640
)
1741
full_results, avg_results = skeleton_metric.run()
1842

1943
# Report results
20-
print(f"Averaged Results...")
44+
print(f"\nAveraged Results...")
2145
for key in avg_results.keys():
2246
print(f" {key}: {round(avg_results[key], 4)}")
2347

@@ -57,8 +81,8 @@ def save_results(path, stats):
5781
# Initializations
5882
output_dir = "./"
5983
pred_labels_path = "./pred_labels.tif"
60-
pred_swcs_pointer = "./pred_swcs.zip"
61-
target_swcs_pointer = "./target_swcs.zip"
84+
fragments_pointer = "./pred_swcs.zip"
85+
groundtruth_pointer = "./target_swcs.zip"
6286

6387
# Run
6488
evaluate()

src/segmentation_skeleton_metrics/skeleton_graph.py

Lines changed: 132 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
"""
2+
Created on Wed Dec 21 19:00:00 2022
3+
4+
@author: Anna Grim
5+
6+
7+
8+
Implementation of a custom subclass of NetworkX.Graph called SkeletonGraph.
9+
10+
"""
11+
112
from scipy.spatial import distance
213

314
import networkx as nx
@@ -7,30 +18,128 @@
718

819

920
class SkeletonGraph(nx.Graph):
21+
"""
22+
A subclass of the NetworkX.Graph that represents a skeleton graph with
23+
additional functionality for handling labels and voxel coordinates
24+
corresponding to the nodes. Note that node IDs directly index into the
25+
"labels" and "voxels" attributes.
26+
27+
Attributes
28+
----------
29+
anisotropy : ArrayLike
30+
Image to physical coordinates scaling factors to account for the
31+
anisotropy of the microscope.
32+
run_length : float
33+
Physical path length of the graph.
34+
labels : numpy.ndarray
35+
A 1D array that contains a label value associated with each node.
36+
voxels : numpy.ndarray
37+
A 3D array that contains a voxel coordinate of each node.
38+
39+
"""
1040

1141
def __init__(self, anisotropy=(1.0, 1.0, 1.0)):
42+
"""
43+
Initializes a SkeletonGraph, including setting the anisotropy and
44+
initializing the run length attribute.
45+
46+
Parameters
47+
----------
48+
anisotropy : ArrayLike, optional
49+
Image to physical coordinates scaling factors to account for the
50+
anisotropy of the microscope. The default is (1.0, 1.0, 1.0).
51+
52+
Returns
53+
-------
54+
None
55+
56+
"""
1257
# Call parent class
1358
super(SkeletonGraph, self).__init__()
1459

1560
# Instance attributes
16-
self.anisotropy = anisotropy
61+
self.anisotropy = np.array(anisotropy)
1762
self.run_length = 0
1863

19-
def set_labels(self):
64+
def init_labels(self):
65+
"""
66+
Initializes the "labels" attribute for the graph.
67+
68+
Parameters
69+
----------
70+
None
71+
72+
Returns
73+
-------
74+
None
75+
76+
"""
2077
self.labels = np.zeros((self.number_of_nodes()), dtype=int)
2178

79+
def init_voxels(self, voxels):
80+
"""
81+
Initializes the "voxels" attribute for the graph.
82+
83+
Parameters
84+
----------
85+
None
86+
87+
Returns
88+
-------
89+
None
90+
91+
"""
92+
self.voxels = np.array(voxels, dtype=np.int32)
93+
2294
def set_nodes(self):
95+
"""
96+
Adds nodes to the graph. The nodes are assigned indices from 0 to the
97+
total number of voxels in the image.
98+
99+
Parameters
100+
----------
101+
None
102+
103+
Returns
104+
-------
105+
None
106+
107+
"""
23108
num_nodes = len(self.voxels)
24109
self.add_nodes_from(np.arange(num_nodes))
25110

26-
def set_voxels(self, voxels):
27-
self.voxels = np.array(voxels, dtype=np.int32)
28-
29111
# --- Getters ---
30112
def get_labels(self):
113+
"""
114+
Gets the unique label values in the "labels" attribute.
115+
116+
Parameters
117+
----------
118+
None
119+
120+
Returns
121+
-------
122+
numpy.ndarray
123+
A 1D array of unique labels assigned to nodes in the graph.
124+
125+
"""
31126
return np.unique(self.labels)
32127

33128
def nodes_with_label(self, label):
129+
"""
130+
Gets the IDs of nodes that have the specified label value.
131+
132+
Parameters
133+
----------
134+
label : int
135+
Label value to search for in the "labels" attribute.
136+
137+
Returns
138+
-------
139+
numpy.ndarray
140+
A 1D array of node IDs that have the specified label.
141+
142+
"""
34143
return np.where(self.labels == label)[0]
35144

36145
# --- Computation ---
@@ -69,14 +178,31 @@ def physical_dist(self, i, j):
69178
Returns
70179
-------
71180
float
72-
Distance between physical coordinates of the given nodes.
181+
Euclidea distance between physical coordinates of the given nodes.
73182
74183
"""
75184
xyz_i = self.voxels[i] * self.anisotropy
76185
xyz_j = self.voxels[j] * self.anisotropy
77186
return distance.euclidean(xyz_i, xyz_j)
78187

79188
def get_bbox(self, nodes):
189+
"""
190+
Calculates the bounding box that contains the voxel coordinates for a
191+
given collection of nodes.
192+
193+
Parameters
194+
----------
195+
nodes : Container
196+
A collection of node indices for which to compute the bounding box.
197+
198+
Returns
199+
-------
200+
dict
201+
Dictionary containing the bounding box coordinates:
202+
- "min": minimum voxel coordinates along each axis.
203+
- "max": maximum voxel coordinates along each axis.
204+
205+
"""
80206
bbox_min = np.inf * np.ones(3)
81207
bbox_max = np.zeros(3)
82208
for i in nodes:

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def label_graphs(self, key, batch_size=64):
220220
threads.append(executor.submit(self.get_patch_labels, key, batch))
221221

222222
# Process results
223-
self.graphs[key].set_labels()
223+
self.graphs[key].init_labels()
224224
for thread in as_completed(threads):
225225
node_to_label = thread.result()
226226
for i, label in node_to_label.items():

src/segmentation_skeleton_metrics/utils/graph_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def to_graph(self, swc_dict):
8989
"""
9090
# Initialize graph
9191
graph = SkeletonGraph(anisotropy=self.anisotropy)
92-
graph.set_voxels(swc_dict["voxel"])
92+
graph.init_voxels(swc_dict["voxel"])
9393

9494
# Build graph
9595
if not self.coords_only:

src/segmentation_skeleton_metrics/utils/swc_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525

2626
from collections import deque
2727
from concurrent.futures import (
28+
as_completed,
2829
ProcessPoolExecutor,
2930
ThreadPoolExecutor,
30-
as_completed,
3131
)
3232
from io import StringIO
3333
from tqdm import tqdm

0 commit comments

Comments
 (0)