2424from segmentation_skeleton_metrics import split_detection , swc_utils , utils
2525from segmentation_skeleton_metrics .graph_utils import to_xyz_array
2626
27- MERGE_DIST_THRESHOLD = 200
27+ MERGE_DIST_THRESHOLD = 100
2828MIN_CNT = 40
2929
3030
@@ -47,7 +47,7 @@ def __init__(
4747 self ,
4848 gt_pointer ,
4949 pred_labels ,
50- anisotropy = [ 1.0 , 1.0 , 1.0 ] ,
50+ anisotropy = ( 1.0 , 1.0 , 1.0 ) ,
5151 connections_path = None ,
5252 fragments_pointer = None ,
5353 output_dir = None ,
@@ -102,11 +102,10 @@ def __init__(
102102 None.
103103
104104 """
105- # Options
106- self .anisotropy = [1.0 / a_i for a_i in anisotropy ]
105+ # Instance attributes
106+ self .anisotropy = [1.0 / a for a in anisotropy ]
107107 self .connections_path = connections_path
108108 self .output_dir = output_dir
109- self .fragments_pointer = fragments_pointer
110109 self .preexisting_merges = preexisting_merges
111110
112111 # Load Labels, Graphs, Fragments
@@ -116,8 +115,8 @@ def __init__(
116115 self .valid_labels = valid_labels
117116 self .init_label_map (connections_path )
118117 self .init_graphs (gt_pointer )
119- if self . fragments_pointer :
120- self .load_fragments ()
118+ if fragments_pointer :
119+ self .load_fragments (fragments_pointer )
121120
122121 # Initialize writer
123122 self .save_projections = save_projections
@@ -167,8 +166,7 @@ def init_graphs(self, paths):
167166
168167 """
169168 # Read graphs
170- reader = swc_utils .Reader (return_graphs = True )
171- self .graphs = reader .load (paths )
169+ self .graphs = swc_utils .Reader ().load (paths )
172170 self .fragment_graphs = None
173171
174172 # Label nodes
@@ -303,7 +301,7 @@ def get_node_labels(self, key, inverse_bool=False):
303301 return set (self .key_to_label_to_nodes [key ].keys ())
304302
305303 # -- Load Fragments --
306- def load_fragments (self ):
304+ def load_fragments (self , fragments_pointer ):
307305 """
308306 Loads and filters swc files from a local zip. These swc files are
309307 assumed to be fragments from a predicted segmentation.
@@ -320,10 +318,8 @@ def load_fragments(self):
320318
321319 """
322320 # Read fragments
323- reader = swc_utils .Reader (
324- anisotropy = self .anisotropy , return_graphs = True
325- )
326- fragment_graphs = reader .load (self .fragments_pointer )
321+ reader = swc_utils .Reader (anisotropy = self .anisotropy , min_size = 40 )
322+ fragment_graphs = reader .load (fragments_pointer )
327323 self .fragment_ids = set (fragment_graphs .keys ())
328324
329325 # Filter fragments
@@ -360,7 +356,7 @@ def init_zip_writer(self):
360356 for key in self .graphs .keys ():
361357 self .zip_writer [key ] = ZipFile (f"{ output_dir } /{ key } .zip" , "w" )
362358 swc_utils .to_zipped_swc (
363- self .zip_writer [key ], self .graphs [key ], color = "1.0 0.0 0.0"
359+ self .zip_writer [key ], self .graphs [key ],
364360 )
365361
366362 # -- Main Routine --
@@ -391,7 +387,6 @@ def run(self):
391387
392388 # Merge evaluation
393389 self .detect_merges ()
394- self .compute_projected_run_lengths ()
395390 self .quantify_merges ()
396391
397392 # Compute metrics
@@ -507,7 +502,6 @@ def detect_merges(self):
507502 self .merged_edges_cnt = self .init_counter ()
508503 self .merged_percent = self .init_counter ()
509504 self .merged_labels = set ()
510- self .projected_run_length = defaultdict (int )
511505
512506 # Count total merges
513507 if self .fragment_graphs :
@@ -557,7 +551,6 @@ def count_merges(self, key, kdtree):
557551 # Check if fragment is a merge mistake
558552 for label in labels :
559553 rl = self .fragment_graphs [label ].graph ["run_length" ]
560- self .projected_run_length [key ] += rl
561554 self .is_fragment_merge (key , label , kdtree )
562555
563556 def is_fragment_merge (self , key , label , kdtree ):
@@ -725,37 +718,6 @@ def get_merged_label(self, label):
725718 return l
726719 return self .inverse_label_map [label ]
727720
728- # -- Projected Run Lengths --
729- def compute_projected_run_lengths (self ):
730- """
731- Computes the projected run length for each graph in "self.graphs".
732- First, we detect fragments from "self.fragments_pointer" that are
733- sufficiently close (as determined by projection distances) to the
734- given graph. The projected run length is the sum of the path lengths
735- of fragments that were detected.
736-
737- Parameters
738- ----------
739- None
740-
741- Returns
742- -------
743- None
744-
745- """
746- # Initializations
747- self .run_length_ratio = dict ()
748- self .target_run_length = dict ()
749-
750- # Compute run lengths
751- for key in self .graphs :
752- target_rl = self .get_run_length (key )
753- projected_rl = self .projected_run_length [key ]
754-
755- self .projected_run_length [key ] = projected_rl
756- self .target_run_length [key ] = target_rl
757- self .run_length_ratio [key ] = projected_rl / target_rl
758-
759721 # -- Compute Metrics --
760722 def compile_results (self ):
761723 """
@@ -816,9 +778,6 @@ def generate_full_results(self):
816778 "% omit" : generate_result (keys , self .omit_percent ),
817779 "% merged" : generate_result (keys , self .merged_percent ),
818780 "edge accuracy" : generate_result (keys , self .edge_accuracy ),
819- "projected_rl" : generate_result (keys , self .projected_run_length ),
820- "target_rl" : generate_result (keys , self .target_run_length ),
821- "rl_ratio" : generate_result (keys , self .run_length_ratio ),
822781 "erl" : generate_result (keys , self .erl ),
823782 "normalized erl" : generate_result (keys , self .normalized_erl ),
824783 }
@@ -844,9 +803,6 @@ def generate_avg_results(self):
844803 "% omit" : self .avg_result (self .omit_percent ),
845804 "% merged" : self .avg_result (self .merged_percent ),
846805 "edge accuracy" : self .avg_result (self .edge_accuracy ),
847- "projected_rl" : self .avg_result (self .projected_run_length ),
848- "target_rl" : self .avg_result (self .target_run_length ),
849- "rl_ratio" : self .avg_result (self .run_length_ratio ),
850806 "erl" : self .avg_result (self .erl ),
851807 "normalized erl" : self .avg_result (self .normalized_erl ),
852808 }
0 commit comments