Skip to content

Commit e40269e

Browse files
author
anna-grim
committed
bug: fixed memory issue
1 parent 126e7e9 commit e40269e

File tree

4 files changed

+238
-259
lines changed

4 files changed

+238
-259
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
88
"""
99

10-
10+
from collections import deque
1111
from concurrent.futures import (
1212
as_completed,
1313
ProcessPoolExecutor,
@@ -31,7 +31,7 @@
3131
util
3232
)
3333

34-
MERGE_DIST_THRESHOLD = 100
34+
MERGE_DIST_THRESHOLD = 200
3535
MIN_CNT = 40
3636

3737

@@ -112,15 +112,19 @@ def __init__(
112112
self.output_dir = output_dir
113113
self.preexisting_merges = preexisting_merges
114114

115-
# Load Data
116-
print("\n(1) Load Data")
115+
# Load ground truth
116+
print("\n(1) Load Ground Truth")
117117
assert type(valid_labels) is set if valid_labels else True
118-
self.label_mask = pred_labels
119118
self.valid_labels = valid_labels
120119
self.init_label_map(connections_path)
121120
self.init_graphs(gt_pointer)
121+
122+
print("\n(2) Load Prediction")
123+
self.label_mask = pred_labels
122124
if fragments_pointer:
123125
self.load_fragments(fragments_pointer)
126+
else:
127+
self.fragment_graphs = None
124128

125129
# Initialize writer
126130
self.save_projections = save_projections
@@ -160,7 +164,7 @@ def init_graphs(self, paths):
160164
161165
Parameters
162166
----------
163-
paths : list[str]
167+
paths : List[str]
164168
List of paths to swc files which correspond to neurons in the
165169
ground truth.
166170
@@ -170,18 +174,35 @@ def init_graphs(self, paths):
170174
171175
"""
172176
# Build graphs
173-
self.graphs = swc_util.Reader().load(paths)
174-
self.fragment_graphs = None
177+
swc_dicts = swc_util.Reader().load(paths)
178+
self.graphs = self.build_graphs(swc_dicts)
175179

176180
# Label nodes
177181
self.key_to_label_to_nodes = dict() # {id: {label: nodes}}
178182
for key in tqdm(self.graphs, desc="Labeling Graphs"):
179-
self.set_node_labels(key)
183+
self.label_graphs(key)
180184
self.key_to_label_to_nodes[key] = gutil.init_label_to_nodes(
181185
self.graphs[key]
182186
)
183187

184-
def set_node_labels(self, key, batch_size=128):
188+
def build_graphs(self, swc_dicts):
189+
graphs = dict()
190+
with ProcessPoolExecutor() as executor:
191+
# Assign processes
192+
processes = list()
193+
for swc_dict in swc_dicts:
194+
processes.append(
195+
executor.submit(gutil.to_graph, swc_dict)
196+
)
197+
198+
# Store results
199+
pbar = tqdm(total=len(processes), desc="Build Graphs")
200+
for process in as_completed(processes):
201+
graphs.update(process.result())
202+
pbar.update(1)
203+
return graphs
204+
205+
def label_graphs(self, key, batch_size=128):
185206
"""
186207
Iterates over nodes in "graph" and stores the corresponding label from
187208
predicted segmentation mask (i.e. "self.label_mask") as a node-level
@@ -238,7 +259,7 @@ def get_patch_labels(self, key, nodes):
238259
# Get bounding box
239260
bbox = {"min": [np.inf, np.inf, np.inf], "max": [0, 0, 0]}
240261
for i in nodes:
241-
voxel = deepcopy(self.graphs[key].nodes[i]["voxel"])
262+
voxel = deepcopy(self.graphs[key].graph["voxel"][i])
242263
for idx in range(3):
243264
if voxel[idx] < bbox["min"][idx]:
244265
bbox["min"][idx] = voxel[idx]
@@ -359,20 +380,20 @@ def load_fragments(self, fragments_pointer):
359380
Dictionary that maps an swc id to the fragment graph.
360381
361382
"""
362-
# Read fragments
363-
reader = swc_util.Reader(anisotropy=self.anisotropy, min_size=40)
364-
fragment_graphs = reader.load(fragments_pointer)
365-
self.fragment_ids = set(fragment_graphs.keys())
366-
367-
# Filter fragments
368-
self.fragment_graphs = dict()
369-
for label in self.get_all_node_labels():
370-
if label in fragment_graphs:
371-
self.fragment_graphs[label] = fragment_graphs[label]
372-
else:
373-
self.fragment_graphs[label] = nx.Graph(
374-
filename=f"{label}.swc", run_length=0, n_edges=1
375-
)
383+
# Read SWC files
384+
reader = swc_util.Reader(anisotropy=self.anisotropy)
385+
swc_dicts = deque(reader.load(fragments_pointer))
386+
387+
# Filter SWC files
388+
filtered_swc_dicts = list()
389+
labels = self.get_all_node_labels()
390+
while len(swc_dicts) > 0:
391+
swc_dict = swc_dicts.popleft()
392+
swc_id = int(swc_dict["swc_id"])
393+
if swc_id in labels:
394+
swc_dict["swc_id"] = swc_id
395+
filtered_swc_dicts.append(swc_dict)
396+
self.fragment_graphs = self.build_graphs(filtered_swc_dicts)
376397
print("# Fragments:", len(self.fragment_graphs))
377398

