Skip to content

Commit cddf6e4

Browse files
anna-grimanna-grim
andauthored
Feat localize merge (#101)
* feat: localize merges * feat: save projections * bug: save fragments --------- Co-authored-by: anna-grim <[email protected]>
1 parent cf3285b commit cddf6e4

File tree

1 file changed

+38
-25
lines changed

1 file changed

+38
-25
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
localize_merge=False,
6767
preexisting_merges=None,
6868
save_merges=False,
69-
save_projections=False,
69+
save_fragments=False,
7070
valid_labels=None,
7171
):
7272
"""
@@ -103,7 +103,7 @@ def __init__(
103103
save_merges: bool, optional
104104
Indication of whether to save fragments with a merge mistake. The
105105
default is None.
106-
save_projections : bool, optional
106+
save_fragments : bool, optional
107107
Indication of whether to save fragments that project onto each
108108
ground truth skeleton. The default is False.
109109
valid_labels : set[int], optional
@@ -123,7 +123,7 @@ def __init__(
123123
self.output_dir = output_dir
124124
self.preexisting_merges = preexisting_merges
125125
self.save_merges = save_merges
126-
self.save_projections = save_projections
126+
self.save_fragments = save_fragments
127127

128128
# Label handler
129129
self.label_handler = gutil.LabelHandler(
@@ -135,9 +135,8 @@ def __init__(
135135
self.load_groundtruth(gt_pointer)
136136
self.load_fragments(fragments_pointer)
137137

138-
# Initialize writer
139-
if self.save_merges:
140-
self.init_zip_writer()
138+
# Initialize writers
139+
self.init_zip_writers()
141140

142141
# Initialize fragment projections directory
143142
if self.save_projections:
@@ -346,9 +345,9 @@ def get_node_labels(self, key, inverse_bool=False):
346345
else:
347346
return self.graphs[key].get_labels()
348347

349-
def init_zip_writer(self):
348+
def init_zip_writers(self):
350349
"""
351-
Initializes "self.zip_writer" attribute by setting up a directory for
350+
Initializes "self.merge_writer" attribute by setting up a directory for
352351
output files and creating ZIP files for each graph in "self.graphs".
353352
354353
Parameters
@@ -360,16 +359,31 @@ def init_zip_writer(self):
360359
None
361360
362361
"""
363-
# Initialize output directory
364-
merged_fragments_dir = os.path.join(self.output_dir, "merged_fragments")
365-
util.mkdir(merged_fragments_dir)
366-
367-
# Save intial graphs
368-
self.zip_writer = dict()
369-
for key in self.graphs.keys():
370-
zip_path = f"{merged_fragments_dir}/{key}.zip"
371-
self.zip_writer[key] = ZipFile(zip_path, "w")
372-
self.graphs[key].to_zipped_swc(self.zip_writer[key])
362+
# Merged fragments zip writer
363+
if self.save_merges:
364+
# Initialize directory
365+
merges_dir = os.path.join(self.output_dir, "merged_fragments")
366+
util.mkdir(merged_fragments_dir)
367+
368+
# Initialize zip writer
369+
self.merge_writer = dict()
370+
for key in self.graphs.keys():
371+
zip_path = f"{merged_fragments_dir}/{key}.zip"
372+
self.merge_writer[key] = ZipFile(zip_path, "w")
373+
self.graphs[key].to_zipped_swc(self.merge_writer[key])
374+
375+
# Fragments zip writer
376+
if self.save_fragments:
377+
# Initialize direction
378+
fragments_dir = os.path.join(self.output_dir, "fragments")
379+
util.mkdir(fragments_dir)
380+
381+
# Initialize zip writer
382+
self.fragment_writer = dict()
383+
for key in self.graphs.keys():
384+
zip_path = f"{fragments_dir}/{key}.zip"
385+
self.fragment_writer[key] = ZipFile(zip_path, "w")
386+
self.graphs[key].to_zipped_swc(self.fragment_writer[key])
373387

374388
# -- Main Routine --
375389
def run(self):
@@ -539,12 +553,6 @@ def count_merges(self, key, kdtree):
539553
None
540554
541555
"""
542-
# Initialize zip writer
543-
if self.save_projections:
544-
zip_path = os.path.join(self.projections_dir, key + ".zip")
545-
zip_writer = ZipFile(zip_path, "w")
546-
#self.graphs[key].to_zipped_swc(zip_writer)
547-
548556
# Iterate over fragments that intersect with GT skeleton
549557
for label in self.get_node_labels(key):
550558
nodes = self.graphs[key].nodes_with_label(label)
@@ -580,6 +588,7 @@ def is_fragment_merge(self, key, label, kdtree):
580588
"""
581589
# Search graphs
582590
for fragment_graph in self.find_graph_from_label(label):
591+
# Search for merge
583592
max_dist = 0
584593
min_dist = np.inf
585594
for voxel in fragment_graph.voxels:
@@ -601,11 +610,15 @@ def is_fragment_merge(self, key, label, kdtree):
601610

602611
# Save merged fragment (if applicable)
603612
if self.save_merges:
604-
fragment_graph.to_zipped_swc(self.zip_writer[key])
613+
fragment_graph.to_zipped_swc(self.merge_writer[key])
605614
if self.localize_merge:
606615
self.find_merge_site(key, fragment_graph, kdtree)
607616
break
608617

618+
# Save fragment (if applicable)
619+
if self.save_fragments and min_dist < 3:
620+
fragment_graph.to_zipped_swc(self.fragment_writer[key])
621+
609622
def adjust_metrics(self, key):
610623
"""
611624
Adjusts the metrics of the graph associated with the given key by

0 commit comments

Comments
 (0)