Skip to content

Commit 1bc336e

Browse files
anna-grimanna-grim
andauthored
refactor: valid labels (#51)
Co-authored-by: anna-grim <[email protected]>
1 parent 62356e6 commit 1bc336e

File tree

2 files changed

+150
-102
lines changed

2 files changed

+150
-102
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from segmentation_skeleton_metrics import graph_utils as gutils
2020
from segmentation_skeleton_metrics import split_detection, utils
21+
from segmentation_skeleton_metrics import swc_utils
2122
from segmentation_skeleton_metrics.swc_utils import (
2223
get_xyz_coords,
2324
save,
@@ -49,65 +50,88 @@ def __init__(
4950
pred_swc_paths,
5051
target_swc_paths,
5152
anisotropy=[1.0, 1.0, 1.0],
52-
ignore_boundary_mistakes=False,
5353
black_holes_xyz_id=None,
5454
black_hole_radius=24,
5555
equivalent_ids=None,
56-
valid_ids=None,
57-
write_to_swc=False,
56+
ignore_boundary_mistakes=False,
5857
output_dir=None,
58+
pred_on_cloud=False,
59+
valid_size_threshold=40.0,
60+
write_to_swc=False,
5961
):
6062
"""
6163
Constructs skeleton metric object that evaluates the quality of a
6264
predicted segmentation.
6365
6466
Parameters
6567
----------
68+
pred_labels : numpy.ndarray or tensorstore.TensorStore
69+
Predicted segmentation mask.
70+
pred_swc_paths : list[str]
71+
List of paths to swc files where each file corresponds to a
72+
neuron in the prediction.
6673
target_swc_paths : list[str]
67-
List of paths to swc files such that each file corresponds to a
74+
List of paths to swc files where each file corresponds to a
6875
neuron in the ground truth.
69-
labels : numpy.ndarray or tensorstore.TensorStore
70-
Predicted segmentation mask.
7176
anisotropy : list[float], optional
7277
Image to real-world coordinates scaling factors applied to swc
7378
files. The default is [1.0, 1.0, 1.0]
74-
black_holes_xyz_id : list
79+
black_holes_xyz_id : list, optional
7580
...
76-
black_hole_radius : float
81+
black_hole_radius : float, optional
7782
...
7883
equivalent_ids : ...
7984
...
80-
valid_ids : set
85+
ignore_boundary_mistakes : bool, optional
86+
Indication of whether to ignore mistakes near boundary of bounding
87+
box. The default is False.
88+
output_dir : str, optional
89+
Path to directory that each mistake site is written to. The default
90+
is None.
91+
pred_on_cloud : bool, optional
92+
Indication of whether predicted swc files in "pred_swc_paths" are
93+
on the cloud in a GCS bucket. The default is False.
94+
valid_size_threshold : float, optional
8195
...
96+
write_to_swc : bool, optional
97+
Indication of whether to write mistake sites to an swc file. The
98+
default is False.
8299
83100
Returns
84101
-------
85102
None.
86103
87104
"""
88-
# Store label options
89-
self.valid_ids = valid_ids
90-
self.label_mask = pred_labels
91-
105+
# Store options
92106
self.anisotropy = anisotropy
93107
self.ignore_boundary_mistakes = ignore_boundary_mistakes
108+
self.output_dir = output_dir
109+
self.write_to_swc = write_to_swc
110+
94111
self.init_black_holes(black_holes_xyz_id)
95112
self.black_hole_radius = black_hole_radius
96113

97-
self.write_to_swc = write_to_swc
98-
self.output_dir = output_dir
99-
100114
# Build Graphs
115+
self.label_mask = pred_labels
116+
self.pred_swc_paths = pred_swc_paths
117+
self.init_valid_labels(valid_size_threshold)
118+
101119
self.target_graphs = self.init_graphs(target_swc_paths, anisotropy)
102120
self.labeled_target_graphs = self.init_labeled_target_graphs()
103-
self.pred_swc_paths = pred_swc_paths
104121

105122
# Build kdtree
106123
self.init_xyz_to_id_node()
107124
self.init_kdtree()
108125
self.rm_spurious_intersections()
109126

110127
# -- Initialize and Label Graphs --
128+
def init_valid_labels(self, valid_size_threshold):
129+
self.valid_labels = set()
130+
for path in self.pred_swc_paths:
131+
contents = swc_utils.read(path)
132+
if len(contents) > valid_size_threshold:
133+
self.valid_labels.add(utils.get_swc_id(path))
134+
111135
def init_graphs(self, paths, anisotropy):
112136
"""
113137
Initializes "self.target_graphs" by iterating over "paths" which
@@ -247,7 +271,7 @@ def __read_label(self, coord):
247271
def is_valid(self, label):
248272
"""
249273
Validates label by checking whether it is contained in
250-
"self.valid_ids".
274+
"self.valid_labels".
251275
252276
Parameters
253277
----------
@@ -258,12 +282,12 @@ def is_valid(self, label):
258282
-------
259283
label : int
260284
There are two possibilities: (1) original label if either "label"
261-
is contained in "self.valid_ids" or "self.valid_labels" is None,
262-
or (2) 0 if "label" is not contained in self.valid_ids.
285+
is contained in "self.valid_labels" or "self.valid_labels" is
286+
None, or (2) 0 if "label" is not contained in self.valid_labels.
263287
264288
"""
265-
if self.valid_ids:
266-
if label not in self.valid_ids:
289+
if self.valid_labels:
290+
if label not in self.valid_labels:
267291
return 0
268292
return label
269293

src/segmentation_skeleton_metrics/swc_utils.py

Lines changed: 104 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -13,100 +13,31 @@
1313
from segmentation_skeleton_metrics import utils
1414

1515

16-
def make_entry(node_id, parent_id, xyz):
17-
"""
18-
Makes an entry to be written in an swc file.
16+
def read(path, cloud_read=False):
17+
return read_from_cloud(path) if cloud_read else read_from_local(path)
1918

20-
Parameters
21-
----------
22-
graph : networkx.Graph
23-
Graph that "node_id" and "parent_id" belong to.
24-
node_id : int
25-
Node that entry corresponds to.
26-
parent_id : int
27-
Parent of node "node_id".
28-
xyz : ...
29-
xyz coordinate of "node_id".
3019

31-
Returns
32-
-------
33-
entry : str
34-
Entry to be written in an swc file.
35-
36-
"""
37-
x, y, z = tuple(xyz)
38-
entry = f"{node_id} 2 {x} {y} {z} 8 {parent_id}"
39-
return entry
20+
def read_from_cloud(path):
21+
pass
4022

4123

42-
def save(path, xyz_1, xyz_2, color=None):
24+
def read_from_local(path):
4325
"""
44-
Writes an swc file.
26+
Reads swc file stored at "path" on local machine.
4527
4628
Parameters
4729
----------
48-
path : str
49-
Path on local machine that swc file will be written to.
50-
entry_list : list[list]
51-
List of entries that will be written to an swc file.
52-
color : str, optional
53-
Color of nodes. The default is None.
54-
55-
Returns
56-
-------
57-
None.
58-
59-
"""
60-
with open(path, "w") as f:
61-
# Preamble
62-
if color is not None:
63-
f.write("# COLOR " + color)
64-
else:
65-
f.write("# id, type, z, y, x, r, pid")
66-
f.write("\n")
67-
68-
# Entries
69-
f.write(make_entry(1, -1, xyz_1))
70-
f.write("\n")
71-
f.write(make_entry(2, 1, xyz_2))
72-
73-
74-
def to_graph(path, anisotropy=[1.0, 1.0, 1.0]):
75-
"""
76-
Reads an swc file and builds an undirected graph from it.
77-
78-
Parameters
79-
----------
80-
path : str
30+
Path : str
8131
Path to swc file to be read.
82-
anisotropy : list[float], optional
83-
Image to real-world coordinates scaling factors for (x, y, z) that is
84-
applied to swc files. The default is [1.0, 1.0, 1.0].
8532
8633
Returns
8734
-------
88-
networkx.Graph
89-
Graph constructed from an swc file.
35+
list
36+
List such that each entry is a line from the swc file.
9037
9138
"""
92-
graph = nx.Graph(swc_id=utils.get_swc_id(path))
93-
with open(path, "r") as f:
94-
offset = [0, 0, 0]
95-
for line in f.readlines():
96-
if line.startswith("# OFFSET"):
97-
parts = line.split()
98-
offset = read_xyz(parts[2:5])
99-
if not line.startswith("#"):
100-
parts = line.split()
101-
child = int(parts[0])
102-
parent = int(parts[-1])
103-
xyz = read_xyz(
104-
parts[2:5], anisotropy=anisotropy, offset=offset
105-
)
106-
graph.add_node(child, xyz=xyz)
107-
if parent != -1:
108-
graph.add_edge(parent, child)
109-
return graph
39+
with open(path, "r") as file:
40+
return file.readlines()
11041

11142

11243
def get_xyz_coords(path, anisotropy=[1.0, 1.0, 1.0]):
@@ -167,3 +98,96 @@ def read_xyz(xyz, anisotropy=[1.0, 1.0, 1.0], offset=[0, 0, 0]):
16798
"""
16899
xyz = [float(xyz[i]) + offset[i] for i in range(3)]
169100
return np.array([xyz[i] / anisotropy[i] for i in range(3)], dtype=int)
101+
102+
103+
def save(path, xyz_1, xyz_2, color=None):
104+
"""
105+
Writes an swc file.
106+
107+
Parameters
108+
----------
109+
path : str
110+
Path on local machine that swc file will be written to.
111+
entry_list : list[list]
112+
List of entries that will be written to an swc file.
113+
color : str, optional
114+
Color of nodes. The default is None.
115+
116+
Returns
117+
-------
118+
None.
119+
120+
"""
121+
with open(path, "w") as f:
122+
# Preamble
123+
if color is not None:
124+
f.write("# COLOR " + color)
125+
else:
126+
f.write("# id, type, z, y, x, r, pid")
127+
f.write("\n")
128+
129+
# Entries
130+
f.write(make_entry(1, -1, xyz_1))
131+
f.write("\n")
132+
f.write(make_entry(2, 1, xyz_2))
133+
134+
135+
def make_entry(node_id, parent_id, xyz):
136+
"""
137+
Makes an entry to be written in an swc file.
138+
139+
Parameters
140+
----------
141+
graph : networkx.Graph
142+
Graph that "node_id" and "parent_id" belong to.
143+
node_id : int
144+
Node that entry corresponds to.
145+
parent_id : int
146+
Parent of node "node_id".
147+
xyz : ...
148+
xyz coordinate of "node_id".
149+
150+
Returns
151+
-------
152+
entry : str
153+
Entry to be written in an swc file.
154+
155+
"""
156+
x, y, z = tuple(xyz)
157+
entry = f"{node_id} 2 {x} {y} {z} 8 {parent_id}"
158+
return entry
159+
160+
161+
def to_graph(path, anisotropy=[1.0, 1.0, 1.0]):
162+
"""
163+
Reads an swc file and builds an undirected graph from it.
164+
165+
Parameters
166+
----------
167+
path : str
168+
Path to swc file to be read.
169+
anisotropy : list[float], optional
170+
Image to real-world coordinates scaling factors for (x, y, z) that is
171+
applied to swc files. The default is [1.0, 1.0, 1.0].
172+
173+
Returns
174+
-------
175+
networkx.Graph
176+
Graph built from an swc file.
177+
178+
"""
179+
graph = nx.Graph(swc_id=utils.get_swc_id(path))
180+
offset = [0, 0, 0]
181+
for line in read(path):
182+
if line.startswith("# OFFSET"):
183+
parts = line.split()
184+
offset = read_xyz(parts[2:5])
185+
if not line.startswith("#"):
186+
parts = line.split()
187+
child = int(parts[0])
188+
parent = int(parts[-1])
189+
xyz = read_xyz(parts[2:5], anisotropy=anisotropy, offset=offset)
190+
graph.add_node(child, xyz=xyz)
191+
if parent != -1:
192+
graph.add_edge(parent, child)
193+
return graph

0 commit comments

Comments
 (0)