2323import numpy as np
2424import os
2525
26- from segmentation_skeleton_metrics import graph_segmentation_alignment as gsa
26+ from segmentation_skeleton_metrics import split_detection
2727from segmentation_skeleton_metrics .utils import (
2828 graph_util as gutil ,
2929 img_util ,
@@ -112,26 +112,20 @@ def __init__(
112112 self .output_dir = output_dir
113113 self .preexisting_merges = preexisting_merges
114114
115- # Load ground truth
116- print ("\n (1) Load Ground Truth" )
117- assert type (valid_labels ) is set if valid_labels else True
115+ # Load data
116+ assert isinstance (valid_labels , set ) if valid_labels else True
118117 self .label_mask = pred_labels
119118 self .valid_labels = valid_labels
120119 self .init_label_map (connections_path )
121- self .init_graphs (gt_pointer )
122-
123- print ("\n (2) Load Prediction" )
124- if fragments_pointer :
125- self .load_fragments (fragments_pointer )
126- else :
127- self .fragment_graphs = None
120+ self .load_groundtruth (gt_pointer )
121+ self .load_fragments (fragments_pointer )
128122
129123 # Initialize writer
130124 self .save_projections = save_projections
131125 if self .save_projections :
132126 self .init_zip_writer ()
133127
134- # -- Initialize and Label Graphs --
128+ # --- Load Data - --
135129 def init_label_map (self , path ):
136130 """
137131 Initializes a dictionary that maps a label to its equivalent label in
@@ -157,7 +151,7 @@ def init_label_map(self, path):
157151 self .label_map = None
158152 self .inverse_label_map = None
159153
160- def init_graphs (self , paths ):
154+ def load_groundtruth (self , swc_pointer ):
161155 """
162156 Initializes "self.graphs" by iterating over "paths" which corresponds
163157 to neurons in the ground truth.
@@ -174,8 +168,13 @@ def init_graphs(self, paths):
174168
175169 """
176170 # Build graphs
177- swc_dicts = swc_util .Reader ().load (paths )
178- self .graphs = self .build_graphs (swc_dicts )
171+ print ("\n (1) Load Ground Truth" )
172+ graph_builder = gutil .GraphBuilder (
173+ anisotropy = self .anisotropy ,
174+ label_mask = self .label_mask ,
175+ use_anisotropy = False ,
176+ )
177+ self .graphs = graph_builder .run (swc_pointer )
179178
180179 # Label nodes
181180 self .key_to_label_to_nodes = dict () # {id: {label: nodes}}
@@ -185,23 +184,18 @@ def init_graphs(self, paths):
185184 self .graphs [key ]
186185 )
187186
188- def build_graphs (self , swc_dicts ):
189- graphs = dict ()
190- with ProcessPoolExecutor () as executor :
191- # Assign processes
192- processes = list ()
193- for swc_dict in swc_dicts :
194- processes .append (
195- executor .submit (gutil .to_graph , swc_dict )
196- )
187+ def load_fragments (self , swc_pointer ):
188+ print ("\n (2) Load Fragments" )
189+ if swc_pointer :
190+ graph_builder = gutil .GraphBuilder (
191+ anisotropy = self .anisotropy ,
192+ selected_ids = self .get_all_node_labels (),
193+ use_anisotropy = True ,
194+ )
195+ self .fragment_graphs = graph_builder .run (swc_pointer )
196+ else :
197+ self .fragment_graphs = None
197198
198- # Store results
199- pbar = tqdm (total = len (processes ), desc = "Build Graphs" )
200- for process in as_completed (processes ):
201- graphs .update (process .result ())
202- pbar .update (1 )
203- return graphs
204-
205199 def label_graphs (self , key , batch_size = 128 ):
206200 """
207201 Iterates over nodes in "graph" and stores the corresponding label from
@@ -259,7 +253,7 @@ def get_patch_labels(self, key, nodes):
259253 # Get bounding box
260254 bbox = {"min" : [np .inf , np .inf , np .inf ], "max" : [0 , 0 , 0 ]}
261255 for i in nodes :
262- voxel = deepcopy ( self .graphs [key ].graph ["voxel" ][i ])
256+ voxel = self .graphs [key ].graph ["voxel" ][i ]
263257 for idx in range (3 ):
264258 if voxel [idx ] < bbox ["min" ][idx ]:
265259 bbox ["min" ][idx ] = voxel [idx ]
@@ -363,39 +357,6 @@ def get_node_labels(self, key, inverse_bool=False):
363357 else :
364358 return set (self .key_to_label_to_nodes [key ].keys ())
365359
366- # -- Load Fragments --
367- def load_fragments (self , fragments_pointer ):
368- """
369- Loads and filters swc files from a local zip. These swc files are
370- assumed to be fragments from a predicted segmentation.
371-
372- Parameters
373- ----------
374- zip_path : str
375- Path to the local zip file containing the fragments
376-
377- Returns
378- -------
379- dict
380- Dictionary that maps an swc id to the fragment graph.
381-
382- """
383- # Read SWC files
384- reader = swc_util .Reader (anisotropy = self .anisotropy )
385- swc_dicts = deque (reader .load (fragments_pointer ))
386-
387- # Filter SWC files
388- filtered_swc_dicts = list ()
389- labels = self .get_all_node_labels ()
390- while len (swc_dicts ) > 0 :
391- swc_dict = swc_dicts .popleft ()
392- swc_id = int (swc_dict ["swc_id" ])
393- if swc_id in labels :
394- swc_dict ["swc_id" ] = swc_id
395- filtered_swc_dicts .append (swc_dict )
396- self .fragment_graphs = self .build_graphs (filtered_swc_dicts )
397- print ("# Fragments:" , len (self .fragment_graphs ))
398-
399360 def init_zip_writer (self ):
400361 """
401362 Initializes "self.zip_writer" attribute by setting up a directory for
@@ -513,7 +474,7 @@ def detect_splits(self):
513474 for key , graph in self .graphs .items ():
514475 processes .append (
515476 executor .submit (
516- gsa . correct_graph_misalignments ,
477+ split_detection . run ,
517478 key ,
518479 graph ,
519480 )
0 commit comments