Skip to content

Commit 4e3763e

Browse files
author
anna-grim
committed
bug: save fragments
1 parent fb61d52 commit 4e3763e

File tree

2 files changed

+40
-35
lines changed

2 files changed

+40
-35
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
localize_merge=False,
6767
preexisting_merges=None,
6868
save_merges=False,
69-
save_projections=False,
69+
save_fragments=False,
7070
valid_labels=None,
7171
):
7272
"""
@@ -103,7 +103,7 @@ def __init__(
103103
save_merges: bool, optional
104104
Indication of whether to save fragments with a merge mistake. The
105105
default is None.
106-
save_projections : bool, optional
106+
save_fragments : bool, optional
107107
Indication of whether to save fragments that project onto each
108108
ground truth skeleton. The default is False.
109109
valid_labels : set[int], optional
@@ -123,7 +123,7 @@ def __init__(
123123
self.output_dir = output_dir
124124
self.preexisting_merges = preexisting_merges
125125
self.save_merges = save_merges
126-
self.save_projections = save_projections
126+
self.save_fragments = save_fragments
127127

128128
# Label handler
129129
self.label_handler = gutil.LabelHandler(
@@ -135,14 +135,8 @@ def __init__(
135135
self.load_groundtruth(gt_pointer)
136136
self.load_fragments(fragments_pointer)
137137

138-
# Initialize writer
139-
if self.save_merges:
140-
self.init_zip_writer()
141-
142-
# Initialize fragment projections directory
143-
if self.save_projections:
144-
self.projections_dir = os.path.join(output_dir, "projections")
145-
util.mkdir(self.projections_dir)
138+
# Initialize writers
139+
self.init_zip_writers()
146140

147141
# --- Load Data ---
148142
def load_groundtruth(self, swc_pointer):
@@ -346,9 +340,9 @@ def get_node_labels(self, key, inverse_bool=False):
346340
else:
347341
return self.graphs[key].get_labels()
348342

349-
def init_zip_writer(self):
343+
def init_zip_writers(self):
350344
"""
351-
Initializes "self.zip_writer" attribute by setting up a directory for
345+
Initializes "self.merge_writer" attribute by setting up a directory for
352346
output files and creating ZIP files for each graph in "self.graphs".
353347
354348
Parameters
@@ -360,16 +354,31 @@ def init_zip_writer(self):
360354
None
361355
362356
"""
363-
# Initialize output directory
364-
merged_fragments_dir = os.path.join(self.output_dir, "merged_fragments")
365-
util.mkdir(merged_fragments_dir)
366-
367-
# Save intial graphs
368-
self.zip_writer = dict()
369-
for key in self.graphs.keys():
370-
zip_path = f"{merged_fragments_dir}/{key}.zip"
371-
self.zip_writer[key] = ZipFile(zip_path, "w")
372-
self.graphs[key].to_zipped_swc(self.zip_writer[key])
357+
# Merged fragments zip writer
358+
if self.save_merges:
359+
# Initialize directory
360+
merges_dir = os.path.join(self.output_dir, "merged_fragments")
361+
util.mkdir(merged_fragments_dir)
362+
363+
# Initialize zip writer
364+
self.merge_writer = dict()
365+
for key in self.graphs.keys():
366+
zip_path = f"{merged_fragments_dir}/{key}.zip"
367+
self.merge_writer[key] = ZipFile(zip_path, "w")
368+
self.graphs[key].to_zipped_swc(self.merge_writer[key])
369+
370+
# Fragments zip writer
371+
if self.save_fragments:
372+
# Initialize direction
373+
fragments_dir = os.path.join(self.output_dir, "fragments")
374+
util.mkdir(fragments_dir)
375+
376+
# Initialize zip writer
377+
self.fragment_writer = dict()
378+
for key in self.graphs.keys():
379+
zip_path = f"{fragments_dir}/{key}.zip"
380+
self.fragment_writer[key] = ZipFile(zip_path, "w")
381+
self.graphs[key].to_zipped_swc(self.fragment_writer[key])
373382

374383
# -- Main Routine --
375384
def run(self):
@@ -539,22 +548,13 @@ def count_merges(self, key, kdtree):
539548
None
540549
541550
"""
542-
# Initialize zip writer
543-
if self.save_projections:
544-
zip_path = os.path.join(self.projections_dir, key + ".zip")
545-
zip_writer = ZipFile(zip_path, "w")
546-
#self.graphs[key].to_zipped_swc(zip_writer)
547-
548551
# Iterate over fragments that intersect with GT skeleton
549552
for label in self.get_node_labels(key):
550553
nodes = self.graphs[key].nodes_with_label(label)
551554
if len(nodes) > 40:
552555
for label in self.label_handler.get_class(label):
553556
if label in self.fragment_ids:
554557
self.is_fragment_merge(key, label, kdtree)
555-
if self.save_projections:
556-
fragment_graph = self.find_graph_from_label(label)[0]
557-
fragment_graph.to_zipped_swc(zip_writer)
558558

559559
def is_fragment_merge(self, key, label, kdtree):
560560
"""
@@ -580,6 +580,7 @@ def is_fragment_merge(self, key, label, kdtree):
580580
"""
581581
# Search graphs
582582
for fragment_graph in self.find_graph_from_label(label):
583+
# Search for merge
583584
max_dist = 0
584585
min_dist = np.inf
585586
for voxel in fragment_graph.voxels:
@@ -601,11 +602,15 @@ def is_fragment_merge(self, key, label, kdtree):
601602

602603
# Save merged fragment (if applicable)
603604
if self.save_merges:
604-
fragment_graph.to_zipped_swc(self.zip_writer[key])
605+
fragment_graph.to_zipped_swc(self.merge_writer[key])
605606
if self.localize_merge:
606607
self.find_merge_site(key, fragment_graph, kdtree)
607608
break
608609

610+
# Save fragment (if applicable)
611+
if self.save_fragments and min_dist < 3:
612+
fragment_graph.to_zipped_swc(self.fragment_writer[key])
613+
609614
def adjust_metrics(self, key):
610615
"""
611616
Adjusts the metrics of the graph associated with the given key by

src/segmentation_skeleton_metrics/split_detection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,15 @@ def run(process_id, graph):
4646
label_i = int(graph.labels[i])
4747
label_j = int(graph.labels[j])
4848
if is_split(label_i, label_j):
49-
graph.remove_edge(i, j) temp
49+
graph.remove_edge(i, j)
5050
split_cnt += 1
5151
elif label_j == 0:
5252
check_misalignment(graph, visited_edges, i, j)
5353
visited_edges.add(frozenset({i, j}))
5454

5555
# Finish
5656
split_percent = split_cnt / graph.graph["n_edges"]
57-
graph.remove_nodes_with_label(0) temp
57+
graph.remove_nodes_with_label(0)
5858
return process_id, graph, split_percent
5959

6060

0 commit comments

Comments
 (0)