Skip to content

Commit 339dc7c

Browse files
author
anna-grim
committed
feat: localize merges
1 parent 969a13b commit 339dc7c

File tree

1 file changed

+80
-31
lines changed

1 file changed

+80
-31
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 80 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,10 @@ def __init__(
6363
anisotropy=(1.0, 1.0, 1.0),
6464
connections_path=None,
6565
fragments_pointer=None,
66+
localize_merge=False,
6667
preexisting_merges=None,
6768
save_merges=False,
69+
save_projections=False,
6870
valid_labels=None,
6971
):
7072
"""
@@ -92,12 +94,18 @@ def __init__(
9294
"swc_util.Reader" for documentation. Notes: (1) "anisotropy" is
9395
applied to these SWC files and (2) these SWC files are required
9496
for counting merges. The default is None.
97+
localize_merge : bool, optional
98+
Indication of whether to search for the approximate location of a
99+
merge. The default is False.
95100
preexisting_merges : List[int], optional
96101
List of segment IDs that are known to contain a merge mistake. The
97102
default is None.
98103
save_merges: bool, optional
99104
Indication of whether to save fragments with a merge mistake. The
100105
default is None.
106+
save_projections : bool, optional
107+
Indication of whether to save fragments that project onto each
108+
ground truth skeleton. The default is False.
101109
valid_labels : set[int], optional
102110
Segment IDs that can be assigned to nodes. This argument accounts
103111
for segments that were been removed due to some type of filtering.
@@ -111,9 +119,11 @@ def __init__(
111119
# Instance attributes
112120
self.anisotropy = anisotropy
113121
self.connections_path = connections_path
122+
self.localize_merge = localize_merge
114123
self.output_dir = output_dir
115124
self.preexisting_merges = preexisting_merges
116125
self.save_merges = save_merges
126+
self.save_projections = save_projections
117127

118128
# Label handler
119129
self.label_handler = gutil.LabelHandler(
@@ -129,6 +139,11 @@ def __init__(
129139
if self.save_merges:
130140
self.init_zip_writer()
131141

142+
# Initialize fragment projections directory
143+
if self.save_projections:
144+
self.projections_dir = os.path.join(output_dir, "projections")
145+
util.mkdir(self.projections_dir)
146+
132147
# --- Load Data ---
133148
def load_groundtruth(self, swc_pointer):
134149
"""
@@ -346,13 +361,13 @@ def init_zip_writer(self):
346361
347362
"""
348363
# Initialize output directory
349-
projections_dir = os.path.join(self.output_dir, "projections")
350-
util.mkdir(projections_dir)
364+
merged_fragments_dir = os.path.join(self.output_dir, "merged_fragments")
365+
util.mkdir(merged_fragments_dir)
351366

352367
# Save intial graphs
353368
self.zip_writer = dict()
354369
for key in self.graphs.keys():
355-
zip_path = f"{projections_dir}/{key}.zip"
370+
zip_path = f"{merged_fragments_dir}/{key}.zip"
356371
self.zip_writer[key] = ZipFile(zip_path, "w")
357372
self.graphs[key].to_zipped_swc(self.zip_writer[key])
358373

@@ -524,12 +539,21 @@ def count_merges(self, key, kdtree):
524539
None
525540
526541
"""
542+
# Initialize zip writer
543+
if self.save_projections:
544+
zip_path = os.path.join(self.projections_dir, key + ".zip")
545+
zip_writer = ZipFile(zip_path, "w")
546+
547+
# Iterate over fragments that intersect with GT skeleton
527548
for label in self.get_node_labels(key):
528549
nodes = self.graphs[key].nodes_with_label(label)
529550
if len(nodes) > 40:
530551
for label in self.label_handler.get_class(label):
531552
if label in self.fragment_ids:
532553
self.is_fragment_merge(key, label, kdtree)
554+
if self.save_projections:
555+
fragment_graph = self.find_graph_from_label(label)[0]
556+
fragment_graph.to_zipped_swc(zip_writer)
533557