378399
def init_zip_writer(self):
@@ -416,7 +437,7 @@ def run(self):
416437
...
417438
418439
"""
419-
print("\n(2) Evaluation")
440+
print("\n(3) Evaluation")
420441

421442
# Split evaluation
422443
self.detect_splits()
@@ -564,16 +585,14 @@ def detect_merges(self):
564585
pbar = tqdm(total=len(self.graphs), desc="Count Merges:")
565586
for key, graph in self.graphs.items():
566587
if graph.number_of_nodes() > 0:
567-
kdtree = KDTree(gutil.to_array(graph))
588+
kdtree = KDTree(graph.graph["voxel"])
568589
self.count_merges(key, kdtree)
569590
pbar.update(1)
570591

571592
# Process merges
572-
pbar = tqdm(total=len(self.graphs), desc="Compute Percent Merged:")
573593
for (key_1, key_2), label in self.find_label_intersections():
574594
self.process_merge(key_1, label, -1)
575595
self.process_merge(key_2, label, -1)
576-
pbar.update(1)
577596

578597
for key, label, xyz in self.merged_labels:
579598
self.process_merge(key, label, xyz, update_merged_labels=False)
@@ -610,8 +629,8 @@ def count_merges(self, key, kdtree):
610629

611630
# Check if fragment is a merge mistake
612631
for label in labels:
613-
rl = self.fragment_graphs[label].graph["run_length"]
614-
self.is_fragment_merge(key, label, kdtree)
632+
if label in self.fragment_graphs:
633+
self.is_fragment_merge(key, label, kdtree)
615634

616635
def is_fragment_merge(self, key, label, kdtree):
617636
"""
@@ -634,7 +653,7 @@ def is_fragment_merge(self, key, label, kdtree):
634653
None
635654
636655
"""
637-
for voxel in gutil.to_array(self.fragment_graphs[label])[::2]:
656+
for voxel in self.fragment_graphs[label].graph["voxel"]:
638657
if kdtree.query(voxel)[0] > MERGE_DIST_THRESHOLD:
639658
# Check whether to get inverse of label
640659
if self.inverse_label_map:
@@ -643,10 +662,10 @@ def is_fragment_merge(self, key, label, kdtree):
643662
equivalent_label = label
644663

645664
# Record merge mistake
646-
xyz = img_util.to_physical(voxel)
665+
xyz = img_util.to_physical(voxel, self.anisotropy)
647666
self.merge_cnt[key] += 1
648667
self.merged_labels.add((key, equivalent_label, tuple(xyz)))
649-
if self.save_projections:
668+
if self.save_projections and label in self.fragment_graphs:
650669
swc_util.to_zipped_swc(
651670
self.zip_writer[key], self.fragment_graphs[label]
652671
)
@@ -768,13 +787,13 @@ def get_merged_label(self, label):
768787
Returns:
769788
-------
770789
str or list
771-
The first matching label found in "self.fragment_ids" or the
772-
original associated labels from "inverse_label_map" if no matches
773-
are found.
790+
The first matching label found in "self.fragment_graphs.keys()" or
791+
the original associated labels from "inverse_label_map" if no
792+
matches are found.
774793
775794
"""
776795
for l in self.inverse_label_map[label]:
777-
if l in self.fragment_ids:
796+
if l in self.fragment_graphs.keys():
778797
return l
779798
return self.inverse_label_map[label]
780799

@@ -988,8 +1007,8 @@ def list_metrics(self):
9881007

