Skip to content

Commit c41743c

Browse files
committed
Refactoring; Added weighted Steiner tree
1 parent 7623a96 commit c41743c

File tree

3 files changed

+244
-52
lines changed

3 files changed

+244
-52
lines changed

flamingo_tools/segmentation/cochlea_mapping.py

Lines changed: 127 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from networkx.algorithms.approximation import steiner_tree
99

1010
from flamingo_tools.segmentation.postprocessing import graph_connected_components
11+
from flamingo_tools.segmentation.distance_weighted_steiner import distance_weighted_steiner_path
1112

1213

1314
def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -> Tuple[float, float]:
@@ -25,6 +26,94 @@ def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -
2526
return u, v
2627

2728

29+
def voxel_subsample(G, factor=0.25, voxel_size=None, seed=1234):
30+
coords = np.asarray([G.nodes[n]["pos"] for n in G.nodes])
31+
nodes = np.asarray(list(G.nodes))
32+
33+
# choose a voxel edge length if the caller has not fixed one
34+
if voxel_size is None:
35+
bbox = np.ptp(coords, axis=0) # edge lengths
36+
voxel_size = (bbox.prod() / (len(G)/factor)) ** (1/3)
37+
38+
# integer voxel indices
39+
mins = coords.min(axis=0)
40+
vox = np.floor((coords - mins) / voxel_size).astype(np.int32)
41+
42+
# bucket nodes per voxel
43+
from collections import defaultdict
44+
buckets = defaultdict(list)
45+
for idx, v in enumerate(map(tuple, vox)):
46+
buckets[v].append(idx)
47+
48+
rng = np.random.default_rng(seed)
49+
keep = []
50+
for bucket in buckets.values():
51+
k = max(1, int(round(len(bucket)*factor))) # local quota
52+
keep.extend(rng.choice(bucket, k, replace=False))
53+
54+
sampled_nodes = nodes[keep]
55+
return G.subgraph(sampled_nodes).copy()
56+
57+
58+
def measure_run_length_sgns(graph, centroids, label_ids, filter_factor, weight="weight"):
59+
if filter_factor is not None:
60+
if 0 <= filter_factor < 1:
61+
graph = voxel_subsample(graph, factor=filter_factor)
62+
centroid_labels = list(graph.nodes)
63+
centroids = [graph.nodes[n]["pos"] for n in graph.nodes]
64+
k_nn_thick = int(40 * filter_factor)
65+
# centroids = [centroids[label_ids.index(i)] for i in centroid_labels]
66+
67+
else:
68+
raise ValueError(f"Invalid filter factor {filter_factor}. Choose a filter factor between 0 and 1.")
69+
else:
70+
k_nn_thick = 40
71+
centroid_labels = label_ids
72+
73+
path_coords, path = distance_weighted_steiner_path(
74+
centroids, # (N,3) ndarray
75+
centroid_labels=centroid_labels, # (N,) ndarray
76+
k_nn_thick=k_nn_thick, # 20‒30 is robust for SGN clouds int(40 * (1 - filter_factor))
77+
lam=0.5, # 0.3‒1.0 : larger → stronger centripetal bias
78+
r_connect=50.0 # connect neighbours within 50 µm
79+
)
80+
81+
for num, p in enumerate(path[:-1]):
82+
pos_i = centroids[centroid_labels.index(p)]
83+
pos_j = centroids[centroid_labels.index(path[num+1])]
84+
dist = math.dist(pos_i, pos_j)
85+
graph.add_edge(p, path[num+1], weight=dist)
86+
87+
total_distance = nx.path_weight(graph, path, weight=weight)
88+
89+
return total_distance, path, graph
90+
91+
92+
def measure_run_length_ihcs(graph, weight="weight"):
93+
u, v = find_most_distant_nodes(graph)
94+
# approximate Steiner tree and find shortest path between the two most distant nodes
95+
terminals = set(graph.nodes()) # All nodes are required
96+
# Approximate Steiner Tree over all nodes
97+
T = steiner_tree(graph, terminals, weight=weight)
98+
path = nx.shortest_path(T, source=u, target=v, weight=weight)
99+
total_distance = nx.path_weight(T, path, weight=weight)
100+
return total_distance, path
101+
102+
103+
def map_frequency(table):
104+
# map frequency using Greenwood function f(x) = A * (10 **(ax) - K), for humans: a=2.1, k=0.88, A = 165.4 [kHz]
105+
var_k = 0.88
106+
# calculate values to fit (assumed) minimal (1kHz) and maximal (80kHz) hearing range of mice at x=0, x=1
107+
fmin = 1
108+
fmax = 80
109+
var_A = fmin / (1 - var_k)
110+
var_exp = ((fmax + var_A * var_k) / var_A)
111+
table.loc[table['distance_to_path[µm]'] >= 0, 'tonotopic_value[kHz]'] = var_A * (var_exp ** table["length_fraction"] - var_k)
112+
table.loc[table['distance_to_path[µm]'] < 0, 'tonotopic_value[kHz]'] = 0
113+
114+
return table
115+
116+
28117
def tonotopic_mapping(
29118
table: pd.DataFrame,
30119
component_label: List[int] = [1],
@@ -47,16 +136,14 @@ def tonotopic_mapping(
47136
Returns:
48137
Table with tonotopic label for cells.
49138
"""
50-
weight = "weight"
51139
# subset of centroids for given component label(s)
52140
new_subset = table[table["component_labels"].isin(component_label)]
53-
comp_label_ids = list(new_subset["label_id"])
54-
centroids_subset = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"]))
55-
labels_subset = [int(i) for i in list(new_subset["label_id"])]
141+
centroids = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"]))
142+
label_ids = [int(i) for i in list(new_subset["label_id"])]
56143

57144
# create graph with connected components
58145
coords = {}
59-
for index, element in zip(labels_subset, centroids_subset):
146+
for index, element in zip(label_ids, centroids):
60147
coords[index] = element
61148

62149
components, graph = graph_connected_components(coords, max_edge_distance, min_component_length)
@@ -66,45 +153,33 @@ def tonotopic_mapping(
66153

67154
unfiltered_graph = graph.copy()
68155

69-
if filter_factor is not None:
70-
if 0 <= filter_factor < 1:
71-
rng = np.random.default_rng(seed=1234)
72-
original_array = np.array(comp_label_ids)
73-
target_length = int(len(original_array) * filter_factor)
74-
filtered_list = list(rng.choice(original_array, size=target_length, replace=False))
75-
for filter_id in filtered_list:
76-
graph.remove_node(filter_id)
77-
else:
78-
raise ValueError(f"Invalid filter factor {filter_factor}. Choose a filter factor between 0 and 1.")
79-
80-
u, v = find_most_distant_nodes(graph)
81-
82-
if not nx.has_path(graph, source=u, target=v) or cell_type == "ihc":
83-
# approximate Steiner tree and find shortest path between the two most distant nodes
84-
terminals = set(graph.nodes()) # All nodes are required
85-
# Approximate Steiner Tree over all nodes
86-
T = steiner_tree(graph, terminals, weight=weight)
87-
path = nx.shortest_path(T, source=u, target=v, weight=weight)
88-
total_distance = nx.path_weight(T, path, weight=weight)
156+
if cell_type == "ihc":
157+
total_distance, path = measure_run_length_ihcs(graph)
89158

90159
else:
91-
path = nx.shortest_path(graph, source=u, target=v, weight=weight)
92-
total_distance = nx.path_weight(graph, path, weight=weight)
160+
total_distance, path, graph = measure_run_length_sgns(graph, centroids, label_ids,
161+
filter_factor, weight="weight")
162+
163+
# measure_betweenness
164+
centrality = nx.betweenness_centrality(graph, k=100, normalized=True, weight='weight', seed=1234)
165+
score = sum(centrality[n] for n in path) / len(path)
166+
print(f"path distance: {total_distance}")
167+
print(f"centrality score: {score}")
93168

94169
# assign relative distance to nodes on path
95-
path_list = {}
96-
path_list[path[0]] = {"label_id": path[0], "tonotopic": 0}
170+
path_dict = {}
171+
path_dict[path[0]] = {"label_id": path[0], "length_fraction": 0}
97172
accumulated = 0
98173
for num, p in enumerate(path[1:-1]):
99174
distance = graph.get_edge_data(path[num], p)["weight"]
100175
accumulated += distance
101176
rel_dist = accumulated / total_distance
102-
path_list[p] = {"label_id": p, "tonotopic": rel_dist}
103-
path_list[path[-1]] = {"label_id": path[-1], "tonotopic": 1}
177+
path_dict[p] = {"label_id": p, "length_fraction": rel_dist}
178+
path_dict[path[-1]] = {"label_id": path[-1], "length_fraction": 1}
104179

105-
# add missing nodes from component
180+
# add missing nodes from component and compute distance to path
106181
pos = nx.get_node_attributes(unfiltered_graph, 'pos')
107-
for c in comp_label_ids:
182+
for c in label_ids:
108183
if c not in path:
109184
min_dist = float('inf')
110185
nearest_node = None
@@ -115,27 +190,28 @@ def tonotopic_mapping(
115190
min_dist = dist
116191
nearest_node = p
117192

118-
path_list[c] = {"label_id": c, "tonotopic": path_list[nearest_node]["tonotopic"]}
193+
path_dict[c] = {
194+
"label_id": c,
195+
"length_fraction": path_dict[nearest_node]["length_fraction"],
196+
"distance_to_path": min_dist,
197+
}
198+
else:
199+
path_dict[c]["distance_to_path"] = 0
119200

120-
# label in micrometer
121-
tonotopic = [0 for _ in range(len(table))]
122-
# be aware of 'label_id' of dataframe starting at 1
123-
for key in list(path_list.keys()):
124-
tonotopic[int(path_list[key]["label_id"] - 1)] = path_list[key]["tonotopic"] * total_distance
201+
distance_to_path = [-1 for _ in range(len(table))]
202+
# 'label_id' of dataframe starting at 1
203+
for key in list(path_dict.keys()):
204+
distance_to_path[int(path_dict[key]["label_id"] - 1)] = path_dict[key]["distance_to_path"]
125205

126-
table.loc[:, "tonotopic_label"] = tonotopic
206+
table.loc[:, "distance_to_path[µm]"] = distance_to_path
127207

128-
# map frequency using Greenwood function f(x) = A * (10 **(ax) - K), for humans: a=2.1, k=0.88, A = 165.4 [kHz]
129-
tonotopic_map = [0 for _ in range(len(table))]
130-
var_k = 0.88
131-
# calculate values to fit (assumed) minimal (1kHz) and maximal (80kHz) hearing range of mice at x=0, x=1
132-
fmin = 1
133-
fmax = 80
134-
var_A = fmin / (1 - var_k)
135-
var_exp = ((fmax + var_A * var_k) / var_A)
136-
for key in list(path_list.keys()):
137-
tonotopic_map[int(path_list[key]["label_id"] - 1)] = var_A * (var_exp ** path_list[key]["tonotopic"] - var_k)
208+
length_fraction = [0 for _ in range(len(table))]
209+
for key in list(path_dict.keys()):
210+
length_fraction[int(path_dict[key]["label_id"] - 1)] = path_dict[key]["length_fraction"]
211+
212+
table.loc[:, "length_fraction"] = length_fraction
213+
table.loc[:, "run_length[µm]"] = table["length_fraction"] * total_distance
138214

139-
table.loc[:, "tonotopic_value[kHz]"] = tonotopic_map
215+
table = map_frequency(table)
140216

141217
return table
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""
2+
distance_weighted_steiner.py
3+
Variant-B: centre-seeking Steiner path for cochlear run-length extraction
4+
"""
5+
6+
from __future__ import annotations
7+
import numpy as np
8+
import networkx as nx
9+
from scipy.spatial import cKDTree
10+
from typing import Tuple, Sequence, Optional
11+
12+
13+
def estimate_local_thickness(points: np.ndarray,
14+
k_nn: int = 20) -> np.ndarray:
15+
"""
16+
Return a per-point scalar proportional to local canal thickness.
17+
We use the *k*-th NN distance as a cheap proxy.
18+
"""
19+
tree = cKDTree(points)
20+
# distances shape → (N, k_nn)
21+
dists, _ = tree.query(points, k=k_nn + 1) # +1 because k=0 is the point itself
22+
kth = dists[:, -1] # farthest of the k neighbours
23+
return kth # units: same as points
24+
25+
26+
def make_graph(points: np.ndarray,
27+
radii: np.ndarray,
28+
r_connect: float = 60.0,
29+
lam: float = 0.5,
30+
k_edge: Optional[int] = None) -> nx.Graph:
31+
"""
32+
Build a graph with distance-transform-weighted edges.
33+
34+
Parameters
35+
----------
36+
points : (N,3) float array
37+
radii : (N,) local thickness proxy
38+
r_connect: connect all neighbours within this radius (µm)
39+
lam : weight of |d_i - d_j| term
40+
k_edge : alternative to r_connect - connect the k_edge
41+
nearest neighbours; leave None to use radius
42+
"""
43+
N = len(points)
44+
tree = cKDTree(points)
45+
46+
G = nx.Graph()
47+
# add nodes with attributes
48+
for idx, (xyz, r) in enumerate(zip(points, radii)):
49+
G.add_node(idx, pos=tuple(xyz), radius=float(r))
50+
51+
# choose connectivity strategy
52+
if k_edge is not None:
53+
for idx in range(N):
54+
_, inds = tree.query(points[idx], k=k_edge + 1)
55+
for j in inds[1:]:
56+
_add_edge(G, idx, j, radii, lam)
57+
else:
58+
# radius search in batches (memory safe)
59+
pairs = tree.query_pairs(r_connect)
60+
for i, j in pairs:
61+
_add_edge(G, i, j, radii, lam)
62+
63+
return G
64+
65+
66+
def _add_edge(G: nx.Graph, i: int, j: int,
67+
radii: np.ndarray, lam: float):
68+
"""Helper to compute weighted edge once and add both directions."""
69+
pi, pj = G.nodes[i]["pos"], G.nodes[j]["pos"]
70+
dij = np.linalg.norm(np.subtract(pi, pj))
71+
dr = abs(radii[i] - radii[j]) / (radii[i] + radii[j] + 1e-9)
72+
w = dij * (1.0 + lam * dr)
73+
G.add_edge(i, j, weight=w)
74+
75+
76+
def find_endpoints(points: np.ndarray) -> Tuple[int, int]:
77+
"""
78+
Pick apical+basal terminals as the points with minimum/maximum
79+
projection on the first PCA axis (fast & robust).
80+
"""
81+
# simple PCA via SVD
82+
pts = points - points.mean(0, keepdims=True)
83+
u, s, vh = np.linalg.svd(pts, full_matrices=False)
84+
axis = vh[0]
85+
proj = pts @ axis
86+
return int(proj.argmin()), int(proj.argmax())
87+
88+
89+
def distance_weighted_steiner_path(centroids: Sequence[Sequence[float]],
90+
*,
91+
centroid_labels: Optional[Sequence[int]] = None,
92+
k_nn_thick: int = 20,
93+
lam: float = 0.5,
94+
r_connect: float = 60.0,
95+
k_edge: Optional[int] = None) -> Tuple[np.ndarray, list[int]]:
96+
"""
97+
Main public entry - returns (Mx3 point array, list of node indices)
98+
representing the centre-biased cochlear path.
99+
"""
100+
pts = np.asarray(centroids, dtype=float)
101+
radii = estimate_local_thickness(pts, k_nn=k_nn_thick)
102+
103+
G = make_graph(pts, radii, r_connect=r_connect, lam=lam, k_edge=k_edge)
104+
105+
s, t = find_endpoints(pts)
106+
steiner = nx.algorithms.approximation.steinertree.steiner_tree(G, {s, t}, weight="weight")
107+
# unique s–t path inside the tree (no branches because only 2 terminals):
108+
path_nodes = nx.shortest_path(steiner, source=s, target=t, weight="weight")
109+
path_xyz = np.array([G.nodes[i]["pos"] for i in path_nodes])
110+
111+
# transfer path nodes into centroid_labels
112+
if centroid_labels is not None:
113+
path_nodes = [centroid_labels[i] for i in path_nodes]
114+
115+
return path_xyz, path_nodes

scripts/prediction/tonotopic_mapping.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ def main():
1616
parser.add_argument("-o", "--output", required=True, help="Output path for post-processed table.")
1717

1818
parser.add_argument("-t", "--type", type=str, default="ihc", help="Cell type of segmentation.")
19-
parser.add_argument("--filter", type=float, default=None, help="Fraction of nodes to remove before mapping.")
19+
parser.add_argument("--filter", type=float, default=None,
20+
help="Fraction of nodes to keep before mapping. Default: 1.")
2021
parser.add_argument("--edge_distance", type=float, default=30, help="Maximal edge distance between nodes.")
2122
parser.add_argument("--component_length", type=int, default=50, help="Minimal number of nodes in component.")
2223

0 commit comments

Comments
 (0)