Skip to content

Commit 740caaa

Browse files
committed
Fixed function calls
1 parent 515bcc6 commit 740caaa

File tree

1 file changed

+24
-45
lines changed

1 file changed

+24
-45
lines changed

flamingo_tools/segmentation/cochlea_mapping.py

Lines changed: 24 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
import networkx as nx
24
from networkx.algorithms.approximation import steiner_tree
35

@@ -40,36 +42,8 @@ def steiner_path_between_distant_nodes(G, weight='weight'):
4042
}
4143

4244

43-
def nearest_node_on_path(G, main_path, query_node, weight='weight'):
44-
"""Find the nearest node in the connected component graph,
45-
which lies on the path between the two most distant nodes.
46-
"""
47-
if query_node in main_path:
48-
return {
49-
"nearest_node": query_node,
50-
"distance": 0
51-
}
52-
53-
min_dist = float('inf')
54-
nearest_node = None
55-
56-
for path_node in main_path:
57-
try:
58-
dist = nx.dijkstra_path_length(G, source=query_node, target=path_node, weight=weight)
59-
if dist < min_dist:
60-
min_dist = dist
61-
nearest_node = path_node
62-
except nx.NetworkXNoPath:
63-
continue # No path to this node
64-
65-
return {
66-
"nearest_node": nearest_node,
67-
"distance": min_dist if nearest_node is not None else None
68-
}
69-
70-
7145
def tonotopic_mapping(table, component_label=[1], min_edge_distance=30, min_component_length=50,
72-
cell_type="ihc"):
46+
cell_type="ihc", weight='weight'):
7347
"""Tonotopic mapping of IHCs by supplying a table with component labels.
7448
The mapping assigns a tonotopic label to each IHC according to the position along the length of the cochlea.
7549
"""
@@ -92,39 +66,44 @@ def tonotopic_mapping(table, component_label=[1], min_edge_distance=30, min_comp
9266
if cell_type == "ihc":
9367
terminals = set(graph.nodes()) # All nodes are required
9468
# Approximate Steiner Tree over all nodes
95-
T = steiner_tree(graph, terminals)
96-
path = nx.shortest_path(T, source=u, target=v)
97-
total_distance = nx.path_weight(T, path)
69+
T = steiner_tree(graph, terminals, weight=weight)
70+
path = nx.shortest_path(T, source=u, target=v, weight=weight)
71+
total_distance = nx.path_weight(T, path, weight=weight)
9872

9973
else:
100-
path = nx.shortest_path(graph, source=u, target=v)
101-
total_distance = nx.path_weight(graph, path)
74+
path = nx.shortest_path(graph, source=u, target=v, weight=weight)
75+
total_distance = nx.path_weight(graph, path, weight=weight)
10276

10377
# assign relative distance to nodes on path
104-
path_list = []
105-
path_list.append({"label_id": path[0], "value": 0})
78+
path_list = {}
79+
path_list[path[0]] = {"label_id": path[0], "tonotopic": 0}
10680
accumulated = 0
10781
for num, p in enumerate(path[1:-1]):
10882
distance = graph.get_edge_data(path[num], p)["weight"]
10983
accumulated += distance
11084
rel_dist = accumulated / total_distance
111-
path_list.append({"label_id": p, "value": rel_dist})
112-
path_list.append({"label_id": path[-1], "value": 1})
85+
path_list[p] = {"label_id": p, "tonotopic": rel_dist}
86+
path_list[path[-1]] = {"label_id": path[-1], "tonotopic": 1}
11387

11488
# add missing nodes from component
89+
pos = nx.get_node_attributes(graph, 'pos')
11590
for c in comp_label_ids:
11691
if c not in path:
117-
nearest_node = nearest_node_on_path(graph, path, c)["nearest_node"]
118-
for label in path_list:
119-
if label["label_id"] == nearest_node:
120-
nearest_node_value = label["value"]
121-
continue
122-
path_list.append({"label_id": int(c), "value": nearest_node_value})
92+
min_dist = float('inf')
93+
nearest_node = None
94+
95+
for p in path:
96+
dist = math.dist(pos[c], pos[p])
97+
if dist < min_dist:
98+
min_dist = dist
99+
nearest_node = p
100+
101+
path_list[c] = {"label_id": c, "tonotopic": path_list[nearest_node]["tonotopic"]}
123102

124103
tonotopic = [0 for _ in range(len(table))]
125104
# be aware of 'label_id' of dataframe starting at 1
126105
for d in path_list:
127-
tonotopic[d["label_id"] - 1] = d["value"] * len(total_distance)
106+
tonotopic[d["label_id"] - 1] = d["value"] * total_distance
128107

129108
table.loc[:, "tonotopic_label"] = tonotopic
130109

0 commit comments

Comments
 (0)