Skip to content

Commit c8d7b85

Browse files
anna-grimanna-grim
andauthored
Feat save merge upds (#129)
* refactor: updated saved fragments * feat: adjust merge site * filter via local search * removed testing blocks --------- Co-authored-by: anna-grim <[email protected]>
1 parent 6394996 commit c8d7b85

File tree

4 files changed

+149
-39
lines changed

4 files changed

+149
-39
lines changed

src/segmentation_skeleton_metrics/skeleton_graph.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,18 @@ class SkeletonGraph(nx.Graph):
3939
A 3D array that contains a voxel coordinate for each node.
4040
4141
"""
42-
43-
def __init__(self, anisotropy=(1.0, 1.0, 1.0)):
42+
colors = [
43+
"# COLOR 1.0 0.0 1.0", # pink
44+
"# COLOR 0.0 1.0 1.0", # cyan
45+
"# COLOR 1.0 1.0 0.0", # yellow
46+
"# COLOR 0.0 0.5 1.0", # blue
47+
"# COLOR 1.0 0.5 0.0", # orange
48+
"# COLOR 0.5 0.0 1.0", # purple
49+
"# COLOR 0.0 0.8 0.8", # teal
50+
"# COLOR 0.6 0.0 0.6", # plum
51+
]
52+
53+
def __init__(self, anisotropy=(1.0, 1.0, 1.0), is_groundtruth=False):
4454
"""
4555
Initializes a SkeletonGraph, including setting the anisotropy and
4656
initializing the run length attributes.
@@ -50,6 +60,9 @@ def __init__(self, anisotropy=(1.0, 1.0, 1.0)):
5060
anisotropy : ArrayLike, optional
5161
Image to physical coordinates scaling factors to account for the
5262
anisotropy of the microscope. The default is (1.0, 1.0, 1.0).
63+
is_groundtruth : bool, optional
64+
Indication of whether this graph corresponds to a ground truth
65+
tracing. The default is False.
5366
5467
Returns
5568
-------
@@ -62,6 +75,7 @@ def __init__(self, anisotropy=(1.0, 1.0, 1.0)):
6275
# Instance attributes
6376
self.anisotropy = np.array(anisotropy)
6477
self.filename = None
78+
self.is_groundtruth = is_groundtruth
6579
self.labels = None
6680
self.run_length = 0
6781
self.voxels = None
@@ -317,7 +331,7 @@ def upd_labels(self, nodes, label):
317331
for i in nodes:
318332
self.labels[i] = label
319333

320-
def to_zipped_swc(self, zip_writer, color=None):
334+
def to_zipped_swc(self, zip_writer):
321335
"""
322336
Writes the graph to an SWC file format, which is then stored in a ZIP
323337
archive.
@@ -337,12 +351,12 @@ def to_zipped_swc(self, zip_writer, color=None):
337351
"""
338352
with StringIO() as text_buffer:
339353
# Preamble
340-
text_buffer.write("# COLOR " + color) if color else None
341-
text_buffer.write("# id, type, z, y, x, r, pid")
354+
text_buffer.write(self.get_color())
355+
text_buffer.write("\n" + "# id, type, z, y, x, r, pid")
342356

343357
# Write entries
344358
node_to_idx = dict()
345-
r = 6 if color else 3
359+
r = 2 if self.is_groundtruth else 3
346360
for i, j in nx.dfs_edges(self):
347361
# Special Case: Root
348362
x, y, z = tuple(self.voxels[i] * self.anisotropy)
@@ -359,3 +373,10 @@ def to_zipped_swc(self, zip_writer, color=None):
359373

360374
# Finish
361375
zip_writer.writestr(self.filename, text_buffer.getvalue())
376+
377+
def get_color(self):
378+
if self.is_groundtruth:
379+
return "# COLOR 1.0 1.0 1.0"
380+
else:
381+
return util.sample_once(SkeletonGraph.colors)
382+

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 113 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __init__(
136136
self.load_fragments(fragments_pointer)
137137

138138
# Initialize metrics
139-
util.mkdir(output_dir, delete=True)
139+
util.mkdir(output_dir)
140140
self.init_writers()
141141
self.merge_sites = list()
142142

@@ -174,6 +174,7 @@ def load_groundtruth(self, swc_pointer):
174174
print("\n(1) Load Ground Truth")
175175
graph_builder = gutil.GraphBuilder(
176176
anisotropy=self.anisotropy,
177+
is_groundtruth=True,
177178
label_mask=self.label_mask,
178179
use_anisotropy=False,
179180
)
@@ -203,6 +204,7 @@ def load_fragments(self, swc_pointer):
203204
if swc_pointer:
204205
graph_builder = gutil.GraphBuilder(
205206
anisotropy=self.anisotropy,
207+
is_groundtruth=False,
206208
selected_ids=self.get_all_node_labels(),
207209
use_anisotropy=self.use_anisotropy,
208210
)
@@ -464,12 +466,13 @@ def detect_splits(self):
464466
n_missing = n_before - n_after
465467
p_omit = 100 * (n_missing + n_split_edges) / n_before
466468
p_split = 100 * n_split_edges / n_before
469+
gt_rl = graph.run_length
467470

468471
self.graphs[key] = graph
469-
self.metrics.at[key, "% Omit"] = p_omit
472+
self.metrics.at[key, "% Omit"] = round(p_omit, 2)
470473
self.metrics.at[key, "# Splits"] = gutil.count_splits(graph)
471-
self.metrics.loc[key, "% Split"] = p_split
472-
self.metrics.loc[key, "GT Run Length"] = graph.run_length
474+
self.metrics.loc[key, "% Split"] = round(p_split, 2)
475+
self.metrics.loc[key, "GT Run Length"] = round(gt_rl, 2)
473476
pbar.update(1)
474477

475478
# -- Merge Detection --
@@ -571,8 +574,8 @@ def is_fragment_merge(self, key, label, kdtree):
571574
for leaf in gutil.get_leafs(fragment_graph):
572575
voxel = fragment_graph.voxels[leaf]
573576
gt_voxel = util.kdtree_query(kdtree, voxel)
574-
if self.physical_dist(gt_voxel, voxel) > 50:
575-
visited = self.find_merge_site(
577+
if self.physical_dist(gt_voxel, voxel) > 60:
578+
self.find_merge_site(
576579
key, kdtree, fragment_graph, leaf, visited
577580
)
578581

@@ -599,28 +602,60 @@ def find_merge_site(self, key, kdtree, fragment_graph, source, visited):
599602
voxel = fragment_graph.voxels[node]
600603
gt_voxel = util.kdtree_query(kdtree, voxel)
601604
if self.physical_dist(gt_voxel, voxel) < 3:
602-
# Log merge mistake
603-
segment_id = util.get_segment_id(fragment_graph.filename)
604-
xyz = img_util.to_physical(voxel, self.anisotropy)
605-
self.merged_labels.add((key, segment_id, xyz))
606-
self.merge_sites.append(
607-
{
608-
"Segment_ID": segment_id,
609-
"GroundTruth_ID": key,
610-
"Voxel": tuple([int(t) for t in voxel]),
611-
"World": tuple([float(t) for t in xyz]),
612-
}
613-
)
605+
# Local search
606+
node = self.branch_search(fragment_graph, kdtree, node)
607+
voxel = fragment_graph.voxels[node]
614608

615-
# Save merged fragment (if applicable)
616-
if self.save_merges:
617-
gutil.write_graph(fragment_graph, self.merge_writer)
618-
gutil.write_graph(
619-
self.gt_graphs[key], self.merge_writer
620-
)
621-
return visited
622-
return visited
609+
# Log merge mistake
610+
if self.is_valid_merge(fragment_graph, kdtree, node):
611+
filename = fragment_graph.filename
612+
segment_id = util.get_segment_id(filename)
613+
xyz = img_util.to_physical(voxel, self.anisotropy)
614+
self.merged_labels.add((key, segment_id, xyz))
615+
self.merge_sites.append(
616+
{
617+
"Segment_ID": segment_id,
618+
"GroundTruth_ID": key,
619+
"Voxel": tuple([int(t) for t in voxel]),
620+
"World": tuple([float(t) for t in xyz]),
621+
}
622+
)
623623

624+
# Save merged fragment (if applicable)
625+
if self.save_merges:
626+
gutil.write_graph(
627+
fragment_graph, self.merge_writer
628+
)
629+
gutil.write_graph(
630+
self.gt_graphs[key], self.merge_writer
631+
)
632+
return
633+
634+
def is_valid_merge(self, graph, kdtree, root):
635+
n_hits = 0
636+
queue = list([(root, 0)])
637+
visited = set({root})
638+
while queue:
639+
# Visit node
640+
i, d_i = queue.pop()
641+
voxel_i = graph.voxels[i]
642+
gt_voxel = util.kdtree_query(kdtree, voxel_i)
643+
if self.physical_dist(gt_voxel, voxel_i) < 5:
644+
n_hits += 1
645+
646+
# Check whether to break
647+
if n_hits > 16:
648+
break
649+
650+
# Update queue
651+
for j in graph.neighbors(i):
652+
voxel_j = graph.voxels[j]
653+
d_j = d_i + self.physical_dist(voxel_i, voxel_j)
654+
if j not in visited and d_j < 30:
655+
queue.append((j, d_j))
656+
visited.add(j)
657+
return True if n_hits > 16 else False
658+
624659
def process_merge_sites(self):
625660
if self.merge_sites:
626661
# Remove duplicates
@@ -632,10 +667,13 @@ def process_merge_sites(self):
632667

633668
# Save merge sites
634669
if self.save_merges:
670+
row_names = list()
635671
for i in range(len(self.merge_sites)):
636672
filename = f"merge-{i + 1}.swc"
637673
xyz = self.merge_sites.iloc[i]["World"]
638674
swc_util.to_zipped_point(self.merge_writer, filename, xyz)
675+
row_names.append(filename)
676+
self.merge_sites.index = row_names
639677
self.merge_writer.close()
640678

641679
# Update counter
@@ -645,7 +683,7 @@ def process_merge_sites(self):
645683

646684
# Save results
647685
path = os.path.join(self.output_dir, "merge_sites.csv")
648-
self.merge_sites.to_csv(path, index=False)
686+
self.merge_sites.to_csv(path, index=True)
649687

650688

651689
def adjust_metrics(self, key):
@@ -757,7 +795,7 @@ def quantify_merges(self):
757795
"""
758796
for key in self.graphs:
759797
p = self.n_merged_edges[key] / self.graphs[key].graph["n_edges"]
760-
self.metrics.loc[key, "% Merged"] = 100 * p
798+
self.metrics.loc[key, "% Merged"] = round(100 * p, 2)
761799

762800
# -- Compute Metrics --
763801
def compute_edge_accuracy(self):
@@ -776,7 +814,8 @@ def compute_edge_accuracy(self):
776814
for key in self.graphs:
777815
p_omit = self.metrics.loc[key, "% Omit"]
778816
p_merged = self.metrics.loc[key, "% Merged"]
779-
self.metrics.loc[key, "Edge Accuracy"] = 100 - p_omit - p_merged
817+
edge_accuracy = round(100 - p_omit - p_merged, 2)
818+
self.metrics.loc[key, "Edge Accuracy"] = edge_accuracy
780819

781820
def compute_erl(self):
782821
"""
@@ -799,14 +838,57 @@ def compute_erl(self):
799838
wgt = run_lengths / max(np.sum(run_lengths), 1)
800839

801840
erl = np.sum(wgt * run_lengths)
802-
self.metrics.loc[key, "ERL"] = erl
803-
self.metrics.loc[key, "Normalized ERL"] = erl / max(run_length, 1)
841+
n_erl = round(erl / max(run_length, 1), 4)
842+
self.metrics.loc[key, "ERL"] = round(erl, 2)
843+
self.metrics.loc[key, "Normalized ERL"] = n_erl
804844

805845
def compute_weighted_avg(self, column_name):
806846
wgt = self.metrics["GT Run Length"]
807847
return (self.metrics[column_name] * wgt).sum() / wgt.sum()
808848

809849
# -- Helpers --
850+
def branch_search(self, graph, kdtree, root, radius=70):
851+
"""
852+
Searches for a branching node within distance "radius" from the given
853+
root node.
854+
855+
Parameters
856+
----------
857+
graph : networkx.Graph
858+
Graph to be searched.
859+
kdtree : ...
860+
KDTree containing voxel coordinates from a ground truth tracing.
861+
root : int
862+
Root of search.
863+
radius : float, optional
864+
Distance to search from root. The default is 70.
865+
866+
Returns
867+
-------
868+
int
869+
Root node or closest branching node within distance "radius".
870+
871+
"""
872+
queue = list([(root, 0)])
873+
visited = set({root})
874+
while queue:
875+
# Visit node
876+
i, d_i = queue.pop()
877+
voxel_i = graph.voxels[i]
878+
if graph.degree[i] > 2:
879+
gt_voxel = util.kdtree_query(kdtree, voxel_i)
880+
if self.physical_dist(gt_voxel, voxel_i) < 16:
881+
return i
882+
883+
# Update queue
884+
for j in graph.neighbors(i):
885+
voxel_j = graph.voxels[j]
886+
d_j = d_i + self.physical_dist(voxel_i, voxel_j)
887+
if j not in visited and d_j < radius:
888+
queue.append((j, d_j))
889+
visited.add(j)
890+
return root
891+
810892
def find_graph_from_label(self, label):
811893
graphs = list()
812894
for key in self.fragment_graphs:

src/segmentation_skeleton_metrics/utils/graph_util.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class GraphBuilder:
2727
def __init__(
2828
self,
2929
anisotropy=(1.0, 1.0, 1.0),
30+
is_groundtruth=False,
3031
label_mask=None,
3132
selected_ids=None,
3233
use_anisotropy=True,
@@ -39,6 +40,9 @@ def __init__(
3940
anisotropy : Tuple[int], optional
4041
Image to physical coordinates scaling factors to account for the
4142
anisotropy of the microscope. The default is [1.0, 1.0, 1.0].
43+
is_groundtruth : bool, optional
44+
Indication of whether this graph corresponds to a ground truth
45+
tracing. The default is False.
4246
label_mask : ImageReader, optional
4347
Predicted segmentation mask.
4448
selected_ids : Set[int], optional
@@ -55,6 +59,7 @@ def __init__(
5559
"""
5660
# Instance attributes
5761
self.anisotropy = anisotropy
62+
self.is_groundtruth = is_groundtruth
5863
self.label_mask = label_mask
5964

6065
# Reader
@@ -146,7 +151,9 @@ def to_graph(self, swc_dict):
146151
147152
"""
148153
# Initialize graph
149-
graph = SkeletonGraph(anisotropy=self.anisotropy)
154+
graph = SkeletonGraph(
155+
anisotropy=self.anisotropy, is_groundtruth=self.is_groundtruth
156+
)
150157
graph.init_voxels(swc_dict["voxel"])
151158
graph.set_filename(swc_dict["swc_id"] + ".swc")
152159
graph.set_nodes(len(swc_dict["id"]))

src/segmentation_skeleton_metrics/utils/swc_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ def to_zipped_point(zip_writer, filename, xyz):
528528
with StringIO() as text_buffer:
529529
# Preamble
530530
text_buffer.write("# COLOR 1.0 0.0 0.0")
531-
text_buffer.write("# id, type, z, y, x, r, pid")
531+
text_buffer.write("\n" + "# id, type, z, y, x, r, pid")
532532

533533
# Write entry
534534
x, y, z = tuple(xyz)

0 commit comments

Comments
 (0)