|
| 1 | +import math |
| 2 | +from typing import List, Tuple |
| 3 | + |
| 4 | +import networkx as nx |
| 5 | +import numpy as np |
| 6 | +import pandas as pd |
| 7 | +from networkx.algorithms.approximation import steiner_tree |
| 8 | +from scipy.ndimage import distance_transform_edt, binary_dilation, binary_closing |
| 9 | +from scipy.interpolate import interp1d |
| 10 | + |
| 11 | +from flamingo_tools.segmentation.postprocessing import downscaled_centroids |
| 12 | + |
| 13 | + |
| 14 | +def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -> Tuple[float, float]: |
| 15 | + """Find the most distant nodes in a graph. |
| 16 | +
|
| 17 | + Args: |
| 18 | + G: Input graph. |
| 19 | +
|
| 20 | + Returns: |
| 21 | + Node 1. |
| 22 | + Node 2. |
| 23 | + """ |
| 24 | + all_lengths = dict(nx.all_pairs_dijkstra_path_length(G, weight=weight)) |
| 25 | + max_dist = 0 |
| 26 | + farthest_pair = (None, None) |
| 27 | + |
| 28 | + for u, dist_dict in all_lengths.items(): |
| 29 | + for v, d in dist_dict.items(): |
| 30 | + if d > max_dist: |
| 31 | + max_dist = d |
| 32 | + farthest_pair = (u, v) |
| 33 | + |
| 34 | + u, v = farthest_pair |
| 35 | + return u, v |
| 36 | + |
| 37 | + |
| 38 | +def central_path_edt_graph(mask: np.ndarray, start: Tuple[int], end: Tuple[int]): |
| 39 | + """Find the central path within a binary mask between a start and an end coordinate. |
| 40 | +
|
| 41 | + Args: |
| 42 | + mask: Binary mask of volume. |
| 43 | + start: Starting coordinate. |
| 44 | + end: End coordinate. |
| 45 | +
|
| 46 | + Returns: |
| 47 | + Coordinates of central path. |
| 48 | + """ |
| 49 | + dt = distance_transform_edt(mask) |
| 50 | + G = nx.Graph() |
| 51 | + shape = mask.shape |
| 52 | + def idx_to_node(z, y, x): return z*shape[1]*shape[2] + y*shape[2] + x |
| 53 | + border_coords = [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)] |
| 54 | + for z in range(shape[0]): |
| 55 | + for y in range(shape[1]): |
| 56 | + for x in range(shape[2]): |
| 57 | + if not mask[z, y, x]: |
| 58 | + continue |
| 59 | + u = idx_to_node(z, y, x) |
| 60 | + for dz, dy, dx in border_coords: |
| 61 | + nz, ny, nx_ = z+dz, y+dy, x+dx |
| 62 | + if nz >= 0 and nz < shape[0] and mask[nz, ny, nx_]: |
| 63 | + v = idx_to_node(nz, ny, nx_) |
| 64 | + w = 1.0 / (1e-3 + min(dt[z, y, x], dt[nz, ny, nx_])) |
| 65 | + G.add_edge(u, v, weight=w) |
| 66 | + s = idx_to_node(*start) |
| 67 | + t = idx_to_node(*end) |
| 68 | + path = nx.shortest_path(G, source=s, target=t, weight="weight") |
| 69 | + coords = [(p//(shape[1]*shape[2]), |
| 70 | + (p//shape[2]) % shape[1], |
| 71 | + p % shape[2]) for p in path] |
| 72 | + return np.array(coords) |
| 73 | + |
| 74 | + |
| 75 | +def moving_average_3d(path: np.ndarray, window: int = 5) -> np.ndarray: |
| 76 | + """Smooth a 3D path with a simple moving average filter. |
| 77 | +
|
| 78 | + Args: |
| 79 | + path: ndarray of shape (N, 3). |
| 80 | + window: half-window size; actual window = 2*window + 1. |
| 81 | +
|
| 82 | + Returns: |
| 83 | + smoothed path: ndarray of same shape. |
| 84 | + """ |
| 85 | + kernel_size = 2 * window + 1 |
| 86 | + kernel = np.ones(kernel_size) / kernel_size |
| 87 | + |
| 88 | + smooth_path = np.zeros_like(path) |
| 89 | + |
| 90 | + for d in range(3): |
| 91 | + pad = np.pad(path[:, d], window, mode='edge') |
| 92 | + smooth_path[:, d] = np.convolve(pad, kernel, mode='valid') |
| 93 | + |
| 94 | + return smooth_path |
| 95 | + |
| 96 | + |
| 97 | +def measure_run_length_sgns(centroids: np.ndarray, scale_factor=10): |
| 98 | + """Measure the run lengths of the SGN segmentation by finding a central path through Rosenthal's canal. |
| 99 | + 1) Create a binary mask based on down-scaled centroids. |
| 100 | + 2) Dilate the mask and close holes to ensure a filled structure. |
| 101 | + 3) Determine the endpoints of the structure using the principal axis. |
| 102 | + 4) Identify a central path based on the 3D Euclidean distance transform. |
| 103 | + 5) The path is up-scaled and smoothed using a moving average filter. |
| 104 | + 6) The points of the path are fed into a dictionary along with the fractional length. |
| 105 | +
|
| 106 | + Args: |
| 107 | + centroids: Centroids of the SGN segmentation, ndarray of shape (N, 3). |
| 108 | + scale_factor: Downscaling factor for finding the central path. |
| 109 | +
|
| 110 | + Returns: |
| 111 | + Total distance of the path. |
| 112 | + Path as an nd.array of positions. |
| 113 | + A dictionary containing the position and the length fraction of each point in the path. |
| 114 | + """ |
| 115 | + mask = downscaled_centroids(centroids, scale_factor=scale_factor, downsample_mode="capped") |
| 116 | + mask = binary_dilation(mask, np.ones((3, 3, 3)), iterations=1) |
| 117 | + mask = binary_closing(mask, np.ones((3, 3, 3)), iterations=1) |
| 118 | + pts = np.argwhere(mask == 1) |
| 119 | + |
| 120 | + # find two endpoints: min/max along principal axis |
| 121 | + c_mean = pts.mean(axis=0) |
| 122 | + cov = np.cov((pts-c_mean).T) |
| 123 | + evals, evecs = np.linalg.eigh(cov) |
| 124 | + axis = evecs[:, np.argmax(evals)] |
| 125 | + proj = (pts - c_mean) @ axis |
| 126 | + start_voxel = tuple(pts[proj.argmin()]) |
| 127 | + end_voxel = tuple(pts[proj.argmax()]) |
| 128 | + |
| 129 | + # get central path and total distance |
| 130 | + path = central_path_edt_graph(mask, start_voxel, end_voxel) |
| 131 | + path = path * scale_factor |
| 132 | + path = moving_average_3d(path, window=5) |
| 133 | + total_distance = sum([math.dist(path[num + 1], path[num]) for num in range(len(path) - 1)]) |
| 134 | + |
| 135 | + # assign relative distance to points on path |
| 136 | + path_dict = {} |
| 137 | + path_dict[0] = {"pos": path[0], "length_fraction": 0} |
| 138 | + accumulated = 0 |
| 139 | + for num, p in enumerate(path[1:-1]): |
| 140 | + distance = math.dist(path[num], p) |
| 141 | + accumulated += distance |
| 142 | + rel_dist = accumulated / total_distance |
| 143 | + path_dict[num + 1] = {"pos": p, "length_fraction": rel_dist} |
| 144 | + path_dict[len(path)] = {"pos": path[-1], "length_fraction": 1} |
| 145 | + |
| 146 | + return total_distance, path, path_dict |
| 147 | + |
| 148 | + |
| 149 | +def measure_run_length_ihcs(centroids): |
| 150 | + """Measure the run lengths of the IHC segmentation |
| 151 | + by finding the shortest path between the most distant nodes in a Steiner Tree. |
| 152 | +
|
| 153 | + Args: |
| 154 | + centroids: Centroids of SGN segmentation. |
| 155 | +
|
| 156 | + Returns: |
| 157 | + Total distance of the path. |
| 158 | + Path as an nd.array of positions. |
| 159 | + A dictionary containing the position and the length fraction of each point in the path. |
| 160 | + """ |
| 161 | + graph = nx.Graph() |
| 162 | + for num, pos in enumerate(centroids): |
| 163 | + graph.add_node(num, pos=pos) |
| 164 | + # approximate Steiner tree and find shortest path between the two most distant nodes |
| 165 | + terminals = set(graph.nodes()) # All nodes are required |
| 166 | + # Approximate Steiner Tree over all nodes |
| 167 | + T = steiner_tree(graph, terminals) |
| 168 | + u, v = find_most_distant_nodes(T) |
| 169 | + path = nx.shortest_path(T, source=u, target=v) |
| 170 | + total_distance = nx.path_weight(T, path, weight="weight") |
| 171 | + |
| 172 | + # assign relative distance to points on path |
| 173 | + path_dict = {} |
| 174 | + path_dict[0] = {"pos": graph.nodes[path[0]]["pos"], "length_fraction": 0} |
| 175 | + accumulated = 0 |
| 176 | + for num, p in enumerate(path[1:-1]): |
| 177 | + distance = math.dist(graph.nodes[path[num]]["pos"], graph.nodes[p]["pos"]) |
| 178 | + accumulated += distance |
| 179 | + rel_dist = accumulated / total_distance |
| 180 | + path_dict[num + 1] = {"pos": graph.nodes[p]["pos"], "length_fraction": rel_dist} |
| 181 | + path_dict[len(path)] = {"pos": graph.nodes[path[-1]]["pos"], "length_fraction": 1} |
| 182 | + |
| 183 | + return total_distance, path, path_dict |
| 184 | + |
| 185 | + |
| 186 | +def map_frequency(table: pd.DataFrame): |
| 187 | + """Map the frequency range of SGNs in the cochlea |
| 188 | + using Greenwood function f(x) = A * (10 **(ax) - K). |
| 189 | + Values for humans: a=2.1, k=0.88, A = 165.4 [kHz]. |
| 190 | + For mice: fit values between minimal (1kHz) and maximal (80kHz) values |
| 191 | +
|
| 192 | + Args: |
| 193 | + table: Dataframe containing the segmentation. |
| 194 | +
|
| 195 | + Returns: |
| 196 | + Dataframe containing frequency in an additional column 'frequency[kHz]'. |
| 197 | + """ |
| 198 | + var_k = 0.88 |
| 199 | + fmin = 1 |
| 200 | + fmax = 80 |
| 201 | + var_A = fmin / (1 - var_k) |
| 202 | + var_exp = ((fmax + var_A * var_k) / var_A) |
| 203 | + table.loc[table['offset'] >= 0, 'frequency[kHz]'] = var_A * (var_exp ** table["length_fraction"] - var_k) |
| 204 | + table.loc[table['offset'] < 0, 'frequency[kHz]'] = 0 |
| 205 | + |
| 206 | + return table |
| 207 | + |
| 208 | + |
| 209 | +def equidistant_centers( |
| 210 | + table: pd.DataFrame, |
| 211 | + component_label: List[int] = [1], |
| 212 | + cell_type: str = "sgn", |
| 213 | + n_blocks: int = 10, |
| 214 | + offset_blocks: bool = True, |
| 215 | +) -> np.ndarray: |
| 216 | + """Find equidistant centers within the central path of the Rosenthal's canal. |
| 217 | +
|
| 218 | + Args: |
| 219 | + table: Dataframe containing centroids of SGN segmentation. |
| 220 | + component_label: List of components for centroid subset. |
| 221 | + cell_type: Cell type of the segmentation. |
| 222 | + n_blocks: Number of equidistant centers for block creation. |
| 223 | + offset_block: Centers are shifted by half a length if True. Avoid centers at the start/end of the path. |
| 224 | +
|
| 225 | + Returns: |
| 226 | + Equidistant centers as float values |
| 227 | + """ |
| 228 | + # subset of centroids for given component label(s) |
| 229 | + new_subset = table[table["component_labels"].isin(component_label)] |
| 230 | + centroids = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"])) |
| 231 | + |
| 232 | + if cell_type == "ihc": |
| 233 | + total_distance, path, _ = measure_run_length_ihcs(centroids) |
| 234 | + |
| 235 | + else: |
| 236 | + total_distance, path, _ = measure_run_length_sgns(centroids) |
| 237 | + |
| 238 | + diffs = np.diff(path, axis=0) |
| 239 | + seg_lens = np.linalg.norm(diffs, axis=1) |
| 240 | + cum_len = np.insert(np.cumsum(seg_lens), 0, 0) |
| 241 | + if offset_blocks: |
| 242 | + target_s = np.linspace(0, total_distance, n_blocks * 2 + 1) |
| 243 | + target_s = [s for num, s in enumerate(target_s) if num % 2 == 1] |
| 244 | + else: |
| 245 | + target_s = np.linspace(0, total_distance, n_blocks) |
| 246 | + f = interp1d(cum_len, path, axis=0) |
| 247 | + centers = f(target_s) |
| 248 | + return centers |
| 249 | + |
| 250 | + |
| 251 | +def tonotopic_mapping( |
| 252 | + table: pd.DataFrame, |
| 253 | + component_label: List[int] = [1], |
| 254 | + cell_type: str = "ihc" |
| 255 | +) -> pd.DataFrame: |
| 256 | + """Tonotopic mapping of IHCs by supplying a table with component labels. |
| 257 | + The mapping assigns a tonotopic label to each IHC according to the position along the length of the cochlea. |
| 258 | +
|
| 259 | + Args: |
| 260 | + table: Dataframe of segmentation table. |
| 261 | + component_label: List of component labels to evaluate. |
| 262 | + cell_type: Cell type of segmentation. |
| 263 | +
|
| 264 | + Returns: |
| 265 | + Table with tonotopic label for cells. |
| 266 | + """ |
| 267 | + # subset of centroids for given component label(s) |
| 268 | + new_subset = table[table["component_labels"].isin(component_label)] |
| 269 | + centroids = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"])) |
| 270 | + label_ids = [int(i) for i in list(new_subset["label_id"])] |
| 271 | + |
| 272 | + if cell_type == "ihc": |
| 273 | + total_distance, _, path_dict = measure_run_length_ihcs(centroids) |
| 274 | + |
| 275 | + else: |
| 276 | + total_distance, _, path_dict = measure_run_length_sgns(centroids) |
| 277 | + |
| 278 | + # add missing nodes from component and compute distance to path |
| 279 | + node_dict = {} |
| 280 | + for num, c in enumerate(label_ids): |
| 281 | + min_dist = float('inf') |
| 282 | + nearest_node = None |
| 283 | + |
| 284 | + for key in path_dict.keys(): |
| 285 | + dist = math.dist(centroids[num], path_dict[key]["pos"]) |
| 286 | + if dist < min_dist: |
| 287 | + min_dist = dist |
| 288 | + nearest_node = key |
| 289 | + |
| 290 | + node_dict[c] = { |
| 291 | + "label_id": c, |
| 292 | + "length_fraction": path_dict[nearest_node]["length_fraction"], |
| 293 | + "offset": min_dist, |
| 294 | + } |
| 295 | + |
| 296 | + offset = [-1 for _ in range(len(table))] |
| 297 | + # 'label_id' of dataframe starting at 1 |
| 298 | + for key in list(node_dict.keys()): |
| 299 | + offset[int(node_dict[key]["label_id"] - 1)] = node_dict[key]["offset"] |
| 300 | + |
| 301 | + table.loc[:, "offset"] = offset |
| 302 | + |
| 303 | + length_fraction = [0 for _ in range(len(table))] |
| 304 | + for key in list(node_dict.keys()): |
| 305 | + length_fraction[int(node_dict[key]["label_id"] - 1)] = node_dict[key]["length_fraction"] |
| 306 | + |
| 307 | + table.loc[:, "length_fraction"] = length_fraction |
| 308 | + table.loc[:, "length[µm]"] = table["length_fraction"] * total_distance |
| 309 | + |
| 310 | + table = map_frequency(table) |
| 311 | + |
| 312 | + return table |
0 commit comments