Skip to content

Commit cb71cac

Browse files
Merge pull request #42 from computational-cell-analytics/tonotopic_mapping
Initial mapping for IHCs and SGNs
2 parents 9984b6c + 856fd63 commit cb71cac

File tree

14 files changed

+834
-41
lines changed

14 files changed

+834
-41
lines changed
Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
import math
2+
from typing import List, Tuple
3+
4+
import networkx as nx
5+
import numpy as np
6+
import pandas as pd
7+
from networkx.algorithms.approximation import steiner_tree
8+
from scipy.ndimage import distance_transform_edt, binary_dilation, binary_closing
9+
from scipy.interpolate import interp1d
10+
11+
from flamingo_tools.segmentation.postprocessing import downscaled_centroids
12+
13+
14+
def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -> Tuple[float, float]:
15+
"""Find the most distant nodes in a graph.
16+
17+
Args:
18+
G: Input graph.
19+
20+
Returns:
21+
Node 1.
22+
Node 2.
23+
"""
24+
all_lengths = dict(nx.all_pairs_dijkstra_path_length(G, weight=weight))
25+
max_dist = 0
26+
farthest_pair = (None, None)
27+
28+
for u, dist_dict in all_lengths.items():
29+
for v, d in dist_dict.items():
30+
if d > max_dist:
31+
max_dist = d
32+
farthest_pair = (u, v)
33+
34+
u, v = farthest_pair
35+
return u, v
36+
37+
38+
def central_path_edt_graph(mask: np.ndarray, start: Tuple[int], end: Tuple[int]):
39+
"""Find the central path within a binary mask between a start and an end coordinate.
40+
41+
Args:
42+
mask: Binary mask of volume.
43+
start: Starting coordinate.
44+
end: End coordinate.
45+
46+
Returns:
47+
Coordinates of central path.
48+
"""
49+
dt = distance_transform_edt(mask)
50+
G = nx.Graph()
51+
shape = mask.shape
52+
def idx_to_node(z, y, x): return z*shape[1]*shape[2] + y*shape[2] + x
53+
border_coords = [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)]
54+
for z in range(shape[0]):
55+
for y in range(shape[1]):
56+
for x in range(shape[2]):
57+
if not mask[z, y, x]:
58+
continue
59+
u = idx_to_node(z, y, x)
60+
for dz, dy, dx in border_coords:
61+
nz, ny, nx_ = z+dz, y+dy, x+dx
62+
if nz >= 0 and nz < shape[0] and mask[nz, ny, nx_]:
63+
v = idx_to_node(nz, ny, nx_)
64+
w = 1.0 / (1e-3 + min(dt[z, y, x], dt[nz, ny, nx_]))
65+
G.add_edge(u, v, weight=w)
66+
s = idx_to_node(*start)
67+
t = idx_to_node(*end)
68+
path = nx.shortest_path(G, source=s, target=t, weight="weight")
69+
coords = [(p//(shape[1]*shape[2]),
70+
(p//shape[2]) % shape[1],
71+
p % shape[2]) for p in path]
72+
return np.array(coords)
73+
74+
75+
def moving_average_3d(path: np.ndarray, window: int = 5) -> np.ndarray:
76+
"""Smooth a 3D path with a simple moving average filter.
77+
78+
Args:
79+
path: ndarray of shape (N, 3).
80+
window: half-window size; actual window = 2*window + 1.
81+
82+
Returns:
83+
smoothed path: ndarray of same shape.
84+
"""
85+
kernel_size = 2 * window + 1
86+
kernel = np.ones(kernel_size) / kernel_size
87+
88+
smooth_path = np.zeros_like(path)
89+
90+
for d in range(3):
91+
pad = np.pad(path[:, d], window, mode='edge')
92+
smooth_path[:, d] = np.convolve(pad, kernel, mode='valid')
93+
94+
return smooth_path
95+
96+
97+
def measure_run_length_sgns(centroids: np.ndarray, scale_factor=10):
98+
"""Measure the run lengths of the SGN segmentation by finding a central path through Rosenthal's canal.
99+
1) Create a binary mask based on down-scaled centroids.
100+
2) Dilate the mask and close holes to ensure a filled structure.
101+
3) Determine the endpoints of the structure using the principal axis.
102+
4) Identify a central path based on the 3D Euclidean distance transform.
103+
5) The path is up-scaled and smoothed using a moving average filter.
104+
6) The points of the path are fed into a dictionary along with the fractional length.
105+
106+
Args:
107+
centroids: Centroids of the SGN segmentation, ndarray of shape (N, 3).
108+
scale_factor: Downscaling factor for finding the central path.
109+
110+
Returns:
111+
Total distance of the path.
112+
Path as an nd.array of positions.
113+
A dictionary containing the position and the length fraction of each point in the path.
114+
"""
115+
mask = downscaled_centroids(centroids, scale_factor=scale_factor, downsample_mode="capped")
116+
mask = binary_dilation(mask, np.ones((3, 3, 3)), iterations=1)
117+
mask = binary_closing(mask, np.ones((3, 3, 3)), iterations=1)
118+
pts = np.argwhere(mask == 1)
119+
120+
# find two endpoints: min/max along principal axis
121+
c_mean = pts.mean(axis=0)
122+
cov = np.cov((pts-c_mean).T)
123+
evals, evecs = np.linalg.eigh(cov)
124+
axis = evecs[:, np.argmax(evals)]
125+
proj = (pts - c_mean) @ axis
126+
start_voxel = tuple(pts[proj.argmin()])
127+
end_voxel = tuple(pts[proj.argmax()])
128+
129+
# get central path and total distance
130+
path = central_path_edt_graph(mask, start_voxel, end_voxel)
131+
path = path * scale_factor
132+
path = moving_average_3d(path, window=5)
133+
total_distance = sum([math.dist(path[num + 1], path[num]) for num in range(len(path) - 1)])
134+
135+
# assign relative distance to points on path
136+
path_dict = {}
137+
path_dict[0] = {"pos": path[0], "length_fraction": 0}
138+
accumulated = 0
139+
for num, p in enumerate(path[1:-1]):
140+
distance = math.dist(path[num], p)
141+
accumulated += distance
142+
rel_dist = accumulated / total_distance
143+
path_dict[num + 1] = {"pos": p, "length_fraction": rel_dist}
144+
path_dict[len(path)] = {"pos": path[-1], "length_fraction": 1}
145+
146+
return total_distance, path, path_dict
147+
148+
149+
def measure_run_length_ihcs(centroids):
150+
"""Measure the run lengths of the IHC segmentation
151+
by finding the shortest path between the most distant nodes in a Steiner Tree.
152+
153+
Args:
154+
centroids: Centroids of SGN segmentation.
155+
156+
Returns:
157+
Total distance of the path.
158+
Path as an nd.array of positions.
159+
A dictionary containing the position and the length fraction of each point in the path.
160+
"""
161+
graph = nx.Graph()
162+
for num, pos in enumerate(centroids):
163+
graph.add_node(num, pos=pos)
164+
# approximate Steiner tree and find shortest path between the two most distant nodes
165+
terminals = set(graph.nodes()) # All nodes are required
166+
# Approximate Steiner Tree over all nodes
167+
T = steiner_tree(graph, terminals)
168+
u, v = find_most_distant_nodes(T)
169+
path = nx.shortest_path(T, source=u, target=v)
170+
total_distance = nx.path_weight(T, path, weight="weight")
171+
172+
# assign relative distance to points on path
173+
path_dict = {}
174+
path_dict[0] = {"pos": graph.nodes[path[0]]["pos"], "length_fraction": 0}
175+
accumulated = 0
176+
for num, p in enumerate(path[1:-1]):
177+
distance = math.dist(graph.nodes[path[num]]["pos"], graph.nodes[p]["pos"])
178+
accumulated += distance
179+
rel_dist = accumulated / total_distance
180+
path_dict[num + 1] = {"pos": graph.nodes[p]["pos"], "length_fraction": rel_dist}
181+
path_dict[len(path)] = {"pos": graph.nodes[path[-1]]["pos"], "length_fraction": 1}
182+
183+
return total_distance, path, path_dict
184+
185+
186+
def map_frequency(table: pd.DataFrame):
187+
"""Map the frequency range of SGNs in the cochlea
188+
using Greenwood function f(x) = A * (10 **(ax) - K).
189+
Values for humans: a=2.1, k=0.88, A = 165.4 [kHz].
190+
For mice: fit values between minimal (1kHz) and maximal (80kHz) values
191+
192+
Args:
193+
table: Dataframe containing the segmentation.
194+
195+
Returns:
196+
Dataframe containing frequency in an additional column 'frequency[kHz]'.
197+
"""
198+
var_k = 0.88
199+
fmin = 1
200+
fmax = 80
201+
var_A = fmin / (1 - var_k)
202+
var_exp = ((fmax + var_A * var_k) / var_A)
203+
table.loc[table['offset'] >= 0, 'frequency[kHz]'] = var_A * (var_exp ** table["length_fraction"] - var_k)
204+
table.loc[table['offset'] < 0, 'frequency[kHz]'] = 0
205+
206+
return table
207+
208+
209+
def equidistant_centers(
210+
table: pd.DataFrame,
211+
component_label: List[int] = [1],
212+
cell_type: str = "sgn",
213+
n_blocks: int = 10,
214+
offset_blocks: bool = True,
215+
) -> np.ndarray:
216+
"""Find equidistant centers within the central path of the Rosenthal's canal.
217+
218+
Args:
219+
table: Dataframe containing centroids of SGN segmentation.
220+
component_label: List of components for centroid subset.
221+
cell_type: Cell type of the segmentation.
222+
n_blocks: Number of equidistant centers for block creation.
223+
offset_block: Centers are shifted by half a length if True. Avoid centers at the start/end of the path.
224+
225+
Returns:
226+
Equidistant centers as float values
227+
"""
228+
# subset of centroids for given component label(s)
229+
new_subset = table[table["component_labels"].isin(component_label)]
230+
centroids = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"]))
231+
232+
if cell_type == "ihc":
233+
total_distance, path, _ = measure_run_length_ihcs(centroids)
234+
235+
else:
236+
total_distance, path, _ = measure_run_length_sgns(centroids)
237+
238+
diffs = np.diff(path, axis=0)
239+
seg_lens = np.linalg.norm(diffs, axis=1)
240+
cum_len = np.insert(np.cumsum(seg_lens), 0, 0)
241+
if offset_blocks:
242+
target_s = np.linspace(0, total_distance, n_blocks * 2 + 1)
243+
target_s = [s for num, s in enumerate(target_s) if num % 2 == 1]
244+
else:
245+
target_s = np.linspace(0, total_distance, n_blocks)
246+
f = interp1d(cum_len, path, axis=0)
247+
centers = f(target_s)
248+
return centers
249+
250+
251+
def tonotopic_mapping(
252+
table: pd.DataFrame,
253+
component_label: List[int] = [1],
254+
cell_type: str = "ihc"
255+
) -> pd.DataFrame:
256+
"""Tonotopic mapping of IHCs by supplying a table with component labels.
257+
The mapping assigns a tonotopic label to each IHC according to the position along the length of the cochlea.
258+
259+
Args:
260+
table: Dataframe of segmentation table.
261+
component_label: List of component labels to evaluate.
262+
cell_type: Cell type of segmentation.
263+
264+
Returns:
265+
Table with tonotopic label for cells.
266+
"""
267+
# subset of centroids for given component label(s)
268+
new_subset = table[table["component_labels"].isin(component_label)]
269+
centroids = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"]))
270+
label_ids = [int(i) for i in list(new_subset["label_id"])]
271+
272+
if cell_type == "ihc":
273+
total_distance, _, path_dict = measure_run_length_ihcs(centroids)
274+
275+
else:
276+
total_distance, _, path_dict = measure_run_length_sgns(centroids)
277+
278+
# add missing nodes from component and compute distance to path
279+
node_dict = {}
280+
for num, c in enumerate(label_ids):
281+
min_dist = float('inf')
282+
nearest_node = None
283+
284+
for key in path_dict.keys():
285+
dist = math.dist(centroids[num], path_dict[key]["pos"])
286+
if dist < min_dist:
287+
min_dist = dist
288+
nearest_node = key
289+
290+
node_dict[c] = {
291+
"label_id": c,
292+
"length_fraction": path_dict[nearest_node]["length_fraction"],
293+
"offset": min_dist,
294+
}
295+
296+
offset = [-1 for _ in range(len(table))]
297+
# 'label_id' of dataframe starting at 1
298+
for key in list(node_dict.keys()):
299+
offset[int(node_dict[key]["label_id"] - 1)] = node_dict[key]["offset"]
300+
301+
table.loc[:, "offset"] = offset
302+
303+
length_fraction = [0 for _ in range(len(table))]
304+
for key in list(node_dict.keys()):
305+
length_fraction[int(node_dict[key]["label_id"] - 1)] = node_dict[key]["length_fraction"]
306+
307+
table.loc[:, "length_fraction"] = length_fraction
308+
table.loc[:, "length[µm]"] = table["length_fraction"] * total_distance
309+
310+
table = map_frequency(table)
311+
312+
return table

0 commit comments

Comments
 (0)