Skip to content

Commit a648de4

Browse files
committed
Central path of Rosenthal's canal
1 parent c41743c commit a648de4

File tree

5 files changed

+181
-234
lines changed

5 files changed

+181
-234
lines changed

flamingo_tools/segmentation/cochlea_mapping.py

Lines changed: 159 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
11
import math
22
import warnings
3-
from typing import List, Optional, Tuple
3+
from typing import List, Tuple
44

55
import networkx as nx
66
import numpy as np
77
import pandas as pd
88
from networkx.algorithms.approximation import steiner_tree
9+
from scipy.ndimage import distance_transform_edt, binary_dilation, binary_closing
910

11+
import flamingo_tools.segmentation.postprocessing as postprocessing
1012
from flamingo_tools.segmentation.postprocessing import graph_connected_components
11-
from flamingo_tools.segmentation.distance_weighted_steiner import distance_weighted_steiner_path
1213

1314

1415
def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -> Tuple[float, float]:
16+
"""Find the most distant nodes in a graph.
17+
18+
Args:
19+
G: Input graph
20+
21+
Returns:
22+
Node 1
23+
Node 2
24+
"""
1525
all_lengths = dict(nx.all_pairs_dijkstra_path_length(G, weight=weight))
1626
max_dist = 0
1727
farthest_pair = (None, None)
@@ -26,90 +36,155 @@ def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -
2636
return u, v
2737

2838

