Skip to content

Commit 78ad541

Browse files
author
anna-grim
committed
refactor: updated saved fragments
1 parent 6394996 commit 78ad541

File tree

4 files changed

+53
-17
lines changed

4 files changed

+53
-17
lines changed

src/segmentation_skeleton_metrics/skeleton_graph.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,18 @@ class SkeletonGraph(nx.Graph):
3939
A 3D array that contains a voxel coordinate for each node.
4040
4141
"""
42-
43-
def __init__(self, anisotropy=(1.0, 1.0, 1.0)):
42+
colors = [
43+
"# COLOR 1.0 0.0 1.0", # pink
44+
"# COLOR 0.0 1.0 1.0", # cyan
45+
"# COLOR 1.0 1.0 0.0", # yellow
46+
"# COLOR 0.0 0.5 1.0", # blue
47+
"# COLOR 1.0 0.5 0.0", # orange
48+
"# COLOR 0.5 0.0 1.0", # purple
49+
"# COLOR 0.0 0.8 0.8", # teal
50+
"# COLOR 0.6 0.0 0.6", # plum
51+
]
52+
53+
def __init__(self, anisotropy=(1.0, 1.0, 1.0), is_groundtruth=False):
4454
"""
4555
Initializes a SkeletonGraph, including setting the anisotropy and
4656
initializing the run length attributes.
@@ -50,6 +60,9 @@ def __init__(self, anisotropy=(1.0, 1.0, 1.0)):
5060
anisotropy : ArrayLike, optional
5161
Image to physical coordinates scaling factors to account for the
5262
anisotropy of the microscope. The default is (1.0, 1.0, 1.0).
63+
is_groundtruth : bool, optional
64+
Indication of whether this graph corresponds to a ground truth
65+
tracing. The default is False.
5366
5467
Returns
5568
-------
@@ -62,6 +75,7 @@ def __init__(self, anisotropy=(1.0, 1.0, 1.0)):
6275
# Instance attributes
6376
self.anisotropy = np.array(anisotropy)
6477
self.filename = None
78+
self.is_groundtruth = is_groundtruth
6579
self.labels = None
6680
self.run_length = 0
6781
self.voxels = None
@@ -317,7 +331,7 @@ def upd_labels(self, nodes, label):
317331
for i in nodes:
318332
self.labels[i] = label
319333

320-
def to_zipped_swc(self, zip_writer, color=None):
334+
def to_zipped_swc(self, zip_writer):
321335
"""
322336
Writes the graph to an SWC file format, which is then stored in a ZIP
323337
archive.
@@ -337,12 +351,12 @@ def to_zipped_swc(self, zip_writer, color=None):
337351
"""
338352
with StringIO() as text_buffer:
339353
# Preamble
340-
text_buffer.write("# COLOR " + color) if color else None
341-
text_buffer.write("# id, type, z, y, x, r, pid")
354+
text_buffer.write(self.get_color())
355+
text_buffer.write("\n" + "# id, type, z, y, x, r, pid")
342356

343357
# Write entries
344358
node_to_idx = dict()
345-
r = 6 if color else 3
359+
r = 2 if self.is_groundtruth else 3
346360
for i, j in nx.dfs_edges(self):
347361
# Special Case: Root
348362
x, y, z = tuple(self.voxels[i] * self.anisotropy)
@@ -359,3 +373,10 @@ def to_zipped_swc(self, zip_writer, color=None):
359373

360374
# Finish
361375
zip_writer.writestr(self.filename, text_buffer.getvalue())
376+
377+
def get_color(self):
378+
if self.is_groundtruth:
379+
return "# COLOR 1.0 1.0 1.0"
380+
else:
381+
return util.sample_once(SkeletonGraph.colors)
382+

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def load_groundtruth(self, swc_pointer):
174174
print("\n(1) Load Ground Truth")
175175
graph_builder = gutil.GraphBuilder(
176176
anisotropy=self.anisotropy,
177+
is_groundtruth=True,
177178
label_mask=self.label_mask,
178179
use_anisotropy=False,
179180
)
@@ -203,6 +204,7 @@ def load_fragments(self, swc_pointer):
203204
if swc_pointer:
204205
graph_builder = gutil.GraphBuilder(
205206
anisotropy=self.anisotropy,
207+
is_groundtruth=False,
206208
selected_ids=self.get_all_node_labels(),
207209
use_anisotropy=self.use_anisotropy,
208210
)
@@ -464,12 +466,13 @@ def detect_splits(self):
464466
n_missing = n_before - n_after
465467
p_omit = 100 * (n_missing + n_split_edges) / n_before
466468
p_split = 100 * n_split_edges / n_before
469+
gt_rl = graph.run_length
467470

468471
self.graphs[key] = graph
469-
self.metrics.at[key, "% Omit"] = p_omit
472+
self.metrics.at[key, "% Omit"] = round(p_omit, 2)
470473
self.metrics.at[key, "# Splits"] = gutil.count_splits(graph)
471-
self.metrics.loc[key, "% Split"] = p_split
472-
self.metrics.loc[key, "GT Run Length"] = graph.run_length
474+
self.metrics.loc[key, "% Split"] = round(p_split, 2)
475+
self.metrics.loc[key, "GT Run Length"] = round(gt_rl, 2)
473476
pbar.update(1)
474477

