Skip to content

Commit a1e6279

Browse files
anna-grimanna-grim
andauthored
feat: compute erl (#35)
Co-authored-by: anna-grim <[email protected]>
1 parent 8abb19e commit a1e6279

File tree

3 files changed

+38
-22
lines changed

3 files changed

+38
-22
lines changed

src/segmentation_skeleton_metrics/graph_utils.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
77
88
"""
9+
from scipy.spatial.distance import euclidean as dist
910
from random import sample
1011

12+
import math
1113
import networkx as nx
1214

1315

@@ -162,44 +164,50 @@ def count_splits(graph):
162164

163165

164166
def compute_run_lengths(graph):
165-
pass
167+
run_lengths = []
168+
for nodes in nx.connected_components(graph):
169+
subgraph = graph.subgraph(nodes)
170+
run_lengths.append(compute_path_length(subgraph))
171+
return run_lengths
166172

167173

168-
# -- miscellaneous --
169-
def sample_leaf(graph):
174+
def compute_path_length(graph):
170175
"""
171-
Samples leaf node from "graph".
176+
Computes path length of graph.
172177
173178
Parameters
174179
----------
175180
graph : networkx.Graph
176-
Graph to be sampled from.
181+
Graph to be parsed.
177182
178183
Returns
179184
-------
180-
int
181-
Leaf node of "graph"
185+
path_length : float
186+
Path length of graph.
182187
183188
"""
184-
leafs = [i for i in graph.nodes if graph.degree[i] == 1]
185-
return sample(leafs, 1)[0]
186-
189+
path_length = 0
190+
for i, j in nx.dfs_edges(graph):
191+
xyz_1 = graph.nodes[i]["xyz"]
192+
xyz_2 = graph.nodes[j]["xyz"]
193+
path_length += dist(xyz_1, xyz_2)
194+
return path_length
187195

188-
def empty_copy(graph):
196+
# -- miscellaneous --
197+
def sample_leaf(graph):
189198
"""
190-
Creates a copy of "graph" that does not contain the node level attributes.
199+
Samples leaf node from "graph".
191200
192201
Parameters
193202
----------
194203
graph : networkx.Graph
195-
Graph to be copied.
204+
Graph to be sampled from.
196205
197206
Returns
198207
-------
199-
graph : netowrkx.Graph
200-
Copy of "graph" that does not contain its node level attributes.
208+
int
209+
Leaf node of "graph"
210+
201211
"""
202-
graph_copy = nx.Graph(graph, pred_ids=set())
203-
for i in graph_copy.nodes():
204-
graph_copy.nodes[i].clear()
205-
return graph_copy
212+
leafs = [i for i in graph.nodes if graph.degree[i] == 1]
213+
return sample(leafs, 1)[0]

src/segmentation_skeleton_metrics/skeleton_metric.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,12 @@ def label_graph(self, target_graph):
126126
Updated graph with node-level attributes called "pred_id".
127127
128128
"""
129-
pred_graph = gutils.empty_copy(target_graph)
129+
pred_graph = nx.Graph(target_graph, pred_ids=set())
130130
with ThreadPoolExecutor() as executor:
131131
# Assign threads
132132
threads = []
133133
for i in pred_graph.nodes:
134-
img_coord = gutils.get_coord(target_graph, i)
134+
img_coord = gutils.get_coord(pred_graph, i)
135135
threads.append(executor.submit(self.get_label, img_coord, i))
136136
# Store results
137137
for thread in as_completed(threads):
@@ -431,7 +431,9 @@ def compute_edge_accuracy(self):
431431
def compute_erl(self):
432432
self.erl = dict()
433433
for swc_id in self.target_graphs.keys():
434-
None
434+
graph = self.pred_graphs[swc_id]
435+
path_lengths = gutils.compute_run_lengths(graph)
436+
self.erl[swc_id] = np.mean(path_lengths)
435437

436438

437439
# -- utils --

src/segmentation_skeleton_metrics/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,9 @@ def time_writer(t, unit="seconds"):
202202
unit = upd_unit[unit]
203203
t, unit = time_writer(t, unit=unit)
204204
return t, unit
205+
206+
207+
def progress_bar(current, total, bar_length=50):
208+
progress = int(current / total * bar_length)
209+
bar = f"[{'=' * progress}{' ' * (bar_length - progress)}] {current}/{total}"
210+
print(f"\r{bar}", end="", flush=True)

0 commit comments

Comments
 (0)