Skip to content

Commit ca87651

Browse files
authored
Merge branch 'main' into run_updates
2 parents a8bb53e + e21e29c commit ca87651

File tree

5 files changed

+151
-42
lines changed

5 files changed

+151
-42
lines changed

src/segmentation_skeleton_metrics/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
Package to evaluate a predicted segmentation.
33
"""
44

5-
__version__ = "4.16.31"
5+
__version__ = "5.0.3"

src/segmentation_skeleton_metrics/skeleton_graph.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,18 @@ class SkeletonGraph(nx.Graph):
3939
A 3D array that contains a voxel coordinate for each node.
4040
4141
"""
42-
43-
def __init__(self, anisotropy=(1.0, 1.0, 1.0)):
42+
colors = [
43+
"# COLOR 1.0 0.0 1.0", # pink
44+
"# COLOR 0.0 1.0 1.0", # cyan
45+
"# COLOR 1.0 1.0 0.0", # yellow
46+
"# COLOR 0.0 0.5 1.0", # blue
47+
"# COLOR 1.0 0.5 0.0", # orange
48+
"# COLOR 0.5 0.0 1.0", # purple
49+
"# COLOR 0.0 0.8 0.8", # teal
50+
"# COLOR 0.6 0.0 0.6", # plum
51+
]
52+
53+
def __init__(self, anisotropy=(1.0, 1.0, 1.0), is_groundtruth=False):
4454
"""
4555
Initializes a SkeletonGraph, including setting the anisotropy and
4656
initializing the run length attributes.
@@ -50,6 +60,9 @@ def __init__(self, anisotropy=(1.0, 1.0, 1.0)):
5060
anisotropy : ArrayLike, optional
5161
Image to physical coordinates scaling factors to account for the
5262
anisotropy of the microscope. The default is (1.0, 1.0, 1.0).
63+
is_groundtruth : bool, optional
64+
Indication of whether this graph corresponds to a ground truth
65+
tracing. The default is False.
5366
5467
Returns
5568
-------
@@ -62,6 +75,7 @@ def __init__(self, anisotropy=(1.0, 1.0, 1.0)):
6275
# Instance attributes
6376
self.anisotropy = np.array(anisotropy)
6477
self.filename = None
78+
self.is_groundtruth = is_groundtruth
6579
self.labels = None
6680
self.run_length = 0
6781
self.voxels = None
@@ -317,7 +331,7 @@ def upd_labels(self, nodes, label):
317331
for i in nodes:
318332
self.labels[i] = label
319333

320-
def to_zipped_swc(self, zip_writer, color=None):
334+
def to_zipped_swc(self, zip_writer):
321335
"""
322336
Writes the graph to an SWC file format, which is then stored in a ZIP
323337
archive.
@@ -337,12 +351,12 @@ def to_zipped_swc(self, zip_writer, color=None):
337351
"""
338352
with StringIO() as text_buffer:
339353
# Preamble
340-
text_buffer.write("# COLOR " + color) if color else None
341-
text_buffer.write("# id, type, z, y, x, r, pid")
354+
text_buffer.write(self.get_color())
355+
text_buffer.write("\n" + "# id, type, z, y, x, r, pid")
342356

343357
# Write entries
344358
node_to_idx = dict()
345-
r = 6 if color else 3
359+
r = 2 if self.is_groundtruth else 3
346360
for i, j in nx.dfs_edges(self):
347361
# Special Case: Root
348362
x, y, z = tuple(self.voxels[i] * self.anisotropy)
@@ -359,3 +373,10 @@ def to_zipped_swc(self, zip_writer, color=None):
359373

360374
# Finish
361375
zip_writer.writestr(self.filename, text_buffer.getvalue())
376+
377+
def get_color(self):
378+
if self.is_groundtruth:
379+
return "# COLOR 1.0 1.0 1.0"
380+
else:
381+
return util.sample_once(SkeletonGraph.colors)
382+

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 113 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def load_groundtruth(self, swc_pointer):
174174
print("\n(1) Load Ground Truth")
175175
graph_builder = gutil.GraphBuilder(
176176
anisotropy=self.anisotropy,
177+
is_groundtruth=True,
177178
label_mask=self.label_mask,
178179
use_anisotropy=False,
179180
)
@@ -203,6 +204,7 @@ def load_fragments(self, swc_pointer):
203204
if swc_pointer:
204205
graph_builder = gutil.GraphBuilder(
205206
anisotropy=self.anisotropy,
207+
is_groundtruth=False,
206208
selected_ids=self.get_all_node_labels(),
207209
use_anisotropy=self.use_anisotropy,
208210
)
@@ -464,12 +466,13 @@ def detect_splits(self):
464466
n_missing = n_before - n_after
465467
p_omit = 100 * (n_missing + n_split_edges) / n_before
466468
p_split = 100 * n_split_edges / n_before
469+
gt_rl = graph.run_length
467470

468471
self.graphs[key] = graph
469-
self.metrics.at[key, "% Omit"] = p_omit
472+
self.metrics.at[key, "% Omit"] = round(p_omit, 2)
470473
self.metrics.at[key, "# Splits"] = gutil.count_splits(graph)
471-
self.metrics.loc[key, "% Split"] = p_split
472-
self.metrics.loc[key, "GT Run Length"] = graph.run_length
474+
self.metrics.loc[key, "% Split"] = round(p_split, 2)
475+
self.metrics.loc[key, "GT Run Length"] = round(gt_rl, 2)
473476
pbar.update(1)
474477

475478
# -- Merge Detection --
@@ -571,8 +574,8 @@ def is_fragment_merge(self, key, label, kdtree):
571574
for leaf in gutil.get_leafs(fragment_graph):
572575
voxel = fragment_graph.voxels[leaf]
573576
gt_voxel = util.kdtree_query(kdtree, voxel)
574-
if self.physical_dist(gt_voxel, voxel) > 50:
575-
visited = self.find_merge_site(
577+
if self.physical_dist(gt_voxel, voxel) > 60:
578+
self.find_merge_site(
576579
key, kdtree, fragment_graph, leaf, visited
577580
)
578581

@@ -589,7 +592,7 @@ def is_fragment_merge(self, key, label, kdtree):
589592
else:
590593
segment_id = util.get_segment_id(fragment_graph.filename)
591594
self.merged_labels.add((key, segment_id, -1))
592-
print(f"Skipping {fragment_graph.filename} - run_length={fragment_graph.run_length}")
595+
print(f"Skipping {segment_id} - run_length={fragment_graph.run_length}")
593596

594597
def find_merge_site(self, key, kdtree, fragment_graph, source, visited):
595598
for _, node in nx.dfs_edges(fragment_graph, source=source):
@@ -599,28 +602,60 @@ def find_merge_site(self, key, kdtree, fragment_graph, source, visited):
599602
voxel = fragment_graph.voxels[node]
600603
gt_voxel = util.kdtree_query(kdtree, voxel)
601604
if self.physical_dist(gt_voxel, voxel) < 3:
602-
# Log merge mistake
603-
segment_id = util.get_segment_id(fragment_graph.filename)
604-
xyz = img_util.to_physical(voxel, self.anisotropy)
605-
self.merged_labels.add((key, segment_id, xyz))
606-
self.merge_sites.append(
607-
{
608-
"Segment_ID": segment_id,
609-
"GroundTruth_ID": key,
610-
"Voxel": tuple([int(t) for t in voxel]),
611-
"World": tuple([float(t) for t in xyz]),
612-
}
613-
)
605+
# Local search
606+
node = self.branch_search(fragment_graph, kdtree, node)
607+
voxel = fragment_graph.voxels[node]
614608

615-
# Save merged fragment (if applicable)
616-
if self.save_merges:
617-
gutil.write_graph(fragment_graph, self.merge_writer)
618-
gutil.write_graph(
619-
self.gt_graphs[key], self.merge_writer
620-
)
621-
return visited
622-
return visited
609+
# Log merge mistake
610+
if self.is_valid_merge(fragment_graph, kdtree, node):
611+
filename = fragment_graph.filename
612+
segment_id = util.get_segment_id(filename)
613+
xyz = img_util.to_physical(voxel, self.anisotropy)
614+
self.merged_labels.add((key, segment_id, xyz))
615+
self.merge_sites.append(
616+
{
617+
"Segment_ID": segment_id,
618+
"GroundTruth_ID": key,
619+
"Voxel": tuple([int(t) for t in voxel]),
620+
"World": tuple([float(t) for t in xyz]),
621+
}
622+
)
623623

624+
# Save merged fragment (if applicable)
625+
if self.save_merges:
626+
gutil.write_graph(
627+
fragment_graph, self.merge_writer
628+
)
629+
gutil.write_graph(
630+
self.gt_graphs[key], self.merge_writer
631+
)
632+
return
633+
634+
def is_valid_merge(self, graph, kdtree, root):
635+
n_hits = 0
636+
queue = list([(root, 0)])
637+
visited = set({root})
638+
while queue:
639+
# Visit node
640+
i, d_i = queue.pop()
641+
voxel_i = graph.voxels[i]
642+
gt_voxel = util.kdtree_query(kdtree, voxel_i)
643+
if self.physical_dist(gt_voxel, voxel_i) < 5:
644+
n_hits += 1
645+
646+
# Check whether to break
647+
if n_hits > 16:
648+
break
649+
650+
# Update queue
651+
for j in graph.neighbors(i):
652+
voxel_j = graph.voxels[j]
653+
d_j = d_i + self.physical_dist(voxel_i, voxel_j)
654+
if j not in visited and d_j < 30:
655+
queue.append((j, d_j))
656+
visited.add(j)
657+
return True if n_hits > 16 else False
658+
624659
def process_merge_sites(self):
625660
if self.merge_sites:
626661
# Remove duplicates
@@ -632,10 +667,13 @@ def process_merge_sites(self):
632667

633668
# Save merge sites
634669
if self.save_merges:
670+
row_names = list()
635671
for i in range(len(self.merge_sites)):
636672
filename = f"merge-{i + 1}.swc"
637673
xyz = self.merge_sites.iloc[i]["World"]
638674
swc_util.to_zipped_point(self.merge_writer, filename, xyz)
675+
row_names.append(filename)
676+
self.merge_sites.index = row_names
639677
self.merge_writer.close()
640678

641679
# Update counter
@@ -645,8 +683,7 @@ def process_merge_sites(self):
645683

646684
# Save results
647685
path = os.path.join(self.output_dir, "merge_sites.csv")
648-
self.merge_sites.to_csv(path, index=False)
649-
686+
self.merge_sites.to_csv(path, index=True)
650687

651688
def adjust_metrics(self, key):
652689
"""
@@ -761,7 +798,7 @@ def quantify_merges(self):
761798
for key in self.graphs:
762799
n_edges = max(self.graphs[key].graph["n_edges"], 1)
763800
p = self.n_merged_edges[key] / n_edges
764-
self.metrics.loc[key, "% Merged"] = 100 * p
801+
self.metrics.loc[key, "% Merged"] = round(100 * p, 2)
765802

