1717import networkx as nx
1818import numpy as np
1919
20- from segmentation_skeleton_metrics .data_handling .skeleton_graph import SkeletonGraph
21- from segmentation_skeleton_metrics .utils import swc_util , util
20+ from segmentation_skeleton_metrics .data_handling import swc_loading
21+ from segmentation_skeleton_metrics .data_handling .skeleton_graph import (
22+ FragmentGraph , LabeledGraph
23+ )
24+ from segmentation_skeleton_metrics .utils import util
2225
2326
2427class DataLoader :
@@ -93,6 +96,7 @@ def load_fragments(self, swc_pointer, gt_graphs):
9396 graph_loader = GraphLoader (
9497 anisotropy = self .anisotropy ,
9598 is_groundtruth = False ,
99+ label_handler = self .label_handler ,
96100 selected_ids = selected_ids ,
97101 use_anisotropy = self .use_anisotropy ,
98102 )
@@ -112,10 +116,10 @@ def get_all_node_labels(self, graphs):
112116 labels : Set[int]
113117 Unique node labels across all graphs.
114118 """
115- labels = set ()
119+ node_labels = set ()
116120 for graph in graphs .values ():
117- labels |= self .label_handler .get_node_labels (graph )
118- return labels
121+ node_labels |= self .label_handler .get_node_labels (graph )
122+ return node_labels
119123
120124
121125class GraphLoader :
@@ -146,7 +150,7 @@ def __init__(
146150 label_mask : ImageReader, optional
147151 Predicted segmentation mask.
148152 selected_ids : Set[int], optional
149- Only SWC files with an swc_id contained in this set are read.
153+ Only SWC files with a name contained in this set are read.
150154 Default is None.
151155 use_anisotropy : bool, optional
152156 Indication of whether coordinates in SWC files should be converted
@@ -161,7 +165,7 @@ def __init__(
161165
162166 # Reader
163167 anisotropy = anisotropy if use_anisotropy else (1.0 , 1.0 , 1.0 )
164- self .swc_reader = swc_util .Reader (
168+ self .swc_reader = swc_loading .Reader (
165169 anisotropy , selected_ids = selected_ids
166170 )
167171
@@ -181,11 +185,11 @@ def run(self, swc_pointer):
181185 Dictionary where the keys are unique identifiers (i.e. filenames
182186 of SWC files) and values are the corresponding SkeletonGraph.
183187 """
184- graph_dict = self ._build_graphs_from_swcs (swc_pointer )
188+ graphs = self ._build_graphs_from_swcs (swc_pointer )
185189 if self .label_mask :
186- for key in graph_dict :
187- self ._label_graph (graph_dict [ key ])
188- return graph_dict
190+ for name in graphs :
191+ self ._label_graph (graphs [ name ])
192+ return graphs
189193
190194 # --- Build Graphs ---
191195 def _build_graphs_from_swcs (self , swc_pointer ):
@@ -246,25 +250,40 @@ def to_graph(self, swc_dict):
246250 Graph built from an SWC file.
247251 """
248252 # Initialize graph
249- graph = SkeletonGraph (
250- anisotropy = self .anisotropy , is_groundtruth = self .is_groundtruth
251- )
252- graph .init_voxels (swc_dict ["voxel" ])
253- graph .set_filename (swc_dict ["swc_id" ] + ".swc" )
254- graph .set_nodes (len (swc_dict ["id" ]))
253+ graph = self ._init_graph (swc_dict )
255254
256- # Build graph
255+ # Build graph structure
257256 id_lookup = dict ()
258257 for i , id_i in enumerate (swc_dict ["id" ]):
259258 id_lookup [id_i ] = i
260259 if swc_dict ["pid" ][i ] != - 1 :
261260 parent = id_lookup [swc_dict ["pid" ][i ]]
262261 graph .add_edge (i , parent )
263262 graph .run_length += graph .dist (i , parent )
263+ graph .prune_branches ()
264+ return {graph .name : graph }
265+
266+ def _init_graph (self , swc_dict ):
267+ # Instantiate graph
268+ if self .is_groundtruth :
269+ graph = LabeledGraph (
270+ anisotropy = self .anisotropy , name = swc_dict ["swc_name" ]
271+ )
272+ else :
273+ segment_id = util .get_segment_id (swc_dict ["swc_name" ])
274+ label = self .label_handler .get (segment_id )
275+ graph = FragmentGraph (
276+ anisotropy = self .anisotropy ,
277+ name = swc_dict ["swc_name" ],
278+ label = label ,
279+ segment_id = segment_id
280+ )
264281
265- # Set graph-level attributes
266- graph .graph ["n_initial_edges" ] = graph .number_of_edges ()
267- return {swc_dict ["swc_id" ]: graph }
282+ # Set class attributes
283+ graph .init_voxels (swc_dict ["voxel" ])
284+ graph .set_filename (swc_dict ["swc_name" ] + ".swc" )
285+ graph .set_nodes (len (swc_dict ["id" ]))
286+ return graph
268287
269288 # --- Label Graphs ---
270289 def _label_graph (self , graph ):
@@ -311,11 +330,11 @@ def _label_graph(self, graph):
311330 )
312331
313332 # Store results
314- graph .init_labels ()
333+ graph .init_node_labels ()
315334 for thread in as_completed (threads ):
316335 node_to_label = thread .result ()
317336 for i , label in node_to_label .items ():
318- graph .labels [i ] = label
337+ graph .node_labels [i ] = label
319338
320339 def get_patch_labels (self , graph , nodes ):
321340 """
@@ -366,6 +385,9 @@ def to_local_voxels(self, graph, i, offset):
366385 offset = np .array (offset )
367386 return tuple (voxel - offset )
368387
388+ def fix_label_misalignments (self , graph ):
389+ pass
390+
369391
370392class LabelHandler :
371393 """
@@ -527,7 +549,7 @@ def get_node_labels(self, graph):
527549 labels : Set[int]
528550 Labels corresponding to nodes in the graph identified by "key".
529551 """
530- labels = graph .get_labels ()
552+ labels = graph .get_node_labels ()
531553 if self .use_mapping ():
532554 labels = set ().union (* (self .inverse_mapping [l ] for l in labels ))
533555 return labels
0 commit comments