Skip to content

Commit 2fdb732

Browse files
anna-grimanna-grim
andauthored
refactor: optimized swc reader (#94)
Co-authored-by: anna-grim <[email protected]>
1 parent 432a9cf commit 2fdb732

File tree

2 files changed

+183
-209
lines changed

2 files changed

+183
-209
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 11 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from segmentation_skeleton_metrics import split_detection, swc_utils, utils
2525
from segmentation_skeleton_metrics.graph_utils import to_xyz_array
2626

27-
MERGE_DIST_THRESHOLD = 200
27+
MERGE_DIST_THRESHOLD = 100
2828
MIN_CNT = 40
2929

3030

@@ -47,7 +47,7 @@ def __init__(
4747
self,
4848
gt_pointer,
4949
pred_labels,
50-
anisotropy=[1.0, 1.0, 1.0],
50+
anisotropy=(1.0, 1.0, 1.0),
5151
connections_path=None,
5252
fragments_pointer=None,
5353
output_dir=None,
@@ -102,11 +102,10 @@ def __init__(
102102
None.
103103
104104
"""
105-
# Options
106-
self.anisotropy = [1.0 / a_i for a_i in anisotropy]
105+
# Instance attributes
106+
self.anisotropy = [1.0 / a for a in anisotropy]
107107
self.connections_path = connections_path
108108
self.output_dir = output_dir
109-
self.fragments_pointer = fragments_pointer
110109
self.preexisting_merges = preexisting_merges
111110

112111
# Load Labels, Graphs, Fragments
@@ -116,8 +115,8 @@ def __init__(
116115
self.valid_labels = valid_labels
117116
self.init_label_map(connections_path)
118117
self.init_graphs(gt_pointer)
119-
if self.fragments_pointer:
120-
self.load_fragments()
118+
if fragments_pointer:
119+
self.load_fragments(fragments_pointer)
121120

122121
# Initialize writer
123122
self.save_projections = save_projections
@@ -167,8 +166,7 @@ def init_graphs(self, paths):
167166
168167
"""
169168
# Read graphs
170-
reader = swc_utils.Reader(return_graphs=True)
171-
self.graphs = reader.load(paths)
169+
self.graphs = swc_utils.Reader().load(paths)
172170
self.fragment_graphs = None
173171

174172
# Label nodes
@@ -303,7 +301,7 @@ def get_node_labels(self, key, inverse_bool=False):
303301
return set(self.key_to_label_to_nodes[key].keys())
304302

305303
# -- Load Fragments --
306-
def load_fragments(self):
304+
def load_fragments(self, fragments_pointer):
307305
"""
308306
Loads and filters swc files from a local zip. These swc files are
309307
assumed to be fragments from a predicted segmentation.
@@ -320,10 +318,8 @@ def load_fragments(self):
320318
321319
"""
322320
# Read fragments
323-
reader = swc_utils.Reader(
324-
anisotropy=self.anisotropy, return_graphs=True
325-
)
326-
fragment_graphs = reader.load(self.fragments_pointer)
321+
reader = swc_utils.Reader(anisotropy=self.anisotropy, min_size=40)
322+
fragment_graphs = reader.load(fragments_pointer)
327323
self.fragment_ids = set(fragment_graphs.keys())
328324

329325
# Filter fragments
@@ -360,7 +356,7 @@ def init_zip_writer(self):
360356
for key in self.graphs.keys():
361357
self.zip_writer[key] = ZipFile(f"{output_dir}/{key}.zip", "w")
362358
swc_utils.to_zipped_swc(
363-
self.zip_writer[key], self.graphs[key], color="1.0 0.0 0.0"
359+
self.zip_writer[key], self.graphs[key],
364360
)
365361

366362
# -- Main Routine --
@@ -391,7 +387,6 @@ def run(self):
391387

392388
# Merge evaluation
393389
self.detect_merges()
394-
self.compute_projected_run_lengths()
395390
self.quantify_merges()
396391

397392
# Compute metrics
@@ -507,7 +502,6 @@ def detect_merges(self):
507502
self.merged_edges_cnt = self.init_counter()
508503
self.merged_percent = self.init_counter()
509504
self.merged_labels = set()
510-
self.projected_run_length = defaultdict(int)
511505

512506
# Count total merges
513507
if self.fragment_graphs:
@@ -557,7 +551,6 @@ def count_merges(self, key, kdtree):
557551
# Check if fragment is a merge mistake
558552
for label in labels:
559553
rl = self.fragment_graphs[label].graph["run_length"]
560-
self.projected_run_length[key] += rl
561554
self.is_fragment_merge(key, label, kdtree)
562555

563556
def is_fragment_merge(self, key, label, kdtree):
@@ -725,37 +718,6 @@ def get_merged_label(self, label):
725718
return l
726719
return self.inverse_label_map[label]
727720

728-
# -- Projected Run Lengths --
729-
def compute_projected_run_lengths(self):
730-
"""
731-
Computes the projected run length for each graph in "self.graphs".
732-
First, we detect fragments from "self.fragments_pointer" that are
733-
sufficiently close (as determined by projection distances) to the
734-
given graph. The projected run length is the sum of the path lengths
735-
of fragments that were detected.
736-
737-
Parameters
738-
----------
739-
None
740-
741-
Returns
742-
-------
743-
None
744-
745-
"""
746-
# Initializations
747-
self.run_length_ratio = dict()
748-
self.target_run_length = dict()
749-
750-
# Compute run lengths
751-
for key in self.graphs:
752-
target_rl = self.get_run_length(key)
753-
projected_rl = self.projected_run_length[key]
754-
755-
self.projected_run_length[key] = projected_rl
756-
self.target_run_length[key] = target_rl
757-
self.run_length_ratio[key] = projected_rl / target_rl
758-
759721
# -- Compute Metrics --
760722
def compile_results(self):
761723
"""
@@ -816,9 +778,6 @@ def generate_full_results(self):
816778
"% omit": generate_result(keys, self.omit_percent),
817779
"% merged": generate_result(keys, self.merged_percent),
818780
"edge accuracy": generate_result(keys, self.edge_accuracy),
819-
"projected_rl": generate_result(keys, self.projected_run_length),
820-
"target_rl": generate_result(keys, self.target_run_length),
821-
"rl_ratio": generate_result(keys, self.run_length_ratio),
822781
"erl": generate_result(keys, self.erl),
823782
"normalized erl": generate_result(keys, self.normalized_erl),
824783
}
@@ -844,9 +803,6 @@ def generate_avg_results(self):
844803
"% omit": self.avg_result(self.omit_percent),
845804
"% merged": self.avg_result(self.merged_percent),
846805
"edge accuracy": self.avg_result(self.edge_accuracy),
847-
"projected_rl": self.avg_result(self.projected_run_length),
848-
"target_rl": self.avg_result(self.target_run_length),
849-
"rl_ratio": self.avg_result(self.run_length_ratio),
850806
"erl": self.avg_result(self.erl),
851807
"normalized erl": self.avg_result(self.normalized_erl),
852808
}

0 commit comments

Comments
 (0)