Skip to content

Commit 1d44b54

Browse files
authored
major refactor
1 parent 21772b1 commit 1d44b54

File tree

3 files changed

+817
-147
lines changed

3 files changed

+817
-147
lines changed

src/segmentation_skeleton_metrics/data_handling/graph_loading.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717
import networkx as nx
1818
import numpy as np
1919

20-
from segmentation_skeleton_metrics.data_handling.skeleton_graph import SkeletonGraph
21-
from segmentation_skeleton_metrics.utils import swc_util, util
20+
from segmentation_skeleton_metrics.data_handling import swc_loading
21+
from segmentation_skeleton_metrics.data_handling.skeleton_graph import (
22+
FragmentGraph, LabeledGraph
23+
)
24+
from segmentation_skeleton_metrics.utils import util
2225

2326

2427
class DataLoader:
@@ -93,6 +96,7 @@ def load_fragments(self, swc_pointer, gt_graphs):
9396
graph_loader = GraphLoader(
9497
anisotropy=self.anisotropy,
9598
is_groundtruth=False,
99+
label_handler=self.label_handler,
96100
selected_ids=selected_ids,
97101
use_anisotropy=self.use_anisotropy,
98102
)
@@ -112,10 +116,10 @@ def get_all_node_labels(self, graphs):
112116
labels : Set[int]
113117
Unique node labels across all graphs.
114118
"""
115-
labels = set()
119+
node_labels = set()
116120
for graph in graphs.values():
117-
labels |= self.label_handler.get_node_labels(graph)
118-
return labels
121+
node_labels |= self.label_handler.get_node_labels(graph)
122+
return node_labels
119123

120124

121125
class GraphLoader:
@@ -146,7 +150,7 @@ def __init__(
146150
label_mask : ImageReader, optional
147151
Predicted segmentation mask.
148152
selected_ids : Set[int], optional
149-
Only SWC files with an swc_id contained in this set are read.
153+
Only SWC files with a name contained in this set are read.
150154
Default is None.
151155
use_anisotropy : bool, optional
152156
Indication of whether coordinates in SWC files should be converted
@@ -161,7 +165,7 @@ def __init__(
161165

162166
# Reader
163167
anisotropy = anisotropy if use_anisotropy else (1.0, 1.0, 1.0)
164-
self.swc_reader = swc_util.Reader(
168+
self.swc_reader = swc_loading.Reader(
165169
anisotropy, selected_ids=selected_ids
166170
)
167171

@@ -181,11 +185,11 @@ def run(self, swc_pointer):
181185
Dictionary where the keys are unique identifiers (i.e. filenames
182186
of SWC files) and values are the corresponding SkeletonGraph.
183187
"""
184-
graph_dict = self._build_graphs_from_swcs(swc_pointer)
188+
graphs = self._build_graphs_from_swcs(swc_pointer)
185189
if self.label_mask:
186-
for key in graph_dict:
187-
self._label_graph(graph_dict[key])
188-
return graph_dict
190+
for name in graphs:
191+
self._label_graph(graphs[name])
192+
return graphs
189193

190194
# --- Build Graphs ---
191195
def _build_graphs_from_swcs(self, swc_pointer):
@@ -246,25 +250,40 @@ def to_graph(self, swc_dict):
246250
Graph built from an SWC file.
247251
"""
248252
# Initialize graph
249-
graph = SkeletonGraph(
250-
anisotropy=self.anisotropy, is_groundtruth=self.is_groundtruth
251-
)
252-
graph.init_voxels(swc_dict["voxel"])
253-
graph.set_filename(swc_dict["swc_id"] + ".swc")
254-
graph.set_nodes(len(swc_dict["id"]))
253+
graph = self._init_graph(swc_dict)
255254

256-
# Build graph
255+
# Build graph structure
257256
id_lookup = dict()
258257
for i, id_i in enumerate(swc_dict["id"]):
259258
id_lookup[id_i] = i
260259
if swc_dict["pid"][i] != -1:
261260
parent = id_lookup[swc_dict["pid"][i]]
262261
graph.add_edge(i, parent)
263262
graph.run_length += graph.dist(i, parent)
263+
graph.prune_branches()
264+
return {graph.name: graph}
265+
266+
def _init_graph(self, swc_dict):
267+
# Instantiate graph
268+
if self.is_groundtruth:
269+
graph = LabeledGraph(
270+
anisotropy=self.anisotropy, name=swc_dict["swc_name"]
271+
)
272+
else:
273+
segment_id = util.get_segment_id(swc_dict["swc_name"])
274+
label = self.label_handler.get(segment_id)
275+
graph = FragmentGraph(
276+
anisotropy=self.anisotropy,
277+
name=swc_dict["swc_name"],
278+
label=label,
279+
segment_id=segment_id
280+
)
264281

265-
# Set graph-level attributes
266-
graph.graph["n_initial_edges"] = graph.number_of_edges()
267-
return {swc_dict["swc_id"]: graph}
282+
# Set class attributes
283+
graph.init_voxels(swc_dict["voxel"])
284+
graph.set_filename(swc_dict["swc_name"] + ".swc")
285+
graph.set_nodes(len(swc_dict["id"]))
286+
return graph
268287

269288
# --- Label Graphs ---
270289
def _label_graph(self, graph):
@@ -311,11 +330,11 @@ def _label_graph(self, graph):
311330
)
312331

313332
# Store results
314-
graph.init_labels()
333+
graph.init_node_labels()
315334
for thread in as_completed(threads):
316335
node_to_label = thread.result()
317336
for i, label in node_to_label.items():
318-
graph.labels[i] = label
337+
graph.node_labels[i] = label
319338

320339
def get_patch_labels(self, graph, nodes):
321340
"""
@@ -366,6 +385,9 @@ def to_local_voxels(self, graph, i, offset):
366385
offset = np.array(offset)
367386
return tuple(voxel - offset)
368387

388+
def fix_label_misalignments(self, graph):
389+
pass
390+
369391

370392
class LabelHandler:
371393
"""
@@ -527,7 +549,7 @@ def get_node_labels(self, graph):
527549
labels : Set[int]
528550
Labels corresponding to nodes in the graph identified by "key".
529551
"""
530-
labels = graph.get_labels()
552+
labels = graph.get_node_labels()
531553
if self.use_mapping():
532554
labels = set().union(*(self.inverse_mapping[l] for l in labels))
533555
return labels

0 commit comments

Comments
 (0)