475478
# -- Merge Detection --
@@ -571,7 +574,7 @@ def is_fragment_merge(self, key, label, kdtree):
571574
for leaf in gutil.get_leafs(fragment_graph):
572575
voxel = fragment_graph.voxels[leaf]
573576
gt_voxel = util.kdtree_query(kdtree, voxel)
574-
if self.physical_dist(gt_voxel, voxel) > 50:
577+
if self.physical_dist(gt_voxel, voxel) > 60:
575578
visited = self.find_merge_site(
576579
key, kdtree, fragment_graph, leaf, visited
577580
)
@@ -632,10 +635,13 @@ def process_merge_sites(self):
632635

633636
# Save merge sites
634637
if self.save_merges:
638+
row_names = list()
635639
for i in range(len(self.merge_sites)):
636640
filename = f"merge-{i + 1}.swc"
637641
xyz = self.merge_sites.iloc[i]["World"]
638642
swc_util.to_zipped_point(self.merge_writer, filename, xyz)
643+
row_names.append(filename)
644+
self.merge_sites.index = row_names
639645
self.merge_writer.close()
640646

641647
# Update counter
@@ -645,7 +651,7 @@ def process_merge_sites(self):
645651

646652
# Save results
647653
path = os.path.join(self.output_dir, "merge_sites.csv")
648-
self.merge_sites.to_csv(path, index=False)
654+
self.merge_sites.to_csv(path, index=True)
649655

650656

651657
def adjust_metrics(self, key):
@@ -757,7 +763,7 @@ def quantify_merges(self):
757763
"""
758764
for key in self.graphs:
759765
p = self.n_merged_edges[key] / self.graphs[key].graph["n_edges"]
760-
self.metrics.loc[key, "% Merged"] = 100 * p
766+
self.metrics.loc[key, "% Merged"] = round(100 * p, 2)
761767

762768
# -- Compute Metrics --
763769
def compute_edge_accuracy(self):
@@ -776,7 +782,8 @@ def compute_edge_accuracy(self):
776782
for key in self.graphs:
777783
p_omit = self.metrics.loc[key, "% Omit"]
778784
p_merged = self.metrics.loc[key, "% Merged"]
779-
self.metrics.loc[key, "Edge Accuracy"] = 100 - p_omit - p_merged
785+
edge_accuracy = round(100 - p_omit - p_merged, 2)
786+
self.metrics.loc[key, "Edge Accuracy"] = edge_accuracy
780787

781788
def compute_erl(self):
782789
"""
@@ -799,8 +806,9 @@ def compute_erl(self):
799806
wgt = run_lengths / max(np.sum(run_lengths), 1)
800807

801808
erl = np.sum(wgt * run_lengths)
802-
self.metrics.loc[key, "ERL"] = erl
803-
self.metrics.loc[key, "Normalized ERL"] = erl / max(run_length, 1)
809+
n_erl = round(erl / max(run_length, 1), 4)
810+
self.metrics.loc[key, "ERL"] = round(erl, 2)
811+
self.metrics.loc[key, "Normalized ERL"] = n_erl
804812

805813
def compute_weighted_avg(self, column_name):
806814
wgt = self.metrics["GT Run Length"]

src/segmentation_skeleton_metrics/utils/graph_util.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class GraphBuilder:
2727
def __init__(
2828
self,
2929
anisotropy=(1.0, 1.0, 1.0),
30+
is_groundtruth=False,
3031
label_mask=None,
3132
selected_ids=None,
3233
use_anisotropy=True,
@@ -39,6 +40,9 @@ def __init__(
3940
anisotropy : Tuple[int], optional
4041
Image to physical coordinates scaling factors to account for the
4142
anisotropy of the microscope. The default is [1.0, 1.0, 1.0].
43+
is_groundtruth : bool, optional
44+
Indication of whether this graph corresponds to a ground truth
45+
tracing. The default is False.
4246
label_mask : ImageReader, optional
4347
Predicted segmentation mask.
4448
selected_ids : Set[int], optional
@@ -55,6 +59,7 @@ def __init__(
5559
"""
5660
# Instance attributes
5761
self.anisotropy = anisotropy
62+
self.is_groundtruth = is_groundtruth
5863
self.label_mask = label_mask
5964

6065
# Reader
@@ -146,7 +151,9 @@ def to_graph(self, swc_dict):
146151
147152
"""
148153
# Initialize graph
149-
graph = SkeletonGraph(anisotropy=self.anisotropy)
154+
graph = SkeletonGraph(
155+
anisotropy=self.anisotropy, is_groundtruth=self.is_groundtruth
156+
)
150157
graph.init_voxels(swc_dict["voxel"])
151158
graph.set_filename(swc_dict["swc_id"] + ".swc")
152159
graph.set_nodes(len(swc_dict["id"]))

src/segmentation_skeleton_metrics/utils/swc_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ def to_zipped_point(zip_writer, filename, xyz):
528528
with StringIO() as text_buffer:
529529
# Preamble
530530
text_buffer.write("# COLOR 1.0 0.0 0.0")
531-
text_buffer.write("# id, type, z, y, x, r, pid")
531+
text_buffer.write("\n" + "# id, type, z, y, x, r, pid")
532532

533533
# Write entry
534534
x, y, z = tuple(xyz)

0 commit comments

Comments
 (0)