Skip to content

Commit ad48fc9

Browse files
anna-grimanna-grim
andauthored
refactor: optimized swc reader (#97)
Co-authored-by: anna-grim <[email protected]>
1 parent 4a85876 commit ad48fc9

File tree

3 files changed

+61
-32
lines changed

3 files changed

+61
-32
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def __init__(
108108
self.connections_path = connections_path
109109
self.output_dir = output_dir
110110
self.preexisting_merges = preexisting_merges
111+
self.save_projections = save_projections
111112

112113
# Label handler
113114
self.label_handler = gutil.LabelHandler(
@@ -116,16 +117,15 @@ def __init__(
116117

117118
# Load data
118119
self.label_mask = pred_labels
119-
self.load_groundtruth(gt_pointer, valid_labels)
120+
self.load_groundtruth(gt_pointer)
120121
self.load_fragments(fragments_pointer)
121122

122123
# Initialize writer
123-
self.save_projections = save_projections
124124
if self.save_projections:
125125
self.init_zip_writer()
126126

127127
# --- Load Data ---
128-
def load_groundtruth(self, swc_pointer, valid_labels):
128+
def load_groundtruth(self, swc_pointer):
129129
"""
130130
Initializes "self.graphs" by iterating over "paths" which corresponds
131131
to neurons in the ground truth.
@@ -147,7 +147,6 @@ def load_groundtruth(self, swc_pointer, valid_labels):
147147
anisotropy=self.anisotropy,
148148
label_mask=self.label_mask,
149149
use_anisotropy=False,
150-
valid_labels=valid_labels,
151150
)
152151
self.graphs = graph_builder.run(swc_pointer)
153152

@@ -158,16 +157,18 @@ def load_groundtruth(self, swc_pointer, valid_labels):
158157
def load_fragments(self, swc_pointer):
159158
print("\n(2) Load Fragments")
160159
if swc_pointer:
160+
coords_only = False #not self.save_projections
161161
graph_builder = gutil.GraphBuilder(
162162
anisotropy=self.anisotropy,
163+
coords_only=coords_only,
163164
selected_ids=self.get_all_node_labels(),
164165
use_anisotropy=True,
165166
)
166167
self.fragment_graphs = graph_builder.run(swc_pointer)
167168
else:
168169
self.fragment_graphs = None
169170

170-
def label_graphs(self, key, batch_size=128):
171+
def label_graphs(self, key, batch_size=64):
171172
"""
172173
Iterates over nodes in "graph" and stores the corresponding label from
173174
predicted segmentation mask (i.e. "self.label_mask") as a node-level

src/segmentation_skeleton_metrics/utils/graph_util.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
77
88
"""
9-
from concurrent.futures import ProcessPoolExecutor, as_completed
9+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
1010
from random import sample
1111
from scipy.spatial import distance
1212
from tqdm import tqdm
@@ -29,20 +29,20 @@ class GraphBuilder:
2929
def __init__(
3030
self,
3131
anisotropy=(1.0, 1.0, 1.0),
32+
coords_only=False,
3233
label_mask=None,
3334
selected_ids=None,
3435
use_anisotropy=True,
35-
valid_labels=None,
3636
):
3737
# Instance attributes
3838
self.anisotropy = anisotropy
39+
self.coords_only = coords_only
3940
self.label_mask = label_mask
4041
self.selected_ids = selected_ids
41-
self.valid_labels = valid_labels
4242

4343
# Reader
4444
anisotropy = anisotropy if use_anisotropy else (1.0, 1.0, 1.0)
45-
self.swc_reader = swc_util.Reader(anisotropy)
45+
self.swc_reader = swc_util.Reader(anisotropy, selected_ids=selected_ids)
4646

4747
def run(self, swc_pointer):
4848
graphs = self._build_graphs_from_swcs(swc_pointer)
@@ -51,14 +51,14 @@ def run(self, swc_pointer):
5151

5252
# --- Build Graphs ---
5353
def _build_graphs_from_swcs(self, swc_pointer):
54-
with ProcessPoolExecutor() as executor:
54+
with ThreadPoolExecutor() as executor:
5555
# Assign processes
5656
processes = list()
5757
swc_dicts = self.swc_reader.load(swc_pointer)
5858
while len(swc_dicts) > 0:
5959
swc_dict = swc_dicts.pop()
60-
if self._process_swc_dict(swc_dict["swc_id"]):
61-
processes.append(executor.submit(self.to_graph, swc_dict))
60+
#if self._process_swc_dict(swc_dict["swc_id"]):
61+
processes.append(executor.submit(self.to_graph, swc_dict))
6262

6363
# Store results
6464
graphs = dict()
@@ -96,17 +96,19 @@ def to_graph(self, swc_dict):
9696
graph.set_voxels(swc_dict["voxel"])
9797

9898
# Build graph
99-
id_lookup = dict()
100-
run_length = 0
101-
for i, id_i in enumerate(swc_dict["id"]):
102-
id_lookup[id_i] = i
103-
if swc_dict["pid"][i] != -1:
104-
parent = id_lookup[swc_dict["pid"][i]]
105-
graph.add_edge(i, parent)
106-
graph.run_length += graph.dist(i, parent)
107-
108-
# Set graph-level attributes
109-
graph.graph["n_edges"] = graph.number_of_edges()
99+
if not self.coords_only:
100+
#graph.set_nodes()
101+
id_lookup = dict()
102+
run_length = 0
103+
for i, id_i in enumerate(swc_dict["id"]):
104+
id_lookup[id_i] = i
105+
if swc_dict["pid"][i] != -1:
106+
parent = id_lookup[swc_dict["pid"][i]]
107+
graph.add_edge(i, parent)
108+
graph.run_length += graph.dist(i, parent)
109+
110+
# Set graph-level attributes
111+
graph.graph["n_edges"] = graph.number_of_edges()
110112
return {swc_dict["swc_id"]: graph}
111113

112114
# --- Label Graphs ---
@@ -127,6 +129,10 @@ def __init__(self, anisotropy=(1.0, 1.0, 1.0)):
127129
self.anisotropy = anisotropy
128130
self.run_length = 0
129131

132+
def set_nodes(self):
133+
num_nodes = len(self.voxels)
134+
self.add_nodes_from(np.arange(num_nodes))
135+
130136
def set_voxels(self, voxels):
131137
self.voxels = np.array(voxels, dtype=np.int32)
132138

src/segmentation_skeleton_metrics/utils/swc_util.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Reader:
4848
4949
"""
5050

51-
def __init__(self, anisotropy=(1.0, 1.0, 1.0)):
51+
def __init__(self, anisotropy=(1.0, 1.0, 1.0), selected_ids=None):
5252
"""
5353
Initializes a Reader object that loads swc files.
5454
@@ -65,6 +65,7 @@ def __init__(self, anisotropy=(1.0, 1.0, 1.0)):
6565
6666
"""
6767
self.anisotropy = anisotropy
68+
self.selected_ids = selected_ids or set()
6869

6970
# --- Load Data ---
7071
def load(self, swc_pointer):
@@ -138,7 +139,10 @@ def load_from_local_path(self, path):
138139
"""
139140
content = util.read_txt(path)
140141
filename = os.path.basename(path)
141-
return self.parse(content, filename)
142+
if self.confirm_load(filename):
143+
return self.parse(content, filename)
144+
else:
145+
return None
142146

143147
def load_from_local_paths(self, paths):
144148
"""
@@ -161,9 +165,11 @@ def load_from_local_paths(self, paths):
161165
threads = list()
162166
pbar = tqdm(total=len(paths), desc="Read SWCs")
163167
for path in paths:
164-
threads.append(
165-
executor.submit(self.load_from_local_path, path)
166-
)
168+
filename = os.path.basename(path)
169+
if self.confirm_load(filename):
170+
threads.append(
171+
executor.submit(self.load_from_local_path, path)
172+
)
167173

168174
# Store results
169175
swc_dicts = deque()
@@ -216,10 +222,14 @@ def load_from_local_zip(self, zip_path):
216222
# Assign threads
217223
threads = list()
218224
zipfile = ZipFile(zip_path, "r")
219-
for f in [f for f in zipfile.namelist() if f.endswith(".swc")]:
220-
threads.append(
221-
executor.submit(self.load_from_zipped_file, zipfile, f)
222-
)
225+
filesnames = [f for f in zipfile.namelist() if f.endswith(".swc")]
226+
for filename in filesnames:
227+
if self.confirm_load(filename):
228+
threads.append(
229+
executor.submit(
230+
self.load_from_zipped_file, zipfile, filename
231+
)
232+
)
223233

224234
# Store results
225235
swc_dicts = deque()
@@ -249,6 +259,13 @@ def load_from_zipped_file(self, zipfile, path):
249259
filename = os.path.basename(path)
250260
return self.parse(content, filename)
251261

262+
def confirm_load(self, filename):
263+
if len(self.selected_ids) > 0:
264+
segment_id = get_segment_id(filename)
265+
return True if segment_id in self.selected_ids else False
266+
else:
267+
return True
268+
252269
# -- Process Text ---
253270
def parse(self, content, filename):
254271
"""
@@ -335,6 +352,11 @@ def read_voxel(self, xyz_str, offset):
335352
xyz = [float(xyz_str[i]) + offset[i] for i in range(3)]
336353
return img_util.to_voxels(xyz, self.anisotropy)
337354

355+
356+
# --- Helpers ---
357+
def get_segment_id(filename):
358+
return int(filename.split(".")[0])
359+
338360

339361
def to_zipped_swc(zip_writer, graph, color=None):
340362
"""

0 commit comments

Comments
 (0)