77
88"""
99
10- import os
11- from collections import defaultdict
10+
1211from concurrent .futures import ThreadPoolExecutor , as_completed
1312from copy import deepcopy
13+ from scipy .spatial import KDTree
1414from time import time
15+ from tqdm import tqdm
1516from zipfile import ZipFile
1617
1718import networkx as nx
1819import numpy as np
20+ import os
1921import tensorstore as ts
20- from scipy .spatial import KDTree
21- from tqdm import tqdm
2222
23- from segmentation_skeleton_metrics import graph_utils as gutils
24- from segmentation_skeleton_metrics import split_detection , swc_utils , utils
25- from segmentation_skeleton_metrics .graph_utils import to_xyz_array
23+ from segmentation_skeleton_metrics import split_detection
24+ from segmentation_skeleton_metrics .utils import (
25+ graph_util as gutil ,
26+ swc_util ,
27+ util
28+ )
29+ from segmentation_skeleton_metrics .utils .graph_util import to_array
2630
2731MERGE_DIST_THRESHOLD = 100
2832MIN_CNT = 40
@@ -62,7 +66,7 @@ def __init__(
6266 Parameters
6367 ----------
6468 gt_pointer : dict/str/list[str]
65- Pointer to ground truth swcs, see "swc_utils .Reader" for further
69+ Pointer to ground truth swcs, see "swc_util .Reader" for further
6670 documentation. Note these swc files are assumed to be stored in
6771 image coordinates.
6872 pred_labels : numpy.ndarray or tensorstore.TensorStore
@@ -75,12 +79,12 @@ def __init__(
7579 that were merged into a single segment. The default is None.
7680 fragments_pointer : dict/str/list[str], optional
7781 Pointer to fragments (i.e. swcs) corresponding to "pred_labels",
78- see "swc_utils .Reader" for further documentation. Note these swc
82+ see "swc_util .Reader" for further documentation. Note these swc
7983 files may be stored in either world or image coordinates. If the
8084 swcs are stored in world coordinates, then provide the world to
8185 image coordinates anisotropy factor. Note the filename of each swc
8286 is assumed to "segment_id.swc" where segment_id cooresponds to the
83- segment id from "pred_labels". The default is None.
87+ segment id from "pred_labels". The default is None.
8488 output_dir : str, optional
8589 Path to directory that mistake sites are written to. The default
8690 is None.
@@ -103,7 +107,7 @@ def __init__(
103107
104108 """
105109 # Instance attributes
106- self .anisotropy = [ 1.0 / a for a in anisotropy ]
110+ self .anisotropy = anisotropy
107111 self .connections_path = connections_path
108112 self .output_dir = output_dir
109113 self .preexisting_merges = preexisting_merges
@@ -142,7 +146,7 @@ def init_label_map(self, path):
142146 """
143147 if path :
144148 assert self .valid_labels is not None , "Must provide valid labels!"
145- self .label_map , self .inverse_label_map = utils .init_label_map (
149+ self .label_map , self .inverse_label_map = util .init_label_map (
146150 path , self .valid_labels
147151 )
148152 else :
@@ -166,14 +170,14 @@ def init_graphs(self, paths):
166170
167171 """
168172 # Read graphs
169- self .graphs = swc_utils .Reader ().load (paths )
173+ self .graphs = swc_util .Reader ().load (paths )
170174 self .fragment_graphs = None
171175
172176 # Label nodes
173177 self .key_to_label_to_nodes = dict () # {id: {label: nodes}}
174178 for key in tqdm (self .graphs , desc = "Labeling Graphs" ):
175179 self .set_node_labels (key )
176- self .key_to_label_to_nodes [key ] = gutils .init_label_to_nodes (
180+ self .key_to_label_to_nodes [key ] = gutil .init_label_to_nodes (
177181 self .graphs [key ]
178182 )
179183
@@ -318,7 +322,7 @@ def load_fragments(self, fragments_pointer):
318322
319323 """
320324 # Read fragments
321- reader = swc_utils .Reader (anisotropy = self .anisotropy , min_size = 40 )
325+ reader = swc_util .Reader (anisotropy = self .anisotropy , min_size = 40 )
322326 fragment_graphs = reader .load (fragments_pointer )
323327 self .fragment_ids = set (fragment_graphs .keys ())
324328
@@ -349,13 +353,13 @@ def init_zip_writer(self):
349353 """
350354 # Initialize output directory
351355 output_dir = os .path .join (self .output_dir , "projections" )
352- utils .mkdir (output_dir )
356+ util .mkdir (output_dir )
353357
354358 # Save intial graphs
355359 self .zip_writer = dict ()
356360 for key in self .graphs .keys ():
357361 self .zip_writer [key ] = ZipFile (f"{ output_dir } /{ key } .zip" , "w" )
358- swc_utils .to_zipped_swc (
362+ swc_util .to_zipped_swc (
359363 self .zip_writer [key ], self .graphs [key ],
360364 )
361365
@@ -419,7 +423,7 @@ def adjust_metrics(self, key):
419423
420424 # Adjust metrics
421425 n_edges = subgraph .number_of_edges ()
422- rls = gutils .compute_run_lengths (subgraph )
426+ rls = gutil .compute_run_lengths (subgraph )
423427 self .graphs [key ].graph ["run_length" ] -= np .sum (rls )
424428 self .graphs [key ].graph ["n_edges" ] -= n_edges
425429
@@ -448,13 +452,13 @@ def detect_splits(self):
448452 graph = split_detection .run (graph , self .graphs [key ])
449453
450454 # Update graph by removing omits (i.e. nodes labeled 0)
451- self .graphs [key ] = gutils .delete_nodes (graph , 0 )
452- self .key_to_label_to_nodes [key ] = gutils .init_label_to_nodes (
455+ self .graphs [key ] = gutil .delete_nodes (graph , 0 )
456+ self .key_to_label_to_nodes [key ] = gutil .init_label_to_nodes (
453457 self .graphs [key ]
454458 )
455459
456460 # Report runtime
457- t , unit = utils .time_writer (time () - t0 )
461+ t , unit = util .time_writer (time () - t0 )
458462 print (f"Runtime: { round (t , 2 )} { unit } \n " )
459463
460464 def quantify_splits (self ):
@@ -478,7 +482,7 @@ def quantify_splits(self):
478482 n_pred_edges = self .graphs [key ].number_of_edges ()
479483 n_target_edges = self .graphs [key ].graph ["n_edges" ]
480484
481- self .split_cnt [key ] = gutils .count_splits (self .graphs [key ])
485+ self .split_cnt [key ] = gutil .count_splits (self .graphs [key ])
482486 self .omit_cnts [key ] = n_target_edges - n_pred_edges
483487 self .omit_percent [key ] = 1 - n_pred_edges / n_target_edges
484488
@@ -507,7 +511,7 @@ def detect_merges(self):
507511 if self .fragment_graphs :
508512 for key , graph in self .graphs .items ():
509513 if graph .number_of_nodes () > 0 :
510- kdtree = KDTree (gutils . to_xyz_array (graph ))
514+ kdtree = KDTree (gutil . to_array (graph ))
511515 self .count_merges (key , kdtree )
512516
513517 # Process merges
@@ -574,20 +578,20 @@ def is_fragment_merge(self, key, label, kdtree):
574578 None
575579
576580 """
577- for xyz in to_xyz_array (self .fragment_graphs [label ])[::5 ]:
578- if kdtree .query (xyz )[0 ] > MERGE_DIST_THRESHOLD :
581+ for voxel in to_array (self .fragment_graphs [label ])[::2 ]:
582+ if kdtree .query (voxel )[0 ] > MERGE_DIST_THRESHOLD :
579583 # Check whether to take inverse of label
580584 if self .inverse_label_map :
581585 equivalent_label = self .label_map [label ]
582586 else :
583587 equivalent_label = label
584588
585589 # Record merge mistake
586- xyz = utils .to_world (xyz )
590+ xyz = util .to_world (voxel )
587591 self .merge_cnt [key ] += 1
588592 self .merged_labels .add ((key , equivalent_label , tuple (xyz )))
589593 if self .save_projections :
590- swc_utils .to_zipped_swc (
594+ swc_util .to_zipped_swc (
591595 self .zip_writer [key ], self .fragment_graphs [label ]
592596 )
593597 return
@@ -871,7 +875,7 @@ def compute_erl(self):
871875 total_run_length = 0
872876 for key in self .graphs :
873877 run_length = self .get_run_length (key )
874- run_lengths = gutils .compute_run_lengths (self .graphs [key ])
878+ run_lengths = gutil .compute_run_lengths (self .graphs [key ])
875879 total_run_length += run_length
876880 wgt = run_lengths / max (np .sum (run_lengths ), 1 )
877881
@@ -924,7 +928,7 @@ def list_metrics(self):
924928 ]
925929 return metrics
926930
927- # -- Utils --
931+ # -- util --
928932 def init_counter (self ):
929933 """
930934 Initializes a dictionary that is used to count some type of mistake
@@ -943,7 +947,7 @@ def init_counter(self):
943947 return {key : 0 for key in self .graphs }
944948
945949
946- # -- utils --
950+ # -- util --
947951def find_sites (graphs , get_labels ):
948952 """
949953 Detects merges between ground truth graphs which are considered to be
0 commit comments