Skip to content

Commit 2b6d058

Browse files
anna-grimanna-grim
andauthored
Feat report summary (#158)
* doc: skeleton metrics and swc_loading * feat: fix label misalignments --------- Co-authored-by: anna-grim <[email protected]>
1 parent 379d24a commit 2b6d058

File tree

4 files changed

+81
-11
lines changed

4 files changed

+81
-11
lines changed

src/segmentation_skeleton_metrics/data_handling/graph_loading.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
1010
"""
1111

12+
from collections import deque
1213
from concurrent.futures import (
1314
as_completed, ProcessPoolExecutor, ThreadPoolExecutor
1415
)
@@ -346,6 +347,8 @@ def _label_graph(self, graph):
346347
for i, label in node_to_label.items():
347348
graph.node_labels[i] = label
348349

350+
GraphLoader.fix_label_misalignments(graph)
351+
349352
def get_patch_labels(self, graph, nodes):
350353
"""
351354
Gets the segment labels for a given set of nodes within a specified
@@ -395,8 +398,68 @@ def to_local_voxels(self, graph, i, offset):
395398
offset = np.array(offset)
396399
return tuple(voxel - offset)
397400

398-
def fix_label_misalignments(self, graph):
399-
pass
401+
@staticmethod
402+
def fix_label_misalignments(graph):
403+
"""
404+
Adjusts misalignments between the labeled graph and segmentation.
405+
406+
Parameters
407+
----------
408+
graph : LabeledGraph
409+
Graph to be searched.
410+
"""
411+
visited_edges = set()
412+
for i, j in deque(nx.dfs_edges(graph)):
413+
# Check whether to visit edge
414+
if frozenset({i, j}) in visited_edges:
415+
continue
416+
417+
# Visit edge
418+
if int(graph.node_labels[j]) == 0:
419+
GraphLoader.check_misalignment(graph, visited_edges, i, j)
420+
visited_edges.add(frozenset({i, j}))
421+
422+
@staticmethod
423+
def check_misalignment(graph, visited_edges, nb, root):
424+
"""
425+
Determines whether zero-valued label corresponds to a misalignment
426+
between the graph and segmentation mask.
427+
428+
Parameters
429+
----------
430+
graph : networkx.Graph
431+
Graph that represents a ground truth neuron.
432+
visited_edges : List[tuple]
433+
List of edges in "graph" that have been visited.
434+
nb : int
435+
Neighbor of "root".
436+
root : int
437+
Node where possible split starts (i.e. zero-valued label).
438+
"""
439+
# Search graph
440+
label_collisions = set()
441+
queue = deque([root])
442+
visited = set()
443+
while len(queue) > 0:
444+
# Visit node
445+
j = queue.popleft()
446+
label_j = int(graph.node_labels[j])
447+
if label_j != 0:
448+
label_collisions.add(label_j)
449+
visited.add(j)
450+
451+
# Update queue
452+
if label_j == 0:
453+
for k in graph.neighbors(j):
454+
if k not in visited:
455+
if frozenset({j, k}) not in visited_edges or k == nb:
456+
queue.append(k)
457+
visited_edges.add(frozenset({j, k}))
458+
459+
# Upd zero nodes
460+
if len(label_collisions) == 1:
461+
label = label_collisions.pop()
462+
graph.update_node_labels(visited, label)
400463

401464

402465
class LabelHandler:

src/segmentation_skeleton_metrics/data_handling/skeleton_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def run_length_from(self, root):
332332
visited.add(k)
333333
return run_length
334334

335-
def upd_node_labels(self, nodes, label):
335+
def update_node_labels(self, nodes, label):
336336
"""
337337
Updates the label of the given nodes with a specified label.
338338

src/segmentation_skeleton_metrics/evaluate.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def evaluate(
4747
verbose=True
4848
):
4949
"""
50-
...
50+
Loads data, calls an evaluator object that compute skeleton-based
51+
segmentation, and saves the results.
5152
5253
Parameters
5354
----------
@@ -164,19 +165,19 @@ def report_summary(self, results):
164165
# Averaged results
165166
filename = f"{self.results_filename}-overview.txt"
166167
path = os.path.join(self.output_dir, filename)
167-
util.update_txt(path, "Average Results...")
168+
util.update_txt(path, "Average Results...", self.verbose)
168169
for column in results.columns:
169170
if column != "SWC Run Length" and column != "SWC Name":
170171
avg = util.compute_weighted_avg(results, column)
171-
util.update_txt(path, f" {column}: {avg:.4f}")
172+
util.update_txt(path, f" {column}: {avg:.4f}", self.verbose)
172173

173174
# Total results
174175
n_splits = results["# Splits"].sum()
175-
util.update_txt(path, "\nTotal Results...")
176-
util.update_txt(path, f" # Splits: {n_splits}")
176+
util.update_txt(path, "\nTotal Results...", self.verbose)
177+
util.update_txt(path, f" # Splits: {n_splits}", self.verbose)
177178
if "# Merges" in results.columns:
178179
n_merges = results["# Merges"].sum()
179-
util.update_txt(path, f" # Merges: {n_merges}")
180+
util.update_txt(path, f" # Merges: {n_merges}", self.verbose)
180181

181182
# --- Writers ---
182183
def save_fragments(self):

src/segmentation_skeleton_metrics/utils/util.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def read_txt(path):
168168
return f.read().splitlines()
169169

170170

171-
def update_txt(path, text):
171+
def update_txt(path, text, verbose=True):
172172
"""
173173
Appends the given text to a specified text file and prints the text.
174174
@@ -178,8 +178,14 @@ def update_txt(path, text):
178178
Path to txt file where the text will be appended.
179179
text : str
180180
Text to be written to the file.
181+
verbose : bool, optional
182+
Indication of whether to printout text. Default is True.
181183
"""
182-
print(text)
184+
# Printout text (if applicable)
185+
if verbose:
186+
print(text)
187+
188+
# Update txt file
183189
with open(path, "a") as file:
184190
file.write(text + "\n")
185191

0 commit comments

Comments
 (0)