534558
def is_fragment_merge(self, key, label, kdtree):
535559
"""
@@ -553,32 +577,33 @@ def is_fragment_merge(self, key, label, kdtree):
553577
None
554578
555579
"""
556-
fragment_graph = self.find_graph_from_label(label)
557-
558-
max_dist = 0
559-
min_dist = np.inf
560-
561-
for voxel in fragment_graph.voxels:
562-
# Find closest point in ground truth
563-
gt_voxel = util.kdtree_query(kdtree, voxel)
564-
565-
# Compute projection distance
566-
dist = self.physical_dist(gt_voxel, voxel)
567-
min_dist = min(dist, min_dist)
568-
max_dist = max(dist, max_dist)
569-
570-
# Check if distances imply merge mistake
571-
if max_dist > 100 and min_dist < 3:
572-
# Log merge mistake
573-
equiv_label = self.label_handler.get(label)
574-
xyz = img_util.to_physical(voxel, self.anisotropy)
575-
self.merge_cnt[key] += 1
576-
self.merged_labels.add((key, equiv_label, tuple(xyz)))
577-
578-
# Save merged fragment (if applicable)
579-
if self.save_merges:
580-
fragment_graph.to_zipped_swc(self.zip_writer[key])
581-
break
580+
# Search graphs
581+
for fragment_graph in self.find_graph_from_label(label):
582+
max_dist = 0
583+
min_dist = np.inf
584+
for voxel in fragment_graph.voxels:
585+
# Find closest point in ground truth
586+
gt_voxel = util.kdtree_query(kdtree, voxel)
587+
588+
# Compute projection distance
589+
dist = self.physical_dist(gt_voxel, voxel)
590+
min_dist = min(dist, min_dist)
591+
max_dist = max(dist, max_dist)
592+
593+
# Check if distances imply merge mistake
594+
if max_dist > 100 and min_dist < 3:
595+
# Log merge mistake
596+
equiv_label = self.label_handler.get(label)
597+
xyz = img_util.to_physical(voxel, self.anisotropy)
598+
self.merge_cnt[key] += 1
599+
self.merged_labels.add((key, equiv_label, tuple(xyz)))
600+
601+
# Save merged fragment (if applicable)
602+
if self.save_merges:
603+
fragment_graph.to_zipped_swc(self.zip_writer[key])
604+
if self.localize_merge:
605+
self.find_merge_site(key, fragment_graph, kdtree)
606+
break
582607

583608
def adjust_metrics(self, key):
584609
"""
@@ -674,6 +699,29 @@ def process_merge(self, key, label, xyz, update_merged_labels=True):
674699
if update_merged_labels:
675700
self.merged_labels.add((key, label, -1))
676701

702+
def find_merge_site(self, key, fragment_graph, kdtree):
703+
visited = set()
704+
hit = False
705+
for i, voxel in enumerate(fragment_graph.voxels):
706+
# Find closest point in ground truth
707+
visited.add(i)
708+
gt_voxel = util.kdtree_query(kdtree, voxel)
709+
710+
# Compute projection distance
711+
if self.physical_dist(gt_voxel, voxel) > 100:
712+
for _, j in nx.dfs_edges(fragment_graph, source=i):
713+
visited.add(j)
714+
voxel_j = fragment_graph.voxels[j]
715+
gt_voxel = util.kdtree_query(kdtree, voxel_j)
716+
if self.physical_dist(gt_voxel, voxel_j) < 2:
717+
hit = True
718+
print("Approximate Site:", img_util.to_physical(voxel_j, self.anisotropy))
719+
break
720+
721+
# Check whether to continue
722+
if hit:
723+
break
724+
677725
def quantify_merges(self):
678726
"""
679727
Computes the percentage of merged edges for each graph.
@@ -966,10 +1014,11 @@ def list_metrics(self):
9661014

9671015
# -- Helpers --
9681016
def find_graph_from_label(self, label):
1017+
graphs = list()
9691018
for key in self.fragment_graphs:
9701019
if label == util.get_segment_id(key):
971-
return self.fragment_graphs[key]
972-
return None
1020+
graphs.append(self.fragment_graphs[key])
1021+
return graphs
9731022

9741023
def physical_dist(self, voxel_1, voxel_2):
9751024
"""

0 commit comments

Comments
 (0)