Skip to content

Commit a017dcd

Browse files
author
anna-grim
committed
refactor: optimized merge detection
1 parent 415ce5d commit a017dcd

File tree

3 files changed

+88
-112
lines changed

3 files changed

+88
-112
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 60 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ def __init__(
6666
anisotropy=(1.0, 1.0, 1.0),
6767
connections_path=None,
6868
fragments_pointer=None,
69-
localize_merges=False,
7069
preexisting_merges=None,
7170
save_merges=False,
7271
save_fragments=False,
@@ -97,9 +96,6 @@ def __init__(
9796
"swc_util.Reader" for documentation. Notes: (1) "anisotropy" is
9897
applied to these SWC files and (2) these SWC files are required
9998
for counting merges. The default is None.
100-
localize_merges : bool, optional
101-
Indication of whether to search for the approximate location of a
102-
merge. The default is False.
10399
preexisting_merges : List[int], optional
104100
List of segment IDs that are known to contain a merge mistake. The
105101
default is None.
@@ -122,7 +118,7 @@ def __init__(
122118
# Instance attributes
123119
self.anisotropy = anisotropy
124120
self.connections_path = connections_path
125-
self.localize_merges = localize_merges
121+
self.merge_sites = list()
126122
self.output_dir = output_dir
127123
self.preexisting_merges = preexisting_merges
128124
self.save_merges = save_merges
@@ -373,10 +369,9 @@ def init_writers(self):
373369
self.graphs[key].to_zipped_swc(self.fragment_writer[key])
374370

375371
# Merged fragments writer
376-
if self.save_merges or self.localize_merges:
372+
if self.save_merges:
377373
zip_path = os.path.join(self.output_dir, "merged_fragments.zip")
378374
self.merge_writer = ZipFile(zip_path, "a")
379-
self.merge_sites = list()
380375

381376
# -- Main Routine --
382377
def run(self):
@@ -410,12 +405,10 @@ def run(self):
410405
path = f"{self.output_dir}/{prefix}results.xls"
411406
util.save_results(path, full_results)
412407

413-
# Save merge sites (if applicable)
414-
if self.localize_merges:
415-
df = pd.DataFrame(self.merge_sites)
416-
df.to_csv(
417-
os.path.join(self.output_dir, "merge_sites.csv"), index=False
418-
)
408+
# Save merge sites
409+
df = pd.DataFrame(self.merge_sites)
410+
path = os.path.join(self.output_dir, "merge_sites.csv")
411+
df.to_csv(path, index=False)
419412

420413
# Report results overview
421414
path = os.path.join(self.output_dir, f"{prefix}results-overview.txt")
@@ -554,7 +547,7 @@ def count_merges(self, key, kdtree):
554547
555548
"""
556549
# Iterate over fragments that intersect with GT skeleton
557-
for label in self.get_node_labels(key):
550+
for label in tqdm(self.get_node_labels(key), desc="Merge Search"):
558551
nodes = self.graphs[key].nodes_with_label(label)
559552
if len(nodes) > 40:
560553
for label in self.label_handler.get_class(label):
@@ -583,46 +576,64 @@ def is_fragment_merge(self, key, label, kdtree):
583576
None
584577
585578
"""
586-
# Search graphs
587579
for fragment_graph in self.find_graph_from_label(label):
588-
# Search for merge
589-
max_dist = 0
590-
min_dist = np.inf
591-
for voxel in fragment_graph.voxels:
580+
if fragment_graph.run_length < 10**6:
581+
# Search for merge
582+
visited = set()
583+
for leaf in gutil.get_leafs(fragment_graph):
584+
# Check if leaf is far from ground truth
585+
voxel = fragment_graph.voxels[leaf]
586+
gt_voxel = util.kdtree_query(kdtree, voxel)
587+
if self.physical_dist(gt_voxel, voxel) > 50:
588+
has_merge, visited = self.find_merge_site(
589+
key, kdtree, fragment_graph, leaf, visited
590+
)
591+
if has_merge:
592+
break
593+
594+
# Save fragment (if applicable)
595+
if self.save_fragments:
596+
for node in fragment_graph.nodes:
597+
voxel = fragment_graph.voxels[node]
598+
gt_voxel = util.kdtree_query(kdtree, voxel)
599+
if self.physical_dist(gt_voxel, voxel) < 3:
600+
write_graph(fragment_graph, self.fragment_writer[key])
601+
break
602+
603+
def find_merge_site(self, key, kdtree, fragment_graph, source, visited):
604+
for _, node in nx.dfs_edges(fragment_graph, source=source):
605+
if node not in visited:
592606
# Find closest point in ground truth
607+
visited.add(node)
608+
voxel = fragment_graph.voxels[node]
593609
gt_voxel = util.kdtree_query(kdtree, voxel)
594-
595-
# Compute projection distance
596-
dist = self.physical_dist(gt_voxel, voxel)
597-
min_dist = min(dist, min_dist)
598-
max_dist = max(dist, max_dist)
599-
600-
# Check if distances imply merge mistake
601-
if max_dist > 100 and min_dist < 3:
610+
if self.physical_dist(gt_voxel, voxel) < 2:
602611
# Log merge mistake
603-
equiv_label = self.label_handler.get(label)
612+
segment_id = util.get_segment_id(fragment_graph.filename)
604613
xyz = img_util.to_physical(voxel, self.anisotropy)
605614
self.merge_cnt[key] += 1
606-
self.merged_labels.add((key, equiv_label, tuple(xyz)))
615+
self.merged_labels.add((key, segment_id, xyz))
616+
self.merge_sites.append(
617+
{
618+
"Segment_ID": segment_id,
619+
"Voxel": voxel,
620+
"XYZ": xyz,
621+
}
622+
)
607623

608624
# Save merged fragment (if applicable)
609625
if self.save_merges:
610-
fragment_graph.to_zipped_swc(self.merge_writer)
611-
if f"{key}.swc" not in self.merge_writer.namelist():
612-
self.gt_graphs[key].to_zipped_swc(
613-
self.merge_writer
614-
)
626+
# Save graphs
627+
write_graph(self.gt_graphs[key], self.merge_writer)
628+
write_graph(fragment_graph, self.merge_writer)
615629

616-
# Find approximate merge site
617-
if self.localize_merges:
618-
self.find_merge_site(key, fragment_graph, kdtree)
619-
620-
break
621-
622-
# Save fragment (if applicable)
623-
if self.save_fragments and min_dist < 3:
624-
if fragment_graph.filename not in self.merge_writer.namelist():
625-
fragment_graph.to_zipped_swc(self.fragment_writer[key])
630+
# Save merge site
631+
merge_cnt = np.sum(list(self.merge_cnt.values()))
632+
swc_util.to_zipped_point(
633+
self.merge_writer, f"merge-{merge_cnt}.swc", xyz
634+
)
635+
return True, visited
636+
return False, visited
626637

627638
def adjust_metrics(self, key):
628639
"""
@@ -718,45 +729,6 @@ def process_merge(self, key, label, xyz, update_merged_labels=True):
718729
if update_merged_labels:
719730
self.merged_labels.add((key, label, -1))
720731

721-
def find_merge_site(self, key, fragment_graph, kdtree):
722-
visited = set()
723-
hit = False
724-
for i, voxel in enumerate(fragment_graph.voxels):
725-
# Find closest point in ground truth
726-
visited.add(i)
727-
gt_voxel = util.kdtree_query(kdtree, voxel)
728-
729-
# Compute projection distance
730-
if self.physical_dist(gt_voxel, voxel) > 100:
731-
for _, j in nx.dfs_edges(fragment_graph, source=i):
732-
visited.add(j)
733-
voxel_j = fragment_graph.voxels[j]
734-
gt_voxel = util.kdtree_query(kdtree, voxel_j)
735-
if self.physical_dist(gt_voxel, voxel_j) < 2:
736-
# Save merge swc
737-
hit = True
738-
merge_cnt = np.sum(list(self.merge_cnt.values()))
739-
filename = f"merge-{merge_cnt}.swc"
740-
xyz = img_util.to_physical(voxel_j, self.anisotropy)
741-
swc_util.to_zipped_point(
742-
self.merge_writer, filename, xyz
743-
)
744-
745-
# Save merge in list
746-
segment_id = util.get_segment_id(fragment_graph.filename)
747-
self.merge_sites.append(
748-
{
749-
"Segment_ID": segment_id,
750-
"Voxel": voxel_j,
751-
"XYZ": xyz,
752-
}
753-
)
754-
break
755-
756-
# Check whether to continue
757-
if hit:
758-
break
759-
760732
def quantify_merges(self):
761733
"""
762734
Computes the percentage of merged edges for each graph.
@@ -775,30 +747,6 @@ def quantify_merges(self):
775747
n_edges = self.graphs[key].graph["n_edges"]
776748
self.merged_percent[key] = self.merged_edges_cnt[key] / n_edges
777749

778-
def save_merged_labels(self):
779-
"""
780-
Saves merged labels and their corresponding coordinates to a text
781-
file.
782-
783-
Parameters
784-
----------
785-
None
786-
787-
Returns
788-
-------
789-
None
790-
791-
"""
792-
# Save detected merges
793-
prefix = "corrected_" if self.connections_path else ""
794-
filename = f"merged_{prefix}segment_ids.txt"
795-
with open(os.path.join(self.output_dir, filename), "w") as f:
796-
f.write(f" Label - Physical Coordinate\n")
797-
for _, label, xyz in self.merged_labels:
798-
if self.label_handler.use_mapping():
799-
label = self.get_merged_label(label)
800-
f.write(f" {label} - {xyz}\n")
801-
802750
def get_merged_label(self, label):
803751
"""
804752
Retrieves the label present in the corrected fragments that
@@ -1127,3 +1075,8 @@ def generate_result(keys, stats):
11271075
11281076
"""
11291077
return [stats[key] for key in keys]
1078+
1079+
1080+
def write_graph(graph, writer):
1081+
if graph.filename not in writer.namelist():
1082+
graph.to_zipped_swc(writer)

src/segmentation_skeleton_metrics/utils/graph_util.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,3 +290,21 @@ def count_splits(graph):
290290
291291
"""
292292
return max(nx.number_connected_components(graph) - 1, 0)
293+
294+
295+
def get_leafs(graph):
296+
"""
297+
Gets all leafs nodes in the given graph.
298+
299+
Parameters
300+
----------
301+
graph : networkx.Graph
302+
Graph to be searched.
303+
304+
Returns
305+
-------
306+
List[int]
307+
Leaf nodes in the given graph.
308+
309+
"""
310+
return [node for node in graph.nodes if graph.degree[node] == 1]

src/segmentation_skeleton_metrics/utils/util.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from zipfile import ZipFile
1818

1919
import os
20+
import pandas as pd
2021
import shutil
2122

2223

@@ -247,6 +248,13 @@ def list_gcs_subdirectories(bucket_name, prefix):
247248
return subdirs
248249

249250

251+
def read_txt_from_gcs(bucket_name, file_name):
252+
client = storage.Client()
253+
bucket = client.bucket(bucket_name)
254+
blob = bucket.blob(file_name)
255+
return blob.download_as_text()
256+
257+
250258
def upload_directory_to_gcs(bucket_name, source_dir, destination_dir):
251259
client = storage.Client()
252260
bucket = client.bucket(bucket_name)
@@ -299,11 +307,8 @@ def load_merged_labels(path):
299307
Segment IDs that are known to contain a merge mistake.
300308
301309
"""
302-
merged_ids = list()
303-
for i, txt in enumerate(read_txt(path)):
304-
if i > 0:
305-
merged_ids.append(int(txt.split("-")[0]))
306-
return merged_ids
310+
df = pd.read_csv(path)
311+
return list(df["Segment_ID"])
307312

308313

309314
def load_valid_labels(path):

0 commit comments

Comments
 (0)