Skip to content

Commit 515bcc6

Browse files
committed
Initial mapping for IHCs and SGNs
1 parent c091e54 commit 515bcc6

File tree

3 files changed

+183
-3
lines changed

3 files changed

+183
-3
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import networkx as nx
2+
from networkx.algorithms.approximation import steiner_tree
3+
4+
from flamingo_tools.segmentation.postprocessing import graph_connected_components
5+
6+
7+
def find_most_distant_nodes(G, weight='weight'):
8+
all_lengths = dict(nx.all_pairs_dijkstra_path_length(G, weight=weight))
9+
max_dist = 0
10+
farthest_pair = (None, None)
11+
12+
for u, dist_dict in all_lengths.items():
13+
for v, d in dist_dict.items():
14+
if d > max_dist:
15+
max_dist = d
16+
farthest_pair = (u, v)
17+
18+
u, v = farthest_pair
19+
return u, v
20+
21+
22+
def steiner_path_between_distant_nodes(G, weight='weight'):
23+
# Step 1: Find the most distant pair of nodes
24+
u, v = find_most_distant_nodes(G, weight=weight)
25+
terminals = set(G.nodes()) # All nodes are required
26+
27+
# Step 2: Approximate Steiner Tree over all nodes
28+
T = steiner_tree(G, terminals, weight=weight)
29+
30+
# Step 3: Find the shortest path between u and v in the Steiner Tree
31+
path = nx.shortest_path(T, source=u, target=v, weight=weight)
32+
total_weight = nx.path_weight(T, path, weight=weight)
33+
34+
return {
35+
"start": u,
36+
"end": v,
37+
"path": path,
38+
"total_weight": total_weight,
39+
"steiner_tree": T
40+
}
41+
42+
43+
def nearest_node_on_path(G, main_path, query_node, weight='weight'):
44+
"""Find the nearest node in the connected component graph,
45+
which lies on the path between the two most distant nodes.
46+
"""
47+
if query_node in main_path:
48+
return {
49+
"nearest_node": query_node,
50+
"distance": 0
51+
}
52+
53+
min_dist = float('inf')
54+
nearest_node = None
55+
56+
for path_node in main_path:
57+
try:
58+
dist = nx.dijkstra_path_length(G, source=query_node, target=path_node, weight=weight)
59+
if dist < min_dist:
60+
min_dist = dist
61+
nearest_node = path_node
62+
except nx.NetworkXNoPath:
63+
continue # No path to this node
64+
65+
return {
66+
"nearest_node": nearest_node,
67+
"distance": min_dist if nearest_node is not None else None
68+
}
69+
70+
71+
def tonotopic_mapping(table, component_label=[1], min_edge_distance=30, min_component_length=50,
72+
cell_type="ihc"):
73+
"""Tonotopic mapping of IHCs by supplying a table with component labels.
74+
The mapping assigns a tonotopic label to each IHC according to the position along the length of the cochlea.
75+
"""
76+
# subset of centroids for given component label(s)
77+
new_subset = table[table["component_labels"].isin(component_label)]
78+
comp_label_ids = list(new_subset["label_id"])
79+
centroids_subset = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"]))
80+
labels_subset = [int(i) for i in list(new_subset["label_id"])]
81+
82+
# create graph with connected components
83+
coords = {}
84+
for index, element in zip(labels_subset, centroids_subset):
85+
coords[index] = element
86+
87+
components, graph = graph_connected_components(coords, min_edge_distance, min_component_length)
88+
89+
# approximate Steiner tree and find shortest path between the two most distant nodes
90+
91+
u, v = find_most_distant_nodes(graph)
92+
if cell_type == "ihc":
93+
terminals = set(graph.nodes()) # All nodes are required
94+
# Approximate Steiner Tree over all nodes
95+
T = steiner_tree(graph, terminals)
96+
path = nx.shortest_path(T, source=u, target=v)
97+
total_distance = nx.path_weight(T, path)
98+
99+
else:
100+
path = nx.shortest_path(graph, source=u, target=v)
101+
total_distance = nx.path_weight(graph, path)
102+
103+
# assign relative distance to nodes on path
104+
path_list = []
105+
path_list.append({"label_id": path[0], "value": 0})
106+
accumulated = 0
107+
for num, p in enumerate(path[1:-1]):
108+
distance = graph.get_edge_data(path[num], p)["weight"]
109+
accumulated += distance
110+
rel_dist = accumulated / total_distance
111+
path_list.append({"label_id": p, "value": rel_dist})
112+
path_list.append({"label_id": path[-1], "value": 1})
113+
114+
# add missing nodes from component
115+
for c in comp_label_ids:
116+
if c not in path:
117+
nearest_node = nearest_node_on_path(graph, path, c)["nearest_node"]
118+
for label in path_list:
119+
if label["label_id"] == nearest_node:
120+
nearest_node_value = label["value"]
121+
continue
122+
path_list.append({"label_id": int(c), "value": nearest_node_value})
123+
124+
tonotopic = [0 for _ in range(len(table))]
125+
# be aware of 'label_id' of dataframe starting at 1
126+
for d in path_list:
127+
tonotopic[d["label_id"] - 1] = d["value"] * len(total_distance)
128+
129+
table.loc[:, "tonotopic_label"] = tonotopic
130+
131+
return table