29-
def voxel_subsample(G, factor=0.25, voxel_size=None, seed=1234):
30-
coords = np.asarray([G.nodes[n]["pos"] for n in G.nodes])
31-
nodes = np.asarray(list(G.nodes))
39+
def central_path_edt_graph(mask: np.ndarray, start: Tuple[int], end: Tuple[int]):
40+
"""Find the central path within a binary mask between a start and an end coordinate.
3241
33-
# choose a voxel edge length if the caller has not fixed one
34-
if voxel_size is None:
35-
bbox = np.ptp(coords, axis=0) # edge lengths
36-
voxel_size = (bbox.prod() / (len(G)/factor)) ** (1/3)
37-
38-
# integer voxel indices
39-
mins = coords.min(axis=0)
40-
vox = np.floor((coords - mins) / voxel_size).astype(np.int32)
42+
Args:
43+
mask: Binary mask of volume
44+
start: Starting coordinate
45+
end: End coordinate
46+
"""
47+
dt = distance_transform_edt(mask)
48+
G = nx.Graph()
49+
shape = mask.shape
50+
def idx_to_node(z, y, x): return z*shape[1]*shape[2] + y*shape[2] + x
51+
border_coords = [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)]
52+
for z in range(shape[0]):
53+
for y in range(shape[1]):
54+
for x in range(shape[2]):
55+
if not mask[z, y, x]:
56+
continue
57+
u = idx_to_node(z, y, x)
58+
for dz, dy, dx in border_coords:
59+
nz, ny, nx_ = z+dz, y+dy, x+dx
60+
if nz >= 0 and nz < shape[0] and mask[nz, ny, nx_]:
61+
v = idx_to_node(nz, ny, nx_)
62+
w = 1.0 / (1e-3 + min(dt[z, y, x], dt[nz, ny, nx_]))
63+
G.add_edge(u, v, weight=w)
64+
s = idx_to_node(*start)
65+
t = idx_to_node(*end)
66+
path = nx.shortest_path(G, source=s, target=t, weight="weight")
67+
coords = [(p//(shape[1]*shape[2]),
68+
(p//shape[2]) % shape[1],
69+
p % shape[2]) for p in path]
70+
return np.array(coords)
71+
72+
73+
def moving_average_3d(path: np.ndarray, window: int = 5) -> np.ndarray:
74+
"""Smooth a 3D path with a simple moving average filter.
4175
42-
# bucket nodes per voxel
43-
from collections import defaultdict
44-
buckets = defaultdict(list)
45-
for idx, v in enumerate(map(tuple, vox)):
46-
buckets[v].append(idx)
76+
Args:
77+
path: ndarray of shape (N, 3)
78+
window: half-window size; actual window = 2*window + 1
4779
48-
rng = np.random.default_rng(seed)
49-
keep = []
50-
for bucket in buckets.values():
51-
k = max(1, int(round(len(bucket)*factor))) # local quota
52-
keep.extend(rng.choice(bucket, k, replace=False))
80+
Returns:
81+
smoothed path: ndarray of same shape
82+
"""
83+
kernel_size = 2 * window + 1
84+
kernel = np.ones(kernel_size) / kernel_size
5385

54-
sampled_nodes = nodes[keep]
55-
return G.subgraph(sampled_nodes).copy()
86+
smooth_path = np.zeros_like(path)
5687

88+
for d in range(3):
89+
pad = np.pad(path[:, d], window, mode='edge')
90+
smooth_path[:, d] = np.convolve(pad, kernel, mode='valid')
5791

58-
def measure_run_length_sgns(graph, centroids, label_ids, filter_factor, weight="weight"):
59-
if filter_factor is not None:
60-
if 0 <= filter_factor < 1:
61-
graph = voxel_subsample(graph, factor=filter_factor)
62-
centroid_labels = list(graph.nodes)
63-
centroids = [graph.nodes[n]["pos"] for n in graph.nodes]
64-
k_nn_thick = int(40 * filter_factor)
65-
# centroids = [centroids[label_ids.index(i)] for i in centroid_labels]
92+
return smooth_path
6693

67-
else:
68-
raise ValueError(f"Invalid filter factor {filter_factor}. Choose a filter factor between 0 and 1.")
69-
else:
70-
k_nn_thick = 40
71-
centroid_labels = label_ids
7294

73-
path_coords, path = distance_weighted_steiner_path(
74-
centroids, # (N,3) ndarray
75-
centroid_labels=centroid_labels, # (N,) ndarray
76-
k_nn_thick=k_nn_thick, # 20‒30 is robust for SGN clouds int(40 * (1 - filter_factor))
77-
lam=0.5, # 0.3‒1.0 : larger → stronger centripetal bias
78-
r_connect=50.0 # connect neighbours within 50 µm
79-
)
95+
def measure_run_length_sgns(centroids: np.ndarray, scale_factor=10):
96+
"""Measure the run lengths of the SGN segmentation by finding a central path through Rosenthal's canal.
97+
1) Create a binary mask based on down-scaled centroids.
98+
2) Dilate the mask and close holes to ensure a filled structure.
99+
3) Determine the endpoints of the structure using the principal axis.
100+
4) Identify a central path based on the 3D Euclidean distance transform.
101+
5) The path is up-scaled and smoothed using a moving average filter.
102+
6) The points of the path are fed into a dictionary along with the fractional length.
80103
81-
for num, p in enumerate(path[:-1]):
82-
pos_i = centroids[centroid_labels.index(p)]
83-
pos_j = centroids[centroid_labels.index(path[num+1])]
84-
dist = math.dist(pos_i, pos_j)
85-
graph.add_edge(p, path[num+1], weight=dist)
104+
Args:
105+
centroids: Centroids of the SGN segmentation, ndarray of shape (N, 3)
106+
scale_factor: Downscaling factor for finding the central path.
86107
87-
total_distance = nx.path_weight(graph, path, weight=weight)
108+
"""
109+
mask = postprocessing.downscaled_centroids(centroids, scale_factor=scale_factor, downsample_mode="capped")
110+
mask = binary_dilation(mask, np.ones((3, 3, 3)), iterations=1)
111+
mask = binary_closing(mask, np.ones((3, 3, 3)), iterations=1)
112+
pts = np.argwhere(mask == 1)
113+
114+
# find two endpoints: min/max along principal axis
115+
c_mean = pts.mean(axis=0)
116+
cov = np.cov((pts-c_mean).T)
117+
evals, evecs = np.linalg.eigh(cov)
118+
axis = evecs[:, np.argmax(evals)]
119+
proj = (pts - c_mean) @ axis
120+
start_voxel = tuple(pts[proj.argmin()])
121+
end_voxel = tuple(pts[proj.argmax()])
122+
123+
# get central path and total distance
124+
path = central_path_edt_graph(mask, start_voxel, end_voxel)
125+
path = path * scale_factor
126+
path = moving_average_3d(path, window=5)
127+
total_distance = sum([math.dist(path[num + 1], path[num]) for num in range(len(path) - 1)])
128+
129+
# assign relative distance to points on path
130+
path_dict = {}
131+
path_dict[0] = {"pos": path[0], "length_fraction": 0}
132+
accumulated = 0
133+
for num, p in enumerate(path[1:-1]):
134+
distance = math.dist(path[num], p)
135+
accumulated += distance
136+
rel_dist = accumulated / total_distance
137+
path_dict[num + 1] = {"pos": p, "length_fraction": rel_dist}
138+
path_dict[len(path)] = {"pos": path[-1], "length_fraction": 1}
88139

89-
return total_distance, path, graph
140+
return total_distance, path_dict
90141

91142

92143
def measure_run_length_ihcs(graph, weight="weight"):
144+
"""Measure the run lengths of the IHC segmentation
145+
by finding the shortest path between the most distant nodes in a Steiner Tree.
146+
147+
Args:
148+
graph: Input graph.
149+
"""
93150
u, v = find_most_distant_nodes(graph)
94151
# approximate Steiner tree and find shortest path between the two most distant nodes
95152
terminals = set(graph.nodes()) # All nodes are required
96153
# Approximate Steiner Tree over all nodes
97154
T = steiner_tree(graph, terminals, weight=weight)
98155
path = nx.shortest_path(T, source=u, target=v, weight=weight)
99156
total_distance = nx.path_weight(T, path, weight=weight)
100-
return total_distance, path
101157

158+
# assign relative distance to points on path
159+
path_dict = {}
160+
path_dict[0] = {"pos": graph.nodes[path[0]]["pos"], "length_fraction": 0}
161+
accumulated = 0
162+
for num, p in enumerate(path[1:-1]):
163+
distance = math.dist(graph.nodes[path[num]]["pos"], graph.nodes[p]["pos"])
164+
accumulated += distance
165+
rel_dist = accumulated / total_distance
166+
path_dict[num + 1] = {"pos": graph.nodes[p]["pos"], "length_fraction": rel_dist}
167+
path_dict[len(path)] = {"pos": graph.nodes[path[-1]]["pos"], "length_fraction": 1}
168+
169+
return total_distance, path_dict
170+
171+
172+
def map_frequency(table: pd.DataFrame):
173+
"""Map the frequency range of SGNs in the cochlea
174+
using Greenwood function f(x) = A * (10 **(ax) - K).
175+
Values for humans: a=2.1, k=0.88, A = 165.4 [kHz].
176+
For mice: fit values between minimal (1kHz) and maximal (80kHz) values
102177
103-
def map_frequency(table):
104-
# map frequency using Greenwood function f(x) = A * (10 **(ax) - K), for humans: a=2.1, k=0.88, A = 165.4 [kHz]
178+
Args:
179+
table:
180+
"""
105181
var_k = 0.88
106-
# calculate values to fit (assumed) minimal (1kHz) and maximal (80kHz) hearing range of mice at x=0, x=1
107182
fmin = 1
108183
fmax = 80
109184
var_A = fmin / (1 - var_k)
110185
var_exp = ((fmax + var_A * var_k) / var_A)
111-
table.loc[table['distance_to_path[µm]'] >= 0, 'tonotopic_value[kHz]'] = var_A * (var_exp ** table["length_fraction"] - var_k)
112-
table.loc[table['distance_to_path[µm]'] < 0, 'tonotopic_value[kHz]'] = 0
186+
table.loc[table['offset'] >= 0, 'frequency[kHz]'] = var_A * (var_exp ** table["length_fraction"] - var_k)
187+
table.loc[table['offset'] < 0, 'frequency[kHz]'] = 0
113188

114189
return table
115190

@@ -119,8 +194,7 @@ def tonotopic_mapping(
119194
component_label: List[int] = [1],
120195
max_edge_distance: float = 30,
121196
min_component_length: int = 50,
122-
cell_type: str = "ihc",
123-
filter_factor: Optional[float] = None
197+
cell_type: str = "ihc"
124198
) -> pd.DataFrame:
125199
"""Tonotopic mapping of IHCs by supplying a table with component labels.
126200
The mapping assigns a tonotopic label to each IHC according to the position along the length of the cochlea.
@@ -154,63 +228,43 @@ def tonotopic_mapping(
154228
unfiltered_graph = graph.copy()
155229

156230
if cell_type == "ihc":
157-
total_distance, path = measure_run_length_ihcs(graph)
231+
total_distance, path_dict = measure_run_length_ihcs(graph)
158232

159233
else:
160-
total_distance, path, graph = measure_run_length_sgns(graph, centroids, label_ids,
161-
filter_factor, weight="weight")
162-
163-
# measure_betweenness
164-
centrality = nx.betweenness_centrality(graph, k=100, normalized=True, weight='weight', seed=1234)
165-
score = sum(centrality[n] for n in path) / len(path)
166-
print(f"path distance: {total_distance}")
167-
print(f"centrality score: {score}")
168-
169-
# assign relative distance to nodes on path
170-
path_dict = {}
171-
path_dict[path[0]] = {"label_id": path[0], "length_fraction": 0}
172-
accumulated = 0
173-
for num, p in enumerate(path[1:-1]):
174-
distance = graph.get_edge_data(path[num], p)["weight"]
175-
accumulated += distance
176-
rel_dist = accumulated / total_distance
177-
path_dict[p] = {"label_id": p, "length_fraction": rel_dist}
178-
path_dict[path[-1]] = {"label_id": path[-1], "length_fraction": 1}
234+
total_distance, path_dict = measure_run_length_sgns(centroids)
179235

180236
# add missing nodes from component and compute distance to path
181237
pos = nx.get_node_attributes(unfiltered_graph, 'pos')
238+
node_dict = {}
182239
for c in label_ids:
183-
if c not in path:
184-
min_dist = float('inf')
185-
nearest_node = None
186-
187-
for p in path:
188-
dist = math.dist(pos[c], pos[p])
189-
if dist < min_dist:
190-
min_dist = dist
191-
nearest_node = p
192-
193-
path_dict[c] = {
194-
"label_id": c,
195-
"length_fraction": path_dict[nearest_node]["length_fraction"],
196-
"distance_to_path": min_dist,
197-
}
198-
else:
199-
path_dict[c]["distance_to_path"] = 0
200-
201-
distance_to_path = [-1 for _ in range(len(table))]
240+
min_dist = float('inf')
241+
nearest_node = None
242+
243+
for key in path_dict.keys():
244+
dist = math.dist(pos[c], path_dict[key]["pos"])
245+
if dist < min_dist:
246+
min_dist = dist
247+
nearest_node = key
248+
249+
node_dict[c] = {
250+
"label_id": c,
251+
"length_fraction": path_dict[nearest_node]["length_fraction"],
252+
"offset": min_dist,
253+
}
254+
255+
offset = [-1 for _ in range(len(table))]
202256
# 'label_id' of dataframe starting at 1
203-
for key in list(path_dict.keys()):
204-
distance_to_path[int(path_dict[key]["label_id"] - 1)] = path_dict[key]["distance_to_path"]
257+
for key in list(node_dict.keys()):
258+
offset[int(node_dict[key]["label_id"] - 1)] = node_dict[key]["offset"]
205259

206-
table.loc[:, "distance_to_path[µm]"] = distance_to_path
260+
table.loc[:, "offset"] = offset
207261

208262
length_fraction = [0 for _ in range(len(table))]
209-
for key in list(path_dict.keys()):
210-
length_fraction[int(path_dict[key]["label_id"] - 1)] = path_dict[key]["length_fraction"]
263+
for key in list(node_dict.keys()):
264+
length_fraction[int(node_dict[key]["label_id"] - 1)] = node_dict[key]["length_fraction"]
211265

212266
table.loc[:, "length_fraction"] = length_fraction
213-
table.loc[:, "run_length[µm]"] = table["length_fraction"] * total_distance
267+
table.loc[:, "length[µm]"] = table["length_fraction"] * total_distance
214268

215269
table = map_frequency(table)
216270

0 commit comments

Comments
 (0)