Skip to content

Commit 0b1fe31

Browse files
author
anna-grim
committed
feat: adjust merge site
1 parent 78ad541 commit 0b1fe31

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __init__(
136136
self.load_fragments(fragments_pointer)
137137

138138
# Initialize metrics
139-
util.mkdir(output_dir, delete=True)
139+
util.mkdir(output_dir)
140140
self.init_writers()
141141
self.merge_sites = list()
142142

@@ -602,6 +602,10 @@ def find_merge_site(self, key, kdtree, fragment_graph, source, visited):
602602
voxel = fragment_graph.voxels[node]
603603
gt_voxel = util.kdtree_query(kdtree, voxel)
604604
if self.physical_dist(gt_voxel, voxel) < 3:
605+
# Local search
606+
node = self.branch_search(fragment_graph, kdtree, node)
607+
voxel = fragment_graph.voxels[node]
608+
605609
# Log merge mistake
606610
segment_id = util.get_segment_id(fragment_graph.filename)
607611
xyz = img_util.to_physical(voxel, self.anisotropy)
@@ -815,6 +819,48 @@ def compute_weighted_avg(self, column_name):
815819
return (self.metrics[column_name] * wgt).sum() / wgt.sum()
816820

817821
# -- Helpers --
822+
def branch_search(self, graph, kdtree, root, radius=70):
823+
"""
824+
Searches for a branching node within distance "radius" from the given
825+
root node.
826+
827+
Parameters
828+
----------
829+
graph : networkx.Graph
830+
Graph to be searched.
831+
kdtree : ...
832+
KDTree containing voxel coordinates from a ground truth tracing.
833+
root : int
834+
Root of search.
835+
radius : float, optional
836+
Distance to search from root. The default is 70.
837+
838+
Returns
839+
-------
840+
int
841+
Root node or closest branching node within distance "radius".
842+
843+
"""
844+
queue = list([(root, 0)])
845+
visited = set({root})
846+
while queue:
847+
# Visit node
848+
i, d_i = queue.pop()
849+
voxel_i = graph.voxels[i]
850+
if graph.degree[i] > 2:
851+
gt_voxel = util.kdtree_query(kdtree, voxel_i)
852+
if self.physical_dist(gt_voxel, voxel_i) < 16:
853+
return i
854+
855+
# Update queue
856+
for j in graph.neighbors(i):
857+
voxel_j = graph.voxels[j]
858+
d_j = d_i + self.physical_dist(voxel_i, voxel_j)
859+
if j not in visited and d_j < radius:
860+
queue.append((j, d_j))
861+
visited.add(j)
862+
return root
863+
818864
def find_graph_from_label(self, label):
819865
graphs = list()
820866
for key in self.fragment_graphs:

src/segmentation_skeleton_metrics/utils/swc_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def read_from_gcs_swcs(self, bucket_name, swc_paths):
322322
with ThreadPoolExecutor() as executor:
323323
# Assign threads
324324
threads = list()
325-
for path in swc_paths:
325+
for path in swc_paths[0:16]: # temp
326326
threads.append(
327327
executor.submit(self.read_from_gcs_swc, bucket_name, path)
328328
)
@@ -527,7 +527,7 @@ def to_zipped_point(zip_writer, filename, xyz):
527527
"""
528528
with StringIO() as text_buffer:
529529
# Preamble
530-
text_buffer.write("# COLOR 1.0 0.0 0.0")
530+
text_buffer.write("# COLOR 1.0 0.0 1.0")
531531
text_buffer.write("\n" + "# id, type, z, y, x, r, pid")
532532

533533
# Write entry

0 commit comments

Comments
 (0)