Skip to content

Commit 2f7e50f

Browse files
author
anna-grim
committed
refactor: optimized memory consumption
1 parent 8ec7444 commit 2f7e50f

File tree

8 files changed

+85
-90
lines changed

8 files changed

+85
-90
lines changed

demo/evaluation_results.xls

0 Bytes
Binary file not shown.

demo/merged_ids-segmentation.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,8 @@
11
Label - xyz
2+
0 - -1
3+
0 - -1
4+
0 - -1
5+
0 - -1
6+
0 - -1
7+
0 - -1
8+
0 - -1

demo/pred_labels-old.tif

-64.1 MB
Binary file not shown.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import networkx as nx
2+
import numpy as np
3+
4+
5+
class SkeletonGraph(nx.Graph):
6+
7+
def __init__(self):
8+
# Call parent class
9+
super(SkeletonGraph, self).__init__()
10+
11+
def get_labels(self):
12+
return np.unique(self.graph["label"])
13+
14+
def nodes_with_label(self, label):
15+
return np.where(self.graph["label"] == label)[0]

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 19 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
ProcessPoolExecutor,
1313
ThreadPoolExecutor,
1414
)
15-
from copy import deepcopy
1615
from scipy.spatial import distance, KDTree
17-
from time import time
1816
from tqdm import tqdm
1917
from zipfile import ZipFile
2018