9891008
# -- util --
9901009
def dist(self, key, i, j):
991-
xyz_i = self.graphs[key].nodes[i]["voxel"]
992-
xyz_j = self.graphs[key].nodes[j]["voxel"]
1010+
xyz_i = self.graphs[key].graph["voxel"][i]
1011+
xyz_j = self.graphs[key].graph["voxel"][j]
9931012
return distance.euclidean(xyz_i, xyz_j)
9941013

9951014
def init_counter(self):
@@ -1010,7 +1029,7 @@ def init_counter(self):
10101029
return {key: 0 for key in self.graphs}
10111030

10121031
def to_local_voxels(self, key, i, offset):
1013-
voxel = np.array(self.graphs[key].nodes[i]["voxel"])
1032+
voxel = np.array(self.graphs[key].graph["voxel"][i])
10141033
offset = np.array(offset)
10151034
return tuple(voxel - offset)
10161035

src/segmentation_skeleton_metrics/utils/graph_util.py

Lines changed: 52 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,58 @@
1212

1313
import networkx as nx
1414
import numpy as np
15-
from scipy.spatial.distance import euclidean as get_dist
15+
from scipy.spatial import distance
1616

1717
from segmentation_skeleton_metrics.utils import img_util
1818

19-
ANISOTROPY = (0.748, 0.748, 1.0)
19+
ANISOTROPY = np.array([0.748, 0.748, 1.0])
20+
21+
22+
def to_graph(swc_dict):
23+
"""
24+
Builds a graph from a dictionary that contains the contents of an SWC
25+
file.
26+
27+
Parameters
28+
----------
29+
swc_dict : dict
30+
...
31+
32+
Returns
33+
-------
34+
networkx.Graph
35+
Graph built from an SWC file.
36+
37+
"""
38+
# Initializations
39+
old_to_new = dict()
40+
run_length = 0
41+
voxels = np.zeros((len(swc_dict["id"]), 3), dtype=np.int32)
42+
43+
# Build graph
44+
graph = nx.Graph()
45+
for i in range(len(swc_dict["id"])):
46+
# Get node id
47+
old_id = swc_dict["id"][i]
48+
old_to_new[old_id] = i
49+
50+
# Update graph
51+
voxels[i] = swc_dict["voxel"][i]
52+
if swc_dict["pid"][i] != -1:
53+
# Add edge
54+
parent = old_to_new[swc_dict["pid"][i]]
55+
graph.add_edge(i, parent)
56+
57+
# Update run length
58+
xyz_i = voxels[i] * ANISOTROPY
59+
xyz_p = voxels[parent] * ANISOTROPY
60+
run_length += distance.euclidean(xyz_i, xyz_p)
61+
62+
# Set graph-level attributes
63+
graph.graph["n_edges"] = graph.number_of_edges()
64+
graph.graph["run_length"] = run_length
65+
graph.graph["voxel"] = voxels
66+
return {swc_dict["swc_id"]: graph}
2067

2168

2269
# --- Update graph ---
@@ -153,32 +200,13 @@ def compute_run_length(graph):
153200
"""
154201
path_length = 0
155202
for i, j in nx.dfs_edges(graph):
156-
xyz_1 = img_util.to_physical(graph.nodes[i]["voxel"], ANISOTROPY)
157-
xyz_2 = img_util.to_physical(graph.nodes[j]["voxel"], ANISOTROPY)
158-
path_length += get_dist(xyz_1, xyz_2)
203+
xyz_i = img_util.to_physical(graph.graph["voxel"][i], ANISOTROPY)
204+
xyz_j = img_util.to_physical(graph.graph["voxel"][j], ANISOTROPY)
205+
path_length += distance.euclidean(xyz_i, xyz_j)
159206
return path_length
160207

161208

162209
# -- miscellaneous --
163-
def to_array(graph):
164-
"""
165-
Converts node coordinates from a graph into a NumPy array.
166-
167-
Parameters
168-
----------
169-
graph : networkx.Graph
170-
Graph that contains nodes with "voxel" attributes.
171-
172-
Returns
173-
-------
174-
numpy.ndarray
175-
Array where each row represents the 3D coordinates of a node.
176-
177-
"""
178-
voxels = nx.get_node_attributes(graph, "voxel")
179-
return np.array([voxels[i] for i in graph.nodes])
180-
181-
182210
def sample_leaf(graph):
183211
"""
184212
Samples leaf node from "graph".

0 commit comments

Comments
 (0)