Skip to content

Commit a8f1869

Browse files
author
anna-grim
committed
feat: graph builder class
1 parent 584e5af commit a8f1869

File tree

4 files changed

+130
-205
lines changed

4 files changed

+130
-205
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 27 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import numpy as np
2424
import os
2525

26-
from segmentation_skeleton_metrics import graph_segmentation_alignment as gsa
26+
from segmentation_skeleton_metrics import split_detection
2727
from segmentation_skeleton_metrics.utils import (
2828
graph_util as gutil,
2929
img_util,
@@ -112,26 +112,20 @@ def __init__(
112112
self.output_dir = output_dir
113113
self.preexisting_merges = preexisting_merges
114114

115-
# Load ground truth
116-
print("\n(1) Load Ground Truth")
117-
assert type(valid_labels) is set if valid_labels else True
115+
# Load data
116+
assert isinstance(valid_labels, set) if valid_labels else True
118117
self.label_mask = pred_labels
119118
self.valid_labels = valid_labels
120119
self.init_label_map(connections_path)
121-
self.init_graphs(gt_pointer)
122-
123-
print("\n(2) Load Prediction")
124-
if fragments_pointer:
125-
self.load_fragments(fragments_pointer)
126-
else:
127-
self.fragment_graphs = None
120+
self.load_groundtruth(gt_pointer)
121+
self.load_fragments(fragments_pointer)
128122

129123
# Initialize writer
130124
self.save_projections = save_projections
131125
if self.save_projections:
132126
self.init_zip_writer()
133127

134-
# -- Initialize and Label Graphs --
128+
# --- Load Data ---
135129
def init_label_map(self, path):
136130
"""
137131
Initializes a dictionary that maps a label to its equivalent label in
@@ -157,7 +151,7 @@ def init_label_map(self, path):
157151
self.label_map = None
158152
self.inverse_label_map = None
159153

160-
def init_graphs(self, paths):
154+
def load_groundtruth(self, swc_pointer):
161155
"""
162156
Initializes "self.graphs" by iterating over "paths" which corresponds
163157
to neurons in the ground truth.
@@ -174,8 +168,13 @@ def init_graphs(self, paths):
174168
175169
"""
176170
# Build graphs
177-
swc_dicts = swc_util.Reader().load(paths)
178-
self.graphs = self.build_graphs(swc_dicts)
171+
print("\n(1) Load Ground Truth")
172+
graph_builder = gutil.GraphBuilder(
173+
anisotropy=self.anisotropy,
174+
label_mask=self.label_mask,
175+
use_anisotropy=False,
176+
)
177+
self.graphs = graph_builder.run(swc_pointer)
179178

180179
# Label nodes
181180
self.key_to_label_to_nodes = dict() # {id: {label: nodes}}
@@ -185,23 +184,18 @@ def init_graphs(self, paths):
185184
self.graphs[key]
186185
)
187186

188-
def build_graphs(self, swc_dicts):
189-
graphs = dict()
190-
with ProcessPoolExecutor() as executor:
191-
# Assign processes
192-
processes = list()
193-
for swc_dict in swc_dicts:
194-
processes.append(
195-
executor.submit(gutil.to_graph, swc_dict)
196-
)
187+
def load_fragments(self, swc_pointer):
188+
print("\n(2) Load Fragments")
189+
if swc_pointer:
190+
graph_builder = gutil.GraphBuilder(
191+
anisotropy=self.anisotropy,
192+
selected_ids=self.get_all_node_labels(),
193+
use_anisotropy=True,
194+
)
195+
self.fragment_graphs = graph_builder.run(swc_pointer)
196+
else:
197+
self.fragment_graphs = None
197198

198-
# Store results
199-
pbar = tqdm(total=len(processes), desc="Build Graphs")
200-
for process in as_completed(processes):
201-
graphs.update(process.result())
202-
pbar.update(1)
203-
return graphs
204-
205199
def label_graphs(self, key, batch_size=128):
206200
"""
207201
Iterates over nodes in "graph" and stores the corresponding label from
@@ -259,7 +253,7 @@ def get_patch_labels(self, key, nodes):
259253
# Get bounding box
260254
bbox = {"min": [np.inf, np.inf, np.inf], "max": [0, 0, 0]}
261255
for i in nodes:
262-
voxel = deepcopy(self.graphs[key].graph["voxel"][i])
256+
voxel = self.graphs[key].graph["voxel"][i]
263257
for idx in range(3):
264258
if voxel[idx] < bbox["min"][idx]:
265259
bbox["min"][idx] = voxel[idx]
@@ -363,39 +357,6 @@ def get_node_labels(self, key, inverse_bool=False):
363357
else:
364358
return set(self.key_to_label_to_nodes[key].keys())
365359

366-
# -- Load Fragments --
367-
def load_fragments(self, fragments_pointer):
368-
"""
369-
Loads and filters swc files from a local zip. These swc files are
370-
assumed to be fragments from a predicted segmentation.
371-
372-
Parameters
373-
----------
374-
zip_path : str
375-
Path to the local zip file containing the fragments
376-
377-
Returns
378-
-------
379-
dict
380-
Dictionary that maps an swc id to the fragment graph.
381-
382-
"""
383-
# Read SWC files
384-
reader = swc_util.Reader(anisotropy=self.anisotropy)
385-
swc_dicts = deque(reader.load(fragments_pointer))
386-
387-
# Filter SWC files
388-
filtered_swc_dicts = list()
389-
labels = self.get_all_node_labels()
390-
while len(swc_dicts) > 0:
391-
swc_dict = swc_dicts.popleft()
392-
swc_id = int(swc_dict["swc_id"])
393-
if swc_id in labels:
394-
swc_dict["swc_id"] = swc_id
395-
filtered_swc_dicts.append(swc_dict)
396-
self.fragment_graphs = self.build_graphs(filtered_swc_dicts)
397-
print("# Fragments:", len(self.fragment_graphs))
398-
399360
def init_zip_writer(self):
400361
"""
401362
Initializes "self.zip_writer" attribute by setting up a directory for
@@ -513,7 +474,7 @@ def detect_splits(self):
513474
for key, graph in self.graphs.items():
514475
processes.append(
515476
executor.submit(
516-
gsa.correct_graph_misalignments,
477+
split_detection.run,
517478
key,
518479
graph,
519480
)

src/segmentation_skeleton_metrics/graph_segmentation_alignment.py renamed to src/segmentation_skeleton_metrics/split_detection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from segmentation_skeleton_metrics.utils import graph_util as gutil
1717

1818

19-
def correct_graph_misalignments(process_id, graph):
19+
def run(process_id, graph):
2020
"""
2121
Adjusts misalignments between ground truth graph and segmentation mask.
2222

src/segmentation_skeleton_metrics/utils/graph_util.py

Lines changed: 101 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -9,62 +9,111 @@
99
from collections import defaultdict
1010
from concurrent.futures import ProcessPoolExecutor, as_completed
1111
from random import sample
12+
from tqdm import tqdm
1213

1314
import networkx as nx
1415
import numpy as np
1516
from scipy.spatial import distance
1617

17-
from segmentation_skeleton_metrics.utils import img_util
18+
from segmentation_skeleton_metrics.utils import img_util, swc_util
1819

1920
ANISOTROPY = np.array([0.748, 0.748, 1.0])
2021

2122

22-
def to_graph(swc_dict):
23-
"""
24-
Builds a graph from a dictionary that contains the contents of an SWC
25-
file.
26-
27-
Parameters
28-
----------
29-
swc_dict : dict
30-
...
31-
32-
Returns
33-
-------
34-
networkx.Graph
35-
Graph built from an SWC file.
36-
37-
"""
38-
# Initializations
39-
old_to_new = dict()
40-
run_length = 0
41-
voxels = np.zeros((len(swc_dict["id"]), 3), dtype=np.int32)
42-
43-
# Build graph
44-
graph = nx.Graph()
45-
for i in range(len(swc_dict["id"])):
46-
# Get node id
47-
old_id = swc_dict["id"][i]
48-
old_to_new[old_id] = i
49-
50-
# Update graph
51-
voxels[i] = swc_dict["voxel"][i]
52-
if swc_dict["pid"][i] != -1:
53-
# Add edge
54-
parent = old_to_new[swc_dict["pid"][i]]
55-
graph.add_edge(i, parent)
56-
57-
# Update run length
58-
xyz_i = voxels[i] * ANISOTROPY
59-
xyz_p = voxels[parent] * ANISOTROPY
60-
run_length += distance.euclidean(xyz_i, xyz_p)
61-
62-
# Set graph-level attributes
63-
graph.graph["n_edges"] = graph.number_of_edges()
64-
graph.graph["run_length"] = run_length
65-
graph.graph["voxel"] = voxels
66-
return {swc_dict["swc_id"]: graph}
67-
23+
class GraphBuilder:
24+
25+
def __init__(
26+
self,
27+
anisotropy=(1.0, 1.0, 1.0),
28+
label_mask=None,
29+
selected_ids=None,
30+
use_anisotropy=True,
31+
):
32+
# Instance attributes
33+
self.anisotropy = anisotropy
34+
self.label_mask = label_mask
35+
self.selected_ids = selected_ids
36+
37+
# Reader
38+
anisotropy = anisotropy if use_anisotropy else (1.0, 1.0, 1.0)
39+
self.swc_reader = swc_util.Reader(anisotropy)
40+
41+
def run(self, swc_pointer):
42+
self._build_graphs_from_swcs(swc_pointer)
43+
self._label_graphs_with_segmentation()
44+
return self.graphs
45+
46+
# --- Build Graphs ---
47+
def _build_graphs_from_swcs(self, swc_pointer):
48+
with ProcessPoolExecutor() as executor:
49+
# Assign processes
50+
processes = list()
51+
for swc_dict in self.swc_reader.load(swc_pointer):
52+
if self._process_swc_dict(swc_dict["swc_id"]):
53+
processes.append(executor.submit(self.to_graph, swc_dict))
54+
55+
# Store results
56+
self.graphs = dict()
57+
pbar = tqdm(total=len(processes), desc="Build Graphs")
58+
for process in as_completed(processes):
59+
self.graphs.update(process.result())
60+
pbar.update(1)
61+
62+
def _process_swc_dict(self, swc_id):
63+
if self.selected_ids:
64+
segment_id = get_segment_id(swc_id)
65+
if segment_id not in self.selected_ids:
66+
return False
67+
return True
68+
69+
def to_graph(self, swc_dict):
70+
"""
71+
Builds a graph from a dictionary that contains the contents of an SWC
72+
file.
73+
74+
Parameters
75+
----------
76+
swc_dict : dict
77+
...
78+
79+
Returns
80+
-------
81+
networkx.Graph
82+
Graph built from an SWC file.
83+
84+
"""
85+
# Extract data from swc_dict
86+
ids = swc_dict["id"]
87+
voxels = np.array(swc_dict["voxel"], dtype=np.int32)
88+
89+
# Build graph
90+
graph = nx.Graph()
91+
id_lookup = dict()
92+
run_length = 0
93+
for i in range(len(swc_dict["id"])):
94+
id_lookup[ids[i]] = i
95+
if swc_dict["pid"][i] != -1:
96+
# Add edge
97+
parent = id_lookup[swc_dict["pid"][i]]
98+
graph.add_edge(i, parent)
99+
100+
# Update run length
101+
xyz_i = voxels[i] * self.anisotropy
102+
xyz_p = voxels[parent] * self.anisotropy
103+
run_length += distance.euclidean(xyz_i, xyz_p)
104+
105+
# Set graph-level attributes
106+
graph.graph["n_edges"] = graph.number_of_edges()
107+
graph.graph["run_length"] = run_length
108+
graph.graph["voxel"] = voxels
109+
return {swc_dict["swc_id"]: graph}
110+
111+
# --- Label Graphs ---
112+
def _label_graphs_with_segmentation(self):
113+
pass
114+
115+
def _label_graph(self, key):
116+
pass
68117

69118
# --- Update graph ---
70119
def delete_nodes(graph, target_label):
@@ -206,7 +255,11 @@ def compute_run_length(graph):
206255
return path_length
207256

208257

209-
# -- miscellaneous --
258+
# -- Miscellaneous --
259+
def get_segment_id(swc_id):
260+
return int(swc_id.split(".")[0])
261+
262+
210263
def sample_leaf(graph):
211264
"""
212265
Samples leaf node from "graph".

0 commit comments

Comments
 (0)