@@ -118,7 +116,7 @@ def __init__(
118116

119117
# Load data
120118
self.label_mask = pred_labels
121-
self.load_groundtruth(gt_pointer)
119+
self.load_groundtruth(gt_pointer, valid_labels)
122120
self.load_fragments(fragments_pointer)
123121

124122
# Initialize writer
@@ -127,7 +125,7 @@ def __init__(
127125
self.init_zip_writer()
128126

129127
# --- Load Data ---
130-
def load_groundtruth(self, swc_pointer):
128+
def load_groundtruth(self, swc_pointer, valid_labels):
131129
"""
132130
Initializes "self.graphs" by iterating over "paths" which corresponds
133131
to neurons in the ground truth.
@@ -149,16 +147,13 @@ def load_groundtruth(self, swc_pointer):
149147
anisotropy=self.anisotropy,
150148
label_mask=self.label_mask,
151149
use_anisotropy=False,
150+
valid_labels=valid_labels,
152151
)
153152
self.graphs = graph_builder.run(swc_pointer)
154153

155154
# Label nodes
156-
self.key_to_label_to_nodes = dict() # {id: {label: nodes}}
157155
for key in tqdm(self.graphs, desc="Labeling Graphs"):
158156
self.label_graphs(key)
159-
self.key_to_label_to_nodes[key] = gutil.init_label_to_nodes(
160-
self.graphs[key]
161-
)
162157

163158
def load_fragments(self, swc_pointer):
164159
print("\n(2) Load Fragments")
@@ -220,10 +215,12 @@ def label_graphs(self, key, batch_size=128):
220215
threads.append(executor.submit(self.get_patch_labels, key, batch))
221216

222217
# Process results
218+
n_nodes = self.graphs[key].number_of_nodes()
219+
self.graphs[key].graph["label"] = np.zeros((n_nodes), dtype=int)
223220
for thread in as_completed(threads):
224221
node_to_label = thread.result()
225222
for i, label in node_to_label.items():
226-
self.graphs[key].nodes[i].update({"label": label})
223+
self.graphs[key].graph["label"][i] = label
227224

228225
def get_patch_labels(self, key, nodes):
229226
# Get bounding box
@@ -287,11 +284,11 @@ def get_node_labels(self, key, inverse_bool=False):
287284
"""
288285
if inverse_bool:
289286
output = set()
290-
for l in self.key_to_label_to_nodes[key].keys():
287+
for l in self.graphs[key].get_labels():
291288
output = output.union(self.label_handler.inverse_mapping[l])
292289
return output
293290
else:
294-
return set(self.key_to_label_to_nodes[key].keys())
291+
return self.graphs[key].get_labels()
295292

296293
def init_zip_writer(self):
297294
"""
@@ -372,9 +369,9 @@ def adjust_metrics(self, key):
372369
"""
373370
for label in self.preexisting_merges:
374371
label = self.label_map[label] if self.label_map else label
375-
if label in self.key_to_label_to_nodes[key].keys():
372+
if label in self.graphs[key].get_labels():
376373
# Extract subgraph
377-
nodes = deepcopy(self.key_to_label_to_nodes[key][label])
374+
nodes = self.graphs[key].nodes_with_label(label)
378375
subgraph = self.graphs[key].subgraph(nodes)
379376

380377
# Adjust metrics
@@ -385,7 +382,6 @@ def adjust_metrics(self, key):
385382

386383
# Update graph
387384
self.graphs[key].remove_nodes_from(nodes)
388-
del self.key_to_label_to_nodes[key][label]
389385

390386
# -- Split Detection --
391387
def detect_splits(self):
@@ -402,7 +398,6 @@ def detect_splits(self):
402398
None
403399
404400
"""
405-
t0 = time()
406401
pbar = tqdm(total=len(self.graphs), desc="Split Detection")
407402
with ProcessPoolExecutor() as executor:
408403
# Assign processes
@@ -420,17 +415,10 @@ def detect_splits(self):
420415
self.split_percent = dict()
421416
for process in as_completed(processes):
422417
key, graph, split_percent = process.result()
423-
self.graphs[key] = gutil.delete_nodes(graph, 0)
424-
self.key_to_label_to_nodes[key] = gutil.init_label_to_nodes(
425-
self.graphs[key]
426-
)
418+
self.graphs[key] = gutil.remove_nodes(graph, 0)
427419
self.split_percent[key] = split_percent
428420
pbar.update(1)
429421

430-
# Report runtime
431-
t, unit = util.time_writer(time() - t0)
432-
print(f"Runtime: {round(t, 2)} {unit}\n")
433-
434422
def quantify_splits(self):
435423
"""
436424
Counts the number of splits, number of omit edges, and percent of omit
@@ -449,9 +437,11 @@ def quantify_splits(self):
449437
self.omit_cnts = dict()
450438
self.omit_percent = dict()
451439
for key in self.graphs:
440+
# Get counts
452441
n_pred_edges = self.graphs[key].number_of_edges()
453442
n_target_edges = self.graphs[key].graph["n_edges"]
454443

444+
# Compute stats
455445
self.split_cnt[key] = gutil.count_splits(self.graphs[key])
456446
self.omit_cnts[key] = n_target_edges - n_pred_edges
457447
self.omit_percent[key] = 1 - n_pred_edges / n_target_edges
@@ -517,7 +507,8 @@ def count_merges(self, key, kdtree):
517507
518508
"""
519509
for label in self.get_node_labels(key):
520-
if len(self.key_to_label_to_nodes[key][label]) > MIN_CNT:
510+
nodes = self.graphs[key].nodes_with_label(label)
511+
if len(nodes) > MIN_CNT:
521512
for label in self.label_handler.get_class(label):
522513
if label in self.fragment_graphs:
523514
self.is_fragment_merge(key, label, kdtree)
@@ -581,8 +572,8 @@ def find_label_intersections(self):
581572
keys = frozenset((key_1, key_2))
582573
if key_1 != key_2 and keys not in visited:
583574
visited.add(keys)
584-
labels_1 = self.get_node_labels(key_1)
585-
labels_2 = self.get_node_labels(key_2)
575+
labels_1 = set(self.graphs[key_1].get_labels())
576+
labels_2 = set(self.graphs[key_2].get_labels())
586577
for label in labels_1.intersection(labels_2):
587578
label_intersections.add((keys, label))
588579
return label_intersections
@@ -605,15 +596,14 @@ def process_merge(self, key, label, xyz, update_merged_labels=True):
605596
None
606597
607598
"""
608-
if label in self.key_to_label_to_nodes[key]:
599+
if label in self.graphs[key].get_labels():
609600
# Compute metrics
610-
nodes = list(self.key_to_label_to_nodes[key][label])
601+
nodes = self.graphs[key].nodes_with_label(label)
611602
subgraph = self.graphs[key].subgraph(nodes)
612603
self.merged_edges_cnt[key] += subgraph.number_of_edges()
613604

614605
# Update self
615606
self.graphs[key].remove_nodes_from(nodes)
616-
del self.key_to_label_to_nodes[key][label]
617607
if update_merged_labels:
618608
self.merged_labels.add((key, label, -1))
619609

src/segmentation_skeleton_metrics/split_detection.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ def run(process_id, graph):
4545
continue
4646

4747
# Visit edge
48-
label_i = graph.nodes[i]["label"]
49-
label_j = graph.nodes[j]["label"]
48+
label_i = int(graph.graph["label"][i])
49+
label_j = int(graph.graph["label"][j])
5050
if is_split(label_i, label_j):
5151
graph.remove_edge(i, j)
5252
split_cnt += 1
@@ -91,12 +91,13 @@ def check_misalignment(graph, visited_edges, nb, root):
9191
while len(queue) > 0:
9292
# Visit node
9393
j = queue.popleft()
94-
if graph.nodes[j]["label"] != 0:
95-
label_collisions.add(graph.nodes[j]["label"])
94+
label_j = int(graph.graph["label"][j])
95+
if label_j != 0:
96+
label_collisions.add(label_j)
9697
visited.add(j)
9798

9899
# Update queue
99-
if graph.nodes[j]["label"] == 0:
100+
if label_j == 0:
100101
for k in graph.neighbors(j):
101102
if k not in visited:
102103
if frozenset({j, k}) not in visited_edges or k == nb:
@@ -106,7 +107,7 @@ def check_misalignment(graph, visited_edges, nb, root):
106107
# Upd zero nodes
107108
if len(label_collisions) == 1:
108109
label = label_collisions.pop()
109-
graph = gutil.upd_labels(graph, visited, label)
110+
upd_labels(graph, visited, label)
110111

111112

112113
# -- Helpers --
@@ -128,3 +129,26 @@ def is_split(a, b):
128129
129130
"""
130131
return (a > 0 and b > 0) and (a != b)
132+
133+
134+
def upd_labels(graph, nodes, label):
135+
"""
136+
Updates the label of each node in "nodes" with "label".
137+
138+
Parameters
139+
----------
140+
graph : networkx.Graph
141+
Graph to be updated.
142+
nodes : list
143+
List of nodes to be updated.
144+
label : int
145+
New label of each node in "nodes".
146+
147+
Returns
148+
-------
149+
networkx.Graph
150+
Updated graph.
151+
152+
"""
153+
for i in nodes:
154+
graph.graph["label"][i] = label

src/segmentation_skeleton_metrics/utils/graph_util.py

Lines changed: 13 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616
from scipy.spatial import distance
1717

18+
from segmentation_skeleton_metrics.skeleton_graph import SkeletonGraph
1819
from segmentation_skeleton_metrics.utils import img_util, swc_util, util
1920

2021
ANISOTROPY = np.array([0.748, 0.748, 1.0])
@@ -28,11 +29,13 @@ def __init__(
2829
label_mask=None,
2930
selected_ids=None,
3031
use_anisotropy=True,
32+
valid_labels=None,
3133
):
3234
# Instance attributes
3335
self.anisotropy = anisotropy
3436
self.label_mask = label_mask
3537
self.selected_ids = selected_ids
38+
self.valid_labels = valid_labels
3639

3740
# Reader
3841
anisotropy = anisotropy if use_anisotropy else (1.0, 1.0, 1.0)
@@ -62,9 +65,9 @@ def _build_graphs_from_swcs(self, swc_pointer):
6265
def _process_swc_dict(self, swc_id):
6366
if self.selected_ids:
6467
segment_id = get_segment_id(swc_id)
65-
if segment_id not in self.selected_ids:
66-
return False
67-
return True
68+
return True if segment_id in self.selected_ids else False
69+
else:
70+
return True
6871

6972
def to_graph(self, swc_dict):
7073
"""
@@ -87,7 +90,7 @@ def to_graph(self, swc_dict):
8790
voxels = np.array(swc_dict["voxel"], dtype=np.int32)
8891

8992
# Build graph
90-
graph = nx.Graph()
93+
graph = SkeletonGraph()
9194
id_lookup = dict()
9295
run_length = 0
9396
for i in range(len(swc_dict["id"])):
@@ -202,23 +205,7 @@ def build_labels_graph(self, connections_path):
202205

203206
# --- Main ---
204207
def get(self, label):
205-
"""
206-
Gets label of voxel in "self.label_mask".
207-
208-
Parameters
209-
----------
210-
i : int
211-
Node ID.
212-
voxel : numpy.ndarray
213-
Image coordinate of voxel to be read.
214-
215-
Returns
216-
-------
217-
int
218-
Label of voxel.
219-
220-
"""
221-
if len(self.mapping) > 0:
208+
if self.use_mapping():
222209
return self.mapping.get(label, 0)
223210
elif self.valid_labels:
224211
return 0 if label not in self.valid_labels else label
@@ -232,14 +219,14 @@ def use_mapping(self):
232219

233220

234221
# --- Update graph ---
235-
def delete_nodes(graph, target_label):
222+
def remove_nodes(graph, target_label):
236223
"""
237-
Deletes nodes in "graph" whose label is "target_label".
224+
Deletes nodes in the given graph whose label is "target_label".
238225
239226
Parameters
240227
----------
241228
graph : networkx.Graph
242-
Graph to be searched and edited.
229+
Graph with a graph-level attribute called "label".
243230
target_label : int
244231
Label to be deleted from graph.
245232
@@ -249,36 +236,8 @@ def delete_nodes(graph, target_label):
249236
Updated graph.
250237
251238
"""
252-
delete_nodes = []
253-
for i in graph.nodes:
254-
label = graph.nodes[i]["label"]
255-
if label == target_label:
256-
delete_nodes.append(i)
257-
graph.remove_nodes_from(delete_nodes)
258-
return graph
259-
260-
261-
def upd_labels(graph, nodes, label):
262-
"""
263-
Updates the label of each node in "nodes" with "label".
264-
265-
Parameters
266-
----------
267-
graph : networkx.Graph
268-
Graph to be updated.
269-
nodes : list
270-
List of nodes to be updated.
271-
label : int
272-
New label of each node in "nodes".
273-
274-
Returns
275-
-------
276-
networkx.Graph
277-
Updated graph.
278-
279-
"""
280-
for i in nodes:
281-
graph.nodes[i].update({"label": label})
239+
nodes = np.where(graph.graph["label"] == target_label)[0]
240+
graph.remove_nodes_from(nodes)
282241
return graph
283242

284243

src/segmentation_skeleton_metrics/utils/swc_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def load_from_local_zips(self, zip_dir):
181181
with ProcessPoolExecutor() as executor:
182182
# Assign threads
183183
processes = list()
184-
for f in zip_names:
184+
for f in zip_names[0:250]: # TEMP
185185
zip_path = os.path.join(zip_dir, f)
186186
processes.append(
187187
executor.submit(self.load_from_local_zip, zip_path)

0 commit comments

Comments
 (0)