diff --git a/flamingo_tools/segmentation/cochlea_mapping.py b/flamingo_tools/segmentation/cochlea_mapping.py new file mode 100644 index 0000000..6a7d1ba --- /dev/null +++ b/flamingo_tools/segmentation/cochlea_mapping.py @@ -0,0 +1,312 @@ +import math +from typing import List, Tuple + +import networkx as nx +import numpy as np +import pandas as pd +from networkx.algorithms.approximation import steiner_tree +from scipy.ndimage import distance_transform_edt, binary_dilation, binary_closing +from scipy.interpolate import interp1d + +from flamingo_tools.segmentation.postprocessing import downscaled_centroids + + +def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -> Tuple[float, float]: + """Find the most distant nodes in a graph. + + Args: + G: Input graph. + + Returns: + Node 1. + Node 2. + """ + all_lengths = dict(nx.all_pairs_dijkstra_path_length(G, weight=weight)) + max_dist = 0 + farthest_pair = (None, None) + + for u, dist_dict in all_lengths.items(): + for v, d in dist_dict.items(): + if d > max_dist: + max_dist = d + farthest_pair = (u, v) + + u, v = farthest_pair + return u, v + + +def central_path_edt_graph(mask: np.ndarray, start: Tuple[int], end: Tuple[int]): + """Find the central path within a binary mask between a start and an end coordinate. + + Args: + mask: Binary mask of volume. + start: Starting coordinate. + end: End coordinate. + + Returns: + Coordinates of central path. + """ + dt = distance_transform_edt(mask) + G = nx.Graph() + shape = mask.shape + def idx_to_node(z, y, x): return z*shape[1]*shape[2] + y*shape[2] + x + border_coords = [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)] + for z in range(shape[0]): + for y in range(shape[1]): + for x in range(shape[2]): + if not mask[z, y, x]: + continue + u = idx_to_node(z, y, x) + for dz, dy, dx in border_coords: + nz, ny, nx_ = z+dz, y+dy, x+dx + if nz >= 0 and nz < shape[0] and mask[nz, ny, nx_]: + v = idx_to_node(nz, ny, nx_) + w = 1.0 / (1e-3 + min(dt[z, y, x], dt[nz, ny, nx_])) + G.add_edge(u, v, weight=w) + s = idx_to_node(*start) + t = idx_to_node(*end) + path = nx.shortest_path(G, source=s, target=t, weight="weight") + coords = [(p//(shape[1]*shape[2]), + (p//shape[2]) % shape[1], + p % shape[2]) for p in path] + return np.array(coords) + + +def moving_average_3d(path: np.ndarray, window: int = 5) -> np.ndarray: + """Smooth a 3D path with a simple moving average filter. + + Args: + path: ndarray of shape (N, 3). + window: half-window size; actual window = 2*window + 1. + + Returns: + smoothed path: ndarray of same shape. + """ + kernel_size = 2 * window + 1 + kernel = np.ones(kernel_size) / kernel_size + + smooth_path = np.zeros_like(path) + + for d in range(3): + pad = np.pad(path[:, d], window, mode='edge') + smooth_path[:, d] = np.convolve(pad, kernel, mode='valid') + + return smooth_path + + +def measure_run_length_sgns(centroids: np.ndarray, scale_factor=10): + """Measure the run lengths of the SGN segmentation by finding a central path through Rosenthal's canal. + 1) Create a binary mask based on down-scaled centroids. + 2) Dilate the mask and close holes to ensure a filled structure. + 3) Determine the endpoints of the structure using the principal axis. + 4) Identify a central path based on the 3D Euclidean distance transform. + 5) The path is up-scaled and smoothed using a moving average filter. + 6) The points of the path are fed into a dictionary along with the fractional length. + + Args: + centroids: Centroids of the SGN segmentation, ndarray of shape (N, 3). + scale_factor: Downscaling factor for finding the central path. + + Returns: + Total distance of the path. + Path as an nd.array of positions. + A dictionary containing the position and the length fraction of each point in the path. + """ + mask = downscaled_centroids(centroids, scale_factor=scale_factor, downsample_mode="capped") + mask = binary_dilation(mask, np.ones((3, 3, 3)), iterations=1) + mask = binary_closing(mask, np.ones((3, 3, 3)), iterations=1) + pts = np.argwhere(mask == 1) + + # find two endpoints: min/max along principal axis + c_mean = pts.mean(axis=0) + cov = np.cov((pts-c_mean).T) + evals, evecs = np.linalg.eigh(cov) + axis = evecs[:, np.argmax(evals)] + proj = (pts - c_mean) @ axis + start_voxel = tuple(pts[proj.argmin()]) + end_voxel = tuple(pts[proj.argmax()]) + + # get central path and total distance + path = central_path_edt_graph(mask, start_voxel, end_voxel) + path = path * scale_factor + path = moving_average_3d(path, window=5) + total_distance = sum([math.dist(path[num + 1], path[num]) for num in range(len(path) - 1)]) + + # assign relative distance to points on path + path_dict = {} + path_dict[0] = {"pos": path[0], "length_fraction": 0} + accumulated = 0 + for num, p in enumerate(path[1:-1]): + distance = math.dist(path[num], p) + accumulated += distance + rel_dist = accumulated / total_distance + path_dict[num + 1] = {"pos": p, "length_fraction": rel_dist} + path_dict[len(path)] = {"pos": path[-1], "length_fraction": 1} + + return total_distance, path, path_dict + + +def measure_run_length_ihcs(centroids): + """Measure the run lengths of the IHC segmentation + by finding the shortest path between the most distant nodes in a Steiner Tree. + + Args: + centroids: Centroids of SGN segmentation. + + Returns: + Total distance of the path. + Path as an nd.array of positions. + A dictionary containing the position and the length fraction of each point in the path. + """ + graph = nx.Graph() + for num, pos in enumerate(centroids): + graph.add_node(num, pos=pos) + # approximate Steiner tree and find shortest path between the two most distant nodes + terminals = set(graph.nodes()) # All nodes are required + # Approximate Steiner Tree over all nodes + T = steiner_tree(graph, terminals) + u, v = find_most_distant_nodes(T) + path = nx.shortest_path(T, source=u, target=v) + total_distance = nx.path_weight(T, path, weight="weight") + + # assign relative distance to points on path + path_dict = {} + path_dict[0] = {"pos": graph.nodes[path[0]]["pos"], "length_fraction": 0} + accumulated = 0 + for num, p in enumerate(path[1:-1]): + distance = math.dist(graph.nodes[path[num]]["pos"], graph.nodes[p]["pos"]) + accumulated += distance + rel_dist = accumulated / total_distance + path_dict[num + 1] = {"pos": graph.nodes[p]["pos"], "length_fraction": rel_dist} + path_dict[len(path)] = {"pos": graph.nodes[path[-1]]["pos"], "length_fraction": 1} + + return total_distance, path, path_dict + + +def map_frequency(table: pd.DataFrame): + """Map the frequency range of SGNs in the cochlea + using Greenwood function f(x) = A * (10 **(ax) - K). + Values for humans: a=2.1, k=0.88, A = 165.4 [kHz]. + For mice: fit values between minimal (1kHz) and maximal (80kHz) values + + Args: + table: Dataframe containing the segmentation. + + Returns: + Dataframe containing frequency in an additional column 'frequency[kHz]'. + """ + var_k = 0.88 + fmin = 1 + fmax = 80 + var_A = fmin / (1 - var_k) + var_exp = ((fmax + var_A * var_k) / var_A) + table.loc[table['offset'] >= 0, 'frequency[kHz]'] = var_A * (var_exp ** table["length_fraction"] - var_k) + table.loc[table['offset'] < 0, 'frequency[kHz]'] = 0 + + return table + + +def equidistant_centers( + table: pd.DataFrame, + component_label: List[int] = [1], + cell_type: str = "sgn", + n_blocks: int = 10, + offset_blocks: bool = True, +) -> np.ndarray: + """Find equidistant centers within the central path of the Rosenthal's canal. + + Args: + table: Dataframe containing centroids of SGN segmentation. + component_label: List of components for centroid subset. + cell_type: Cell type of the segmentation. + n_blocks: Number of equidistant centers for block creation. + offset_block: Centers are shifted by half a length if True. Avoid centers at the start/end of the path. + + Returns: + Equidistant centers as float values + """ + # subset of centroids for given component label(s) + new_subset = table[table["component_labels"].isin(component_label)] + centroids = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"])) + + if cell_type == "ihc": + total_distance, path, _ = measure_run_length_ihcs(centroids) + + else: + total_distance, path, _ = measure_run_length_sgns(centroids) + + diffs = np.diff(path, axis=0) + seg_lens = np.linalg.norm(diffs, axis=1) + cum_len = np.insert(np.cumsum(seg_lens), 0, 0) + if offset_blocks: + target_s = np.linspace(0, total_distance, n_blocks * 2 + 1) + target_s = [s for num, s in enumerate(target_s) if num % 2 == 1] + else: + target_s = np.linspace(0, total_distance, n_blocks) + f = interp1d(cum_len, path, axis=0) + centers = f(target_s) + return centers + + +def tonotopic_mapping( + table: pd.DataFrame, + component_label: List[int] = [1], + cell_type: str = "ihc" +) -> pd.DataFrame: + """Tonotopic mapping of IHCs by supplying a table with component labels. + The mapping assigns a tonotopic label to each IHC according to the position along the length of the cochlea. + + Args: + table: Dataframe of segmentation table. + component_label: List of component labels to evaluate. + cell_type: Cell type of segmentation. + + Returns: + Table with tonotopic label for cells. + """ + # subset of centroids for given component label(s) + new_subset = table[table["component_labels"].isin(component_label)] + centroids = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"])) + label_ids = [int(i) for i in list(new_subset["label_id"])] + + if cell_type == "ihc": + total_distance, _, path_dict = measure_run_length_ihcs(centroids) + + else: + total_distance, _, path_dict = measure_run_length_sgns(centroids) + + # add missing nodes from component and compute distance to path + node_dict = {} + for num, c in enumerate(label_ids): + min_dist = float('inf') + nearest_node = None + + for key in path_dict.keys(): + dist = math.dist(centroids[num], path_dict[key]["pos"]) + if dist < min_dist: + min_dist = dist + nearest_node = key + + node_dict[c] = { + "label_id": c, + "length_fraction": path_dict[nearest_node]["length_fraction"], + "offset": min_dist, + } + + offset = [-1 for _ in range(len(table))] + # 'label_id' of dataframe starting at 1 + for key in list(node_dict.keys()): + offset[int(node_dict[key]["label_id"] - 1)] = node_dict[key]["offset"] + + table.loc[:, "offset"] = offset + + length_fraction = [0 for _ in range(len(table))] + for key in list(node_dict.keys()): + length_fraction[int(node_dict[key]["label_id"] - 1)] = node_dict[key]["length_fraction"] + + table.loc[:, "length_fraction"] = length_fraction + table.loc[:, "length[µm]"] = table["length_fraction"] * total_distance + + table = map_frequency(table) + + return table diff --git a/flamingo_tools/segmentation/postprocessing.py b/flamingo_tools/segmentation/postprocessing.py index 8d21ef8..6f98bd1 100644 --- a/flamingo_tools/segmentation/postprocessing.py +++ b/flamingo_tools/segmentation/postprocessing.py @@ -267,27 +267,30 @@ def erode_subset( def downscaled_centroids( - table: pd.DataFrame, + centroids: np.ndarray, scale_factor: int, ref_dimensions: Optional[Tuple[float, float, float]] = None, + component_labels: Optional[List[int]] = None, downsample_mode: str = "accumulated", ) -> np.typing.NDArray: """Downscale centroids in dataframe. Args: - table: Dataframe of segmentation table. + centroids: Centroids of SGN segmentation, ndarray of shape (N, 3) scale_factor: Factor for downscaling coordinates. ref_dimensions: Reference dimensions for downscaling. Taken from centroids if not supplied. + component_labels: List of component labels, which has to be supplied for the downsampling mode 'components' downsample_mode: Flag for downsampling, either 'accumulated', 'capped', or 'components'. Returns: The downscaled array """ - centroids = list(zip(table["anchor_x"], table["anchor_y"], table["anchor_z"])) centroids_scaled = [(c[0] / scale_factor, c[1] / scale_factor, c[2] / scale_factor) for c in centroids] if ref_dimensions is None: - bounding_dimensions = (max(table["anchor_x"]), max(table["anchor_y"]), max(table["anchor_z"])) + bounding_dimensions = (max([c[0] for c in centroids]), + max([c[1] for c in centroids]), + max([c[2] for c in centroids])) bounding_dimensions_scaled = tuple([round(b // scale_factor + 1) for b in bounding_dimensions]) new_array = np.zeros(bounding_dimensions_scaled) @@ -304,9 +307,8 @@ def downscaled_centroids( new_array[int(c[0]), int(c[1]), int(c[2])] = 1 elif downsample_mode == "components": - if "component_labels" not in table.columns: - raise KeyError("Dataframe must continue key 'component_labels' for downsampling with mode 'components'.") - component_labels = list(table["component_labels"]) + if component_labels is None: + raise KeyError("Component labels must be supplied for downsampling with mode 'components'.") for comp, centr in zip(component_labels, centroids_scaled): if comp != 0: new_array[int(centr[0]), int(centr[1]), int(centr[2])] = comp @@ -319,27 +321,28 @@ def downscaled_centroids( return new_array -def graph_connected_components(coords: dict, min_edge_distance: float, min_component_length: int): +def graph_connected_components(coords: dict, max_edge_distance: float, min_component_length: int): """Create a list of IDs for each connected component of a graph. Args: coords: Dictionary containing label IDs as keys and their position as value. - min_edge_distance: Minimal edge distance between graph nodes to create an edge between nodes. + max_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes. min_component_length: Minimal length of nodes of connected component. Filtered out if lower. Returns: List of dictionary keys of connected components. + Graph of connected components. """ graph = nx.Graph() for num, pos in coords.items(): graph.add_node(num, pos=pos) - # create edges between points whose distance is less than threshold min_edge_distance + # create edges between points whose distance is less than threshold max_edge_distance for num_i, pos_i in coords.items(): for num_j, pos_j in coords.items(): if num_i < num_j: dist = math.dist(pos_i, pos_j) - if dist <= min_edge_distance: + if dist <= max_edge_distance: graph.add_edge(num_i, num_j, weight=dist) components = list(nx.connected_components(graph)) @@ -351,7 +354,10 @@ def graph_connected_components(coords: dict, min_edge_distance: float, min_compo graph.remove_node(c) components = [list(s) for s in nx.connected_components(graph)] - return components + length_components = [len(c) for c in components] + length_components, components = zip(*sorted(zip(length_components, components), reverse=True)) + + return components, graph def components_sgn( @@ -359,7 +365,7 @@ def components_sgn( keyword: str = "distance_nn100", threshold_erode: Optional[float] = None, min_component_length: int = 50, - min_edge_distance: float = 30, + max_edge_distance: float = 30, iterations_erode: Optional[int] = None, postprocess_threshold: Optional[float] = None, postprocess_components: Optional[List[int]] = None, @@ -371,7 +377,7 @@ def components_sgn( keyword: Keyword of the dataframe column for erosion. threshold_erode: Threshold of column value after erosion step with spatial statistics. min_component_length: Minimal length for filtering out connected components. - min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. + max_edge_distance: Maximal distance in micrometer between points to create edges for connected components. iterations_erode: Number of iterations for erosion, normally determined automatically. postprocess_threshold: Post-process graph connected components by searching for points closer than threshold. postprocess_components: Post-process specific graph connected components ([0] for largest component only). @@ -411,10 +417,7 @@ def components_sgn( for index, element in zip(labels_subset, centroids_subset): coords[index] = element - components = graph_connected_components(coords, min_edge_distance, min_component_length) - - length_components = [len(c) for c in components] - length_components, components = zip(*sorted(zip(length_components, components), reverse=True)) + components, _ = graph_connected_components(coords, max_edge_distance, min_component_length) # add original coordinates closer to eroded component than threshold if postprocess_threshold is not None: @@ -447,7 +450,7 @@ def label_components_sgn( min_size: int = 1000, threshold_erode: Optional[float] = None, min_component_length: int = 50, - min_edge_distance: float = 30, + max_edge_distance: float = 30, iterations_erode: Optional[int] = None, postprocess_threshold: Optional[float] = None, postprocess_components: Optional[List[int]] = None, @@ -459,7 +462,7 @@ def label_components_sgn( min_size: Minimal number of pixels for filtering small instances. threshold_erode: Threshold of column value after erosion step with spatial statistics. min_component_length: Minimal length for filtering out connected components. - min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. + max_edge_distance: Maximal distance in micrometer between points to create edges for connected components. iterations_erode: Number of iterations for erosion, normally determined automatically. postprocess_threshold: Post-process graph connected components by searching for points closer than threshold. postprocess_components: Post-process specific graph connected components ([0] for largest component only). @@ -473,7 +476,7 @@ def label_components_sgn( table = table[table.n_pixels >= min_size] components = components_sgn(table, threshold_erode=threshold_erode, min_component_length=min_component_length, - min_edge_distance=min_edge_distance, iterations_erode=iterations_erode, + max_edge_distance=max_edge_distance, iterations_erode=iterations_erode, postprocess_threshold=postprocess_threshold, postprocess_components=postprocess_components) @@ -495,7 +498,7 @@ def postprocess_sgn_seg( min_size: int = 1000, threshold_erode: Optional[float] = None, min_component_length: int = 50, - min_edge_distance: float = 30, + max_edge_distance: float = 30, iterations_erode: Optional[int] = None, ) -> pd.DataFrame: """Postprocessing SGN segmentation of cochlea. @@ -505,7 +508,7 @@ def postprocess_sgn_seg( min_size: Minimal number of pixels for filtering small instances. threshold_erode: Threshold of column value after erosion step with spatial statistics. min_component_length: Minimal length for filtering out connected components. - min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. + max_edge_distance: Maximal distance in micrometer between points to create edges for connected components. iterations_erode: Number of iterations for erosion, normally determined automatically. Returns: @@ -514,7 +517,7 @@ def postprocess_sgn_seg( comp_labels = label_components_sgn(table, min_size=min_size, threshold_erode=threshold_erode, min_component_length=min_component_length, - min_edge_distance=min_edge_distance, iterations_erode=iterations_erode) + max_edge_distance=max_edge_distance, iterations_erode=iterations_erode) table.loc[:, "component_labels"] = comp_labels @@ -524,14 +527,14 @@ def postprocess_sgn_seg( def components_ihc( table: pd.DataFrame, min_component_length: int = 50, - min_edge_distance: float = 30, + max_edge_distance: float = 30, ): """Create connected components for IHC segmentation. Args: table: Dataframe of segmentation table. min_component_length: Minimal length for filtering out connected components. - min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. + max_edge_distance: Maximal distance in micrometer between points to create edges for connected components. Returns: Subgraph components as lists of label_ids of dataframe. @@ -542,7 +545,7 @@ def components_ihc( for index, element in zip(labels, centroids): coords[index] = element - components = graph_connected_components(coords, min_edge_distance, min_component_length) + components, _ = graph_connected_components(coords, max_edge_distance, min_component_length) return components @@ -550,7 +553,7 @@ def label_components_ihc( table: pd.DataFrame, min_size: int = 1000, min_component_length: int = 50, - min_edge_distance: float = 30, + max_edge_distance: float = 30, ) -> List[int]: """Label components using graph connected components. @@ -558,7 +561,7 @@ def label_components_ihc( table: Dataframe of segmentation table. min_size: Minimal number of pixels for filtering small instances. min_component_length: Minimal length for filtering out connected components. - min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. + max_edge_distance: Maximal distance in micrometer between points to create edges for connected components. Returns: List of component label for each point in dataframe. 0 - background, then in descending order of size @@ -569,7 +572,7 @@ def label_components_ihc( table = table[table.n_pixels >= min_size] components = components_ihc(table, min_component_length=min_component_length, - min_edge_distance=min_edge_distance) + max_edge_distance=max_edge_distance) # add size-filtered objects to have same initial length table = pd.concat([table, entries_filtered], ignore_index=True) @@ -591,7 +594,7 @@ def postprocess_ihc_seg( table: pd.DataFrame, min_size: int = 1000, min_component_length: int = 50, - min_edge_distance: float = 30, + max_edge_distance: float = 30, ) -> pd.DataFrame: """Postprocessing IHC segmentation of cochlea. @@ -599,7 +602,7 @@ def postprocess_ihc_seg( table: Dataframe of segmentation table. min_size: Minimal number of pixels for filtering small instances. min_component_length: Minimal length for filtering out connected components. - min_edge_distance: Minimal distance in micrometer between points to create edges for connected components. + max_edge_distance: Maximal distance in micrometer between points to create edges for connected components. Returns: Dataframe with component labels. @@ -607,7 +610,7 @@ def postprocess_ihc_seg( comp_labels = label_components_ihc(table, min_size=min_size, min_component_length=min_component_length, - min_edge_distance=min_edge_distance) + max_edge_distance=max_edge_distance) table.loc[:, "component_labels"] = comp_labels diff --git a/reproducibility/block_extraction/ChReef_MLR144R.json b/reproducibility/block_extraction/ChReef_MLR144R.json new file mode 100644 index 0000000..76c6716 --- /dev/null +++ b/reproducibility/block_extraction/ChReef_MLR144R.json @@ -0,0 +1,47 @@ +[ + { + "cochlea": "M_LR_000144_R", + "image_channel": [ + "PV", + "GFP", + "SGN_v2" + ], + "crop_centers": [ + [ + 1329, + 1080, + 602 + ], + [ + 1220, + 898, + 790 + ], + [ + 1369, + 645, + 732 + ], + [ + 1184, + 629, + 530 + ], + [ + 875, + 622, + 610 + ], + [ + 721, + 431, + 805 + ] + ], + "halo_size": [ + 256, + 256, + 50 + ] + } +] \ No newline at end of file diff --git a/reproducibility/block_extraction/ChReef_MLR145R.json b/reproducibility/block_extraction/ChReef_MLR145R.json new file mode 100644 index 0000000..779701a --- /dev/null +++ b/reproducibility/block_extraction/ChReef_MLR145R.json @@ -0,0 +1,47 @@ +[ + { + "cochlea": "M_LR_000145_R", + "image_channel": [ + "PV", + "GFP", + "SGN_v2" + ], + "crop_centers": [ + [ + 789, + 820, + 670 + ], + [ + 920, + 1025, + 908 + ], + [ + 801, + 1278, + 1090 + ], + [ + 620, + 1449, + 911 + ], + [ + 836, + 1461, + 700 + ], + [ + 959, + 1639, + 922 + ] + ], + "halo_size": [ + 256, + 256, + 50 + ] + } +] \ No newline at end of file diff --git a/reproducibility/block_extraction/ChReef_MLR155R.json b/reproducibility/block_extraction/ChReef_MLR155R.json new file mode 100644 index 0000000..38cfc9d --- /dev/null +++ b/reproducibility/block_extraction/ChReef_MLR155R.json @@ -0,0 +1,47 @@ +[ + { + "cochlea": "M_LR_000155_R", + "image_channel": [ + "PV", + "GFP", + "SGN_v2" + ], + "crop_centers": [ + [ + 1634, + 442, + 633 + ], + [ + 1339, + 548, + 790 + ], + [ + 1016, + 575, + 676 + ], + [ + 1037, + 804, + 470 + ], + [ + 1188, + 1017, + 622 + ], + [ + 921, + 1060, + 752 + ] + ], + "halo_size": [ + 256, + 256, + 50 + ] + } +] \ No newline at end of file diff --git a/reproducibility/block_extraction/repro_equidistant_centers.py b/reproducibility/block_extraction/repro_equidistant_centers.py new file mode 100644 index 0000000..f8c048a --- /dev/null +++ b/reproducibility/block_extraction/repro_equidistant_centers.py @@ -0,0 +1,90 @@ +import argparse +import json +import os +from typing import Optional + +import pandas as pd +from flamingo_tools.s3_utils import get_s3_path +from flamingo_tools.segmentation.cochlea_mapping import equidistant_centers + + +def repro_equidistant_centers( + ddict: dict, + output_path: str, + s3_credentials: Optional[str] = None, + s3_bucket_name: Optional[str] = None, + s3_service_endpoint: Optional[str] = None, + force_overwrite: Optional[bool] = None, +): + default_cell_type = "ihc" + default_component_list = [1] + default_halo_size = [256, 256, 50] + default_n_blocks = 6 + + with open(ddict, 'r') as myfile: + data = myfile.read() + param_dicts = json.loads(data) + + out_dict = [] + + if os.path.isfile(output_path) and not force_overwrite: + print(f"Skipping {output_path}. File already exists.") + + for dic in param_dicts: + cochlea = dic["cochlea"] + img_channel = dic["image_channel"] + seg_channel = dic["segmentation_channel"] + + s3_path = os.path.join(f"{cochlea}", "tables", f"{seg_channel}", "default.tsv") + print(f"Finding equidistant centers for {cochlea}.") + + tsv_path, fs = get_s3_path(s3_path, bucket_name=s3_bucket_name, + service_endpoint=s3_service_endpoint, credential_file=s3_credentials) + with fs.open(tsv_path, 'r') as f: + table = pd.read_csv(f, sep="\t") + + cell_type = dic["type"] if "type" in dic else default_cell_type + component_list = dic["component_list"] if "component_list" in dic else default_component_list + halo_size = dic["halo_size"] if "halo_size" in dic else default_halo_size + n_blocks = dic["n_blocks"] if "n_blocks" in dic else default_n_blocks + + centers = equidistant_centers(table, component_label=component_list, cell_type=cell_type, n_blocks=n_blocks) + centers = [[int(c) for c in center] for center in centers] + ddict = {"cochlea": cochlea} + ddict["image_channel"] = img_channel + ddict["crop_centers"] = centers + ddict["halo_size"] = halo_size + out_dict.append(ddict) + + with open(output_path, "w") as f: + json.dump(out_dict, f, indent='\t', separators=(',', ': ')) + + +def main(): + parser = argparse.ArgumentParser( + description="Script to extract region of interest (ROI) block around center coordinate.") + + parser.add_argument('-i', '--input', type=str, required=True, help="Input JSON dictionary.") + parser.add_argument('-o', "--output", type=str, required=True, help="Output JSON dictionary.") + + parser.add_argument("--force", action="store_true", help="Forcefully overwrite output.") + parser.add_argument("--s3_credentials", type=str, default=None, + help="Input file containing S3 credentials. " + "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") + parser.add_argument("--s3_bucket_name", type=str, default=None, + help="S3 bucket name. Optional if BUCKET_NAME was exported.") + parser.add_argument("--s3_service_endpoint", type=str, default=None, + help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") + + args = parser.parse_args() + + repro_equidistant_centers( + args.input, args.output, + args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint, + args.force, + ) + + +if __name__ == "__main__": + + main() diff --git a/reproducibility/postprocess_sgn/SGN_v1_postprocess.json b/reproducibility/postprocess_sgn/SGN_v1_postprocess.json index 9a21835..3383639 100644 --- a/reproducibility/postprocess_sgn/SGN_v1_postprocess.json +++ b/reproducibility/postprocess_sgn/SGN_v1_postprocess.json @@ -12,7 +12,7 @@ { "cochlea": "M_LR_000144_L", "image_channel": "PV_resized", - "min_edge_distance": 70, + "max_edge_distance": 70, "iterations_erode": 1, "unet_version": "v1" }, diff --git a/reproducibility/postprocess_sgn/repro_postprocess_sgn_v1.py b/reproducibility/postprocess_sgn/repro_postprocess_sgn_v1.py index ee44f77..852ff83 100644 --- a/reproducibility/postprocess_sgn/repro_postprocess_sgn_v1.py +++ b/reproducibility/postprocess_sgn/repro_postprocess_sgn_v1.py @@ -18,7 +18,7 @@ def repro_postprocess_sgn_v1( min_size = 1000 default_threshold_erode = None default_min_length = 50 - default_min_edge_distance = 30 + default_max_edge_distance = 30 default_iterations_erode = None with open(ddict, 'r') as myfile: @@ -36,18 +36,18 @@ def repro_postprocess_sgn_v1( threshold_erode = dic["threshold_erode"] if "threshold_erode" in dic else default_threshold_erode min_component_length = dic["min_component_length"] if "min_component_length" in dic else default_min_length - min_edge_distance = dic["min_edge_distance"] if "min_edge_distance" in dic else default_min_edge_distance + max_edge_distance = dic["max_edge_distance"] if "max_edge_distance" in dic else default_max_edge_distance iterations_erode = dic["iterations_erode"] if "iterations_erode" in dic else default_iterations_erode print("threshold_erode", threshold_erode) print("min_component_length", min_component_length) - print("min_edge", min_edge_distance) + print("max_edge", max_edge_distance) print("iterations_erode", iterations_erode) tsv_table = postprocess_sgn_seg(table, min_size=min_size, threshold_erode=threshold_erode, min_component_length=min_component_length, - min_edge_distance=min_edge_distance, + max_edge_distance=max_edge_distance, iterations_erode=iterations_erode) largest_comp = len(tsv_table[tsv_table["component_labels"] == 1]) diff --git a/reproducibility/tonotopic_mapping/2025-07-IHC_fig2.json b/reproducibility/tonotopic_mapping/2025-07-IHC_fig2.json new file mode 100644 index 0000000..ed0f13f --- /dev/null +++ b/reproducibility/tonotopic_mapping/2025-07-IHC_fig2.json @@ -0,0 +1,22 @@ +[ + { + "cochlea": "M_LR_000226_L", + "segmentation_channel": "IHC_v3", + "type": "ihc" + }, + { + "cochlea": "M_LR_000226_R", + "segmentation_channel": "IHC_v3", + "type": "ihc" + }, + { + "cochlea": "M_LR_000227_L", + "segmentation_channel": "IHC_v3", + "type": "ihc" + }, + { + "cochlea": "M_LR_000227_R", + "segmentation_channel": "IHC_v3", + "type": "ihc" + } +] diff --git a/reproducibility/tonotopic_mapping/2025-07-SGN.json b/reproducibility/tonotopic_mapping/2025-07-SGN.json new file mode 100644 index 0000000..2c24044 --- /dev/null +++ b/reproducibility/tonotopic_mapping/2025-07-SGN.json @@ -0,0 +1,57 @@ +[ + { + "cochlea": "M_AMD_000058_L", + "segmentation_channel": "SGN_v2", + "type": "sgn", + "filter_factor": 0.75 + }, + { + "cochlea": "M_LR_000144_L", + "segmentation_channel": "SGN_resized_v2", + "max_edge_distance": 70, + "type": "sgn", + "filter_factor": 0.75 + }, + { + "cochlea": "M_LR_000144_R", + "segmentation_channel": "SGN_v2", + "type": "sgn", + "filter_factor": 0.75 + }, + { + "cochlea": "M_LR_000145_L", + "segmentation_channel": "SGN_resized_v2", + "type": "sgn", + "filter_factor": 0.75 + }, + { + "cochlea": "M_LR_000151_R", + "segmentation_channel": "SGN_resized_v2", + "type": "sgn", + "filter_factor": 0.75 + }, + { + "cochlea": "M_LR_000155_L", + "segmentation_channel": "SGN_resized_v2", + "type": "sgn", + "filter_factor": 0.75 + }, + { + "cochlea": "M_LR_000155_R", + "segmentation_channel": "SGN_v2", + "type": "sgn", + "filter_factor": 0.75 + }, + { + "cochlea": "M_LR_000167_R", + "segmentation_channel": "SGN_v2", + "type": "sgn", + "filter_factor": 0.75 + }, + { + "cochlea": "M_LR_000184_L", + "segmentation_channel": "SGN_resized_v2", + "type": "sgn", + "filter_factor": 0.75 + } +] diff --git a/reproducibility/tonotopic_mapping/2025-07-SGN_fig2.json b/reproducibility/tonotopic_mapping/2025-07-SGN_fig2.json new file mode 100644 index 0000000..fe52947 --- /dev/null +++ b/reproducibility/tonotopic_mapping/2025-07-SGN_fig2.json @@ -0,0 +1,27 @@ +[ + { + "cochlea": "M_LR_000226_L", + "segmentation_channel": "SGN_v2", + "type": "sgn", + "filter_factor": 0.75 + }, + { + "cochlea": "M_LR_000226_R", + "segmentation_channel": "SGN_v2", + "type": "sgn", + "filter_factor": 0.75 + }, + { + "cochlea": "M_LR_000227_L", + "segmentation_channel": "SGN_v2", + "type": "sgn", + "max_edge_distance": 70, + "filter_factor": 0.75 + }, + { + "cochlea": "M_LR_000227_R", + "segmentation_channel": "SGN_v2", + "type": "sgn", + "filter_factor": 0.75 + } +] diff --git a/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py b/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py new file mode 100644 index 0000000..080ca58 --- /dev/null +++ b/reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py @@ -0,0 +1,92 @@ +import argparse +import json +import os +from typing import Optional + +import pandas as pd +from flamingo_tools.s3_utils import get_s3_path +from flamingo_tools.segmentation.cochlea_mapping import tonotopic_mapping + + +def repro_tonotopic_mapping( + ddict: dict, + output_dir: str, + s3_credentials: Optional[str] = None, + s3_bucket_name: Optional[str] = None, + s3_service_endpoint: Optional[str] = None, + force_overwrite: Optional[bool] = None, +): + default_cell_type = "ihc" + default_component_list = [1] + + remove_columns = ["tonotopic_label", + "tonotopic_value[kHz]", + "distance_to_path[µm]", + "length_fraction", + "run_length[µm]", + "centrality"] + + with open(ddict, 'r') as myfile: + data = myfile.read() + param_dicts = json.loads(data) + + for dic in param_dicts: + cochlea = dic["cochlea"] + seg_channel = dic["segmentation_channel"] + + cochlea_str = "-".join(cochlea.split("_")) + seg_str = "-".join(seg_channel.split("_")) + output_table_path = os.path.join(output_dir, f"{cochlea_str}_{seg_str}.tsv") + + s3_path = os.path.join(f"{cochlea}", "tables", f"{seg_channel}", "default.tsv") + print(f"Tonotopic mapping for {cochlea}.") + + tsv_path, fs = get_s3_path(s3_path, bucket_name=s3_bucket_name, + service_endpoint=s3_service_endpoint, credential_file=s3_credentials) + with fs.open(tsv_path, 'r') as f: + table = pd.read_csv(f, sep="\t") + + cell_type = dic["type"] if "type" in dic else default_cell_type + component_list = dic["component_list"] if "component_list" in dic else default_component_list + + for column in remove_columns: + if column in list(table.columns): + table = table.drop(column, axis=1) + + if not os.path.isfile(output_table_path) or force_overwrite: + table = tonotopic_mapping(table, component_label=component_list, cell_type=cell_type) + + table.to_csv(output_table_path, sep="\t", index=False) + + else: + print(f"Skipping {output_table_path}. Table already exists.") + + +def main(): + parser = argparse.ArgumentParser( + description="Script to extract region of interest (ROI) block around center coordinate.") + + parser.add_argument('-i', '--input', type=str, required=True, help="Input JSON dictionary.") + parser.add_argument('-o', "--output", type=str, required=True, help="Output directory.") + + parser.add_argument("--force", action="store_true", help="Forcefully overwrite output.") + parser.add_argument("--s3_credentials", type=str, default=None, + help="Input file containing S3 credentials. " + "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") + parser.add_argument("--s3_bucket_name", type=str, default=None, + help="S3 bucket name. Optional if BUCKET_NAME was exported.") + parser.add_argument("--s3_service_endpoint", type=str, default=None, + help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") + + args = parser.parse_args() + + repro_tonotopic_mapping( + args.input, args.output, + args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint, + args.force, + ) + + +if __name__ == "__main__": + + main() diff --git a/scripts/prediction/postprocess_seg.py b/scripts/prediction/postprocess_seg.py index 66735ab..0cfa7fc 100644 --- a/scripts/prediction/postprocess_seg.py +++ b/scripts/prediction/postprocess_seg.py @@ -37,8 +37,8 @@ def main(): help="Threshold for spatial statistics.") parser.add_argument("--min_component_length", type=int, default=50, help="Minimal length for filtering out connected components.") - parser.add_argument("--min_edge_dist", type=float, default=30, - help="Minimal distance in micrometer between points to create edges for connected components.") + parser.add_argument("--max_edge_dist", type=float, default=30, + help="Maximal distance in micrometer between points to create edges for connected components.") parser.add_argument("--iterations_erode", type=int, default=None, help="Number of iterations for erosion, normally determined automatically.") @@ -126,7 +126,7 @@ def create_spatial_statistics_dict(functions, keyword, options, threshold): if seg_path is None: post_table = postprocess_sgn_seg( tsv_table.copy(), min_size=args.min_size, threshold_erode=args.threshold, - min_component_length=args.min_component_length, min_edge_distance=args.min_edge_dist, + min_component_length=args.min_component_length, max_edge_distance=args.max_edge_dist, iterations_erode=args.iterations_erode, ) diff --git a/scripts/prediction/tonotopic_mapping.py b/scripts/prediction/tonotopic_mapping.py new file mode 100644 index 0000000..f5d9f31 --- /dev/null +++ b/scripts/prediction/tonotopic_mapping.py @@ -0,0 +1,49 @@ +import argparse + +import pandas as pd + +import flamingo_tools.s3_utils as s3_utils +from flamingo_tools.segmentation.cochlea_mapping import tonotopic_mapping + + +def main(): + + parser = argparse.ArgumentParser( + description="Script for the tonotopic mapping of IHCs and SGNs. " + "Either locally or on an S3 bucket.") + + parser.add_argument("-i", "--input", required=True, help="Input table with IHC segmentation.") + parser.add_argument("-o", "--output", required=True, help="Output path for json file with cropping parameters.") + + parser.add_argument("-t", "--type", type=str, default="sgn", help="Cell type of segmentation.") + + parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.") + parser.add_argument("--s3_credentials", type=str, default=None, + help="Input file containing S3 credentials. " + "Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.") + parser.add_argument("--s3_bucket_name", type=str, default=None, + help="S3 bucket name. Optional if BUCKET_NAME was exported.") + parser.add_argument("--s3_service_endpoint", type=str, default=None, + help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.") + + args = parser.parse_args() + + if args.s3: + tsv_path, fs = s3_utils.get_s3_path(args.input, bucket_name=args.s3_bucket_name, + service_endpoint=args.s3_service_endpoint, + credential_file=args.s3_credentials) + with fs.open(tsv_path, 'r') as f: + tsv_table = pd.read_csv(f, sep="\t") + else: + with open(args.input, 'r') as f: + tsv_table = pd.read_csv(f, sep="\t") + + table = tonotopic_mapping( + tsv_table, cell_type=args.type, + ) + + table.to_csv(args.output, sep="\t", index=False) + + +if __name__ == "__main__": + main()