766803
# -- Compute Metrics --
767804
def compute_edge_accuracy(self):
@@ -780,7 +817,8 @@ def compute_edge_accuracy(self):
780817
for key in self.graphs:
781818
p_omit = self.metrics.loc[key, "% Omit"]
782819
p_merged = self.metrics.loc[key, "% Merged"]
783-
self.metrics.loc[key, "Edge Accuracy"] = 100 - p_omit - p_merged
820+
edge_accuracy = round(100 - p_omit - p_merged, 2)
821+
self.metrics.loc[key, "Edge Accuracy"] = edge_accuracy
784822

785823
def compute_erl(self):
786824
"""
@@ -803,14 +841,57 @@ def compute_erl(self):
803841
wgt = run_lengths / max(np.sum(run_lengths), 1)
804842

805843
erl = np.sum(wgt * run_lengths)
806-
self.metrics.loc[key, "ERL"] = erl
807-
self.metrics.loc[key, "Normalized ERL"] = erl / max(run_length, 1)
844+
n_erl = round(erl / max(run_length, 1), 4)
845+
self.metrics.loc[key, "ERL"] = round(erl, 2)
846+
self.metrics.loc[key, "Normalized ERL"] = n_erl
808847

809848
def compute_weighted_avg(self, column_name):
810849
wgt = self.metrics["GT Run Length"]
811850
return (self.metrics[column_name] * wgt).sum() / wgt.sum()
812851

813852
# -- Helpers --
853+
def branch_search(self, graph, kdtree, root, radius=70):
854+
"""
855+
Searches for a branching node within distance "radius" from the given
856+
root node.
857+
858+
Parameters
859+
----------
860+
graph : networkx.Graph
861+
Graph to be searched.
862+
kdtree : ...
863+
KDTree containing voxel coordinates from a ground truth tracing.
864+
root : int
865+
Root of search.
866+
radius : float, optional
867+
Distance to search from root. The default is 70.
868+
869+
Returns
870+
-------
871+
int
872+
Root node or closest branching node within distance "radius".
873+
874+
"""
875+
queue = list([(root, 0)])
876+
visited = set({root})
877+
while queue:
878+
# Visit node
879+
i, d_i = queue.pop()
880+
voxel_i = graph.voxels[i]
881+
if graph.degree[i] > 2:
882+
gt_voxel = util.kdtree_query(kdtree, voxel_i)
883+
if self.physical_dist(gt_voxel, voxel_i) < 16:
884+
return i
885+
886+
# Update queue
887+
for j in graph.neighbors(i):
888+
voxel_j = graph.voxels[j]
889+
d_j = d_i + self.physical_dist(voxel_i, voxel_j)
890+
if j not in visited and d_j < radius:
891+
queue.append((j, d_j))
892+
visited.add(j)
893+
return root
894+
814895
def find_graph_from_label(self, label):
815896
graphs = list()
816897
for key in self.fragment_graphs:

src/segmentation_skeleton_metrics/utils/graph_util.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class GraphBuilder:
2727
def __init__(
2828
self,
2929
anisotropy=(1.0, 1.0, 1.0),
30+
is_groundtruth=False,
3031
label_mask=None,
3132
selected_ids=None,
3233
use_anisotropy=True,
@@ -39,6 +40,9 @@ def __init__(
3940
anisotropy : Tuple[int], optional
4041
Image to physical coordinates scaling factors to account for the
4142
anisotropy of the microscope. The default is [1.0, 1.0, 1.0].
43+
is_groundtruth : bool, optional
44+
Indication of whether this graph corresponds to a ground truth
45+
tracing. The default is False.
4246
label_mask : ImageReader, optional
4347
Predicted segmentation mask.
4448
selected_ids : Set[int], optional
@@ -55,6 +59,7 @@ def __init__(
5559
"""
5660
# Instance attributes
5761
self.anisotropy = anisotropy
62+
self.is_groundtruth = is_groundtruth
5863
self.label_mask = label_mask
5964

6065
# Reader
@@ -146,7 +151,9 @@ def to_graph(self, swc_dict):
146151
147152
"""
148153
# Initialize graph
149-
graph = SkeletonGraph(anisotropy=self.anisotropy)
154+
graph = SkeletonGraph(
155+
anisotropy=self.anisotropy, is_groundtruth=self.is_groundtruth
156+
)
150157
graph.init_voxels(swc_dict["voxel"])
151158
graph.set_filename(swc_dict["swc_id"] + ".swc")
152159
graph.set_nodes(len(swc_dict["id"]))

src/segmentation_skeleton_metrics/utils/swc_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,8 @@ def to_zipped_point(zip_writer, filename, xyz):
527527
"""
528528
with StringIO() as text_buffer:
529529
# Preamble
530-
text_buffer.write("# COLOR 1.0 0.0 1.0")
531-
text_buffer.write("# id, type, z, y, x, r, pid")
530+
text_buffer.write("# COLOR 1.0 0.0 0.0")
531+
text_buffer.write("\n" + "# id, type, z, y, x, r, pid")
532532

533533
# Write entry
534534
x, y, z = tuple(xyz)

0 commit comments

Comments
 (0)