flamingo_tools/segmentation/postprocessing.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def graph_connected_components(coords: dict, min_edge_distance: float, min_compo
329329
330330
Returns:
331331
List of dictionary keys of connected components.
332+
Graph of connected components.
332333
"""
333334
graph = nx.Graph()
334335
for num, pos in coords.items():
@@ -351,7 +352,7 @@ def graph_connected_components(coords: dict, min_edge_distance: float, min_compo
351352
graph.remove_node(c)
352353

353354
components = [list(s) for s in nx.connected_components(graph)]
354-
return components
355+
return components, graph
355356

356357

357358
def components_sgn(
@@ -411,7 +412,7 @@ def components_sgn(
411412
for index, element in zip(labels_subset, centroids_subset):
412413
coords[index] = element
413414

414-
components = graph_connected_components(coords, min_edge_distance, min_component_length)
415+
components, _ = graph_connected_components(coords, min_edge_distance, min_component_length)
415416

416417
length_components = [len(c) for c in components]
417418
length_components, components = zip(*sorted(zip(length_components, components), reverse=True))
@@ -542,7 +543,7 @@ def components_ihc(
542543
for index, element in zip(labels, centroids):
543544
coords[index] = element
544545

545-
components = graph_connected_components(coords, min_edge_distance, min_component_length)
546+
components, _ = graph_connected_components(coords, min_edge_distance, min_component_length)
546547
return components
547548

548549

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import argparse
2+
3+
import pandas as pd
4+
5+
import flamingo_tools.s3_utils as s3_utils
6+
from flamingo_tools.segmentation.cochlea_mapping import tonotopic_mapping
7+
8+
9+
def main():
10+
11+
parser = argparse.ArgumentParser(
12+
description="Script for the tonotopic mapping of IHCs and SGNs. "
13+
"Either locally or on an S3 bucket.")
14+
15+
parser.add_argument("-i", "--input", required=True, help="Input table with IHC segmentation.")
16+
parser.add_argument("-o", "--output", required=True, help="Output path for post-processed table.")
17+
18+
parser.add_argument("-t", "--type", type=str, default="ihc", help="Cell type of segmentation.")
19+
parser.add_argument("--edge_distance", type=float, default=30, help="Maximal edge distance between nodes.")
20+
parser.add_argument("--component_length", type=int, default=50, help="Minimal number of nodes in component.")
21+
22+
parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.")
23+
parser.add_argument("--s3_credentials", type=str, default=None,
24+
help="Input file containing S3 credentials. "
25+
"Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.")
26+
parser.add_argument("--s3_bucket_name", type=str, default=None,
27+
help="S3 bucket name. Optional if BUCKET_NAME was exported.")
28+
parser.add_argument("--s3_service_endpoint", type=str, default=None,
29+
help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.")
30+
31+
args = parser.parse_args()
32+
33+
if args.s3:
34+
tsv_path, fs = s3_utils.get_s3_path(args.input, bucket_name=args.s3_bucket_name,
35+
service_endpoint=args.s3_service_endpoint,
36+
credential_file=args.s3_credentials)
37+
with fs.open(tsv_path, 'r') as f:
38+
tsv_table = pd.read_csv(f, sep="\t")
39+
else:
40+
with open(args.input, 'r') as f:
41+
tsv_table = pd.read_csv(f, sep="\t")
42+
43+
table = tonotopic_mapping(tsv_table)
44+
table.to_csv(args.output, sep="\t", index=False)
45+
46+
47+
if __name__ == "__main__":
48+
main()

0 commit comments

Comments
 (0)