Skip to content
141 changes: 141 additions & 0 deletions flamingo_tools/segmentation/cochlea_mapping.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment, but we could create a new submodule measurement for this, and then put this into a file called tonotopic mapping. We can do this later, as I want to re-organize the code a bit anyways.

Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import math
import warnings
from typing import List, Optional, Tuple

import networkx as nx
import numpy as np
import pandas as pd
from networkx.algorithms.approximation import steiner_tree

from flamingo_tools.segmentation.postprocessing import graph_connected_components


def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -> Tuple[float, float]:
all_lengths = dict(nx.all_pairs_dijkstra_path_length(G, weight=weight))
max_dist = 0
farthest_pair = (None, None)

for u, dist_dict in all_lengths.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: it should be possible to vectorize this by mappling the result from all_length to a numpy array and then using np.argmax. (This is likely not a bottleneck, so doesn't really matter that much.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bottleneck of this function is nx.all_pairs_dijkstra_path_length(G, weight=weight), so I will try to find a faster alternative if I have time.

for v, d in dist_dict.items():
if d > max_dist:
max_dist = d
farthest_pair = (u, v)

u, v = farthest_pair
return u, v


def tonotopic_mapping(
table: pd.DataFrame,
component_label: List[int] = [1],
max_edge_distance: float = 30,
min_component_length: int = 50,
cell_type: str = "ihc",
filter_factor: Optional[float] = None
) -> pd.DataFrame:
"""Tonotopic mapping of IHCs by supplying a table with component labels.
The mapping assigns a tonotopic label to each IHC according to the position along the length of the cochlea.

Args:
table: Dataframe of segmentation table.
component_label: List of component labels to evaluate.
max_edge_distance: Maximal edge distance to connect nodes.
min_component_length: Minimal number of nodes in component.
cell_type: Cell type of segmentation.
Filter factor: Fraction of nodes to remove before mapping.

Returns:
Table with tonotopic label for cells.
"""
weight = "weight"
# subset of centroids for given component label(s)
new_subset = table[table["component_labels"].isin(component_label)]
comp_label_ids = list(new_subset["label_id"])
centroids_subset = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"]))
labels_subset = [int(i) for i in list(new_subset["label_id"])]

# create graph with connected components
coords = {}
for index, element in zip(labels_subset, centroids_subset):
coords[index] = element

components, graph = graph_connected_components(coords, max_edge_distance, min_component_length)
if len(components) > 1:
warnings.warn(f"There are {len(components)} connected components, expected 1. "
"Check parameters for post-processing (max_edge_distance, min_component_length).")

unfiltered_graph = graph.copy()

if filter_factor is not None:
if 0 <= filter_factor < 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: in general it would be safer to use a "spatially stratified" sample here. (For the current application in SGNs the random sample is probably fine though). Here are some suggestions from ChatGPT: https://chatgpt.com/share/6873d42d-d99c-8000-a0d3-b56efbd69ec8

rng = np.random.default_rng(seed=1234)
original_array = np.array(comp_label_ids)
target_length = int(len(original_array) * filter_factor)
filtered_list = list(rng.choice(original_array, size=target_length, replace=False))
for filter_id in filtered_list:
graph.remove_node(filter_id)
else:
raise ValueError(f"Invalid filter factor {filter_factor}. Choose a filter factor between 0 and 1.")

u, v = find_most_distant_nodes(graph)

if not nx.has_path(graph, source=u, target=v) or cell_type == "ihc":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the logic between the two branches here?

Copy link
Contributor Author

@schilling40 schilling40 Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea behind the two branches were the different use cases of IHC and SGN segmentation. While the Steiner Tree made sense for the IHC, where we expect a single line of cells, the application of the shortest path, which might skip intermediate cells, if they lie below the maximal edge distance of edges between nodes, seemed logical for SGNs. However, I can see that the distinction is probably not necessary, especially with the upcoming incorporation of the centrality for SGNs.

# approximate Steiner tree and find shortest path between the two most distant nodes
terminals = set(graph.nodes()) # All nodes are required
# Approximate Steiner Tree over all nodes
T = steiner_tree(graph, terminals, weight=weight)
path = nx.shortest_path(T, source=u, target=v, weight=weight)
total_distance = nx.path_weight(T, path, weight=weight)

else:
path = nx.shortest_path(graph, source=u, target=v, weight=weight)
total_distance = nx.path_weight(graph, path, weight=weight)

# assign relative distance to nodes on path
path_list = {}
path_list[path[0]] = {"label_id": path[0], "tonotopic": 0}
accumulated = 0
for num, p in enumerate(path[1:-1]):
distance = graph.get_edge_data(path[num], p)["weight"]
accumulated += distance
rel_dist = accumulated / total_distance
path_list[p] = {"label_id": p, "tonotopic": rel_dist}
path_list[path[-1]] = {"label_id": path[-1], "tonotopic": 1}

# add missing nodes from component
pos = nx.get_node_attributes(unfiltered_graph, 'pos')
for c in comp_label_ids:
if c not in path:
min_dist = float('inf')
nearest_node = None

for p in path:
dist = math.dist(pos[c], pos[p])
if dist < min_dist:
min_dist = dist
nearest_node = p

path_list[c] = {"label_id": c, "tonotopic": path_list[nearest_node]["tonotopic"]}

# label in micrometer
tonotopic = [0 for _ in range(len(table))]
# be aware of 'label_id' of dataframe starting at 1
for key in list(path_list.keys()):
tonotopic[int(path_list[key]["label_id"] - 1)] = path_list[key]["tonotopic"] * total_distance

table.loc[:, "tonotopic_label"] = tonotopic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think everything until here should be refactored into the function that measures the run-length across the helix.


# map frequency using Greenwood function f(x) = A * (10 **(ax) - K), for humans: a=2.1, k=0.88, A = 165.4 [kHz]
tonotopic_map = [0 for _ in range(len(table))]
var_k = 0.88
# calculate values to fit (assumed) minimal (1kHz) and maximal (80kHz) hearing range of mice at x=0, x=1
fmin = 1
fmax = 80
var_A = fmin / (1 - var_k)
var_exp = ((fmax + var_A * var_k) / var_A)
for key in list(path_list.keys()):
tonotopic_map[int(path_list[key]["label_id"] - 1)] = var_A * (var_exp ** path_list[key]["tonotopic"] - var_k)

table.loc[:, "tonotopic_value[kHz]"] = tonotopic_map
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this part then becomes the other function for tonotopic mapping.


return table
53 changes: 27 additions & 26 deletions flamingo_tools/segmentation/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,27 +319,28 @@ def downscaled_centroids(
return new_array


def graph_connected_components(coords: dict, min_edge_distance: float, min_component_length: int):
def graph_connected_components(coords: dict, max_edge_distance: float, min_component_length: int):
"""Create a list of IDs for each connected component of a graph.

Args:
coords: Dictionary containing label IDs as keys and their position as value.
min_edge_distance: Minimal edge distance between graph nodes to create an edge between nodes.
max_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes.
min_component_length: Minimal length of nodes of connected component. Filtered out if lower.

Returns:
List of dictionary keys of connected components.
Graph of connected components.
"""
graph = nx.Graph()
for num, pos in coords.items():
graph.add_node(num, pos=pos)

# create edges between points whose distance is less than threshold min_edge_distance
# create edges between points whose distance is less than threshold max_edge_distance
for num_i, pos_i in coords.items():
for num_j, pos_j in coords.items():
if num_i < num_j:
dist = math.dist(pos_i, pos_j)
if dist <= min_edge_distance:
if dist <= max_edge_distance:
graph.add_edge(num_i, num_j, weight=dist)

components = list(nx.connected_components(graph))
Expand All @@ -351,15 +352,18 @@ def graph_connected_components(coords: dict, min_edge_distance: float, min_compo
graph.remove_node(c)

components = [list(s) for s in nx.connected_components(graph)]
return components
length_components = [len(c) for c in components]
length_components, components = zip(*sorted(zip(length_components, components), reverse=True))

return components, graph


def components_sgn(
table: pd.DataFrame,
keyword: str = "distance_nn100",
threshold_erode: Optional[float] = None,
min_component_length: int = 50,
min_edge_distance: float = 30,
max_edge_distance: float = 30,
iterations_erode: Optional[int] = None,
postprocess_threshold: Optional[float] = None,
postprocess_components: Optional[List[int]] = None,
Expand All @@ -371,7 +375,7 @@ def components_sgn(
keyword: Keyword of the dataframe column for erosion.
threshold_erode: Threshold of column value after erosion step with spatial statistics.
min_component_length: Minimal length for filtering out connected components.
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
max_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
iterations_erode: Number of iterations for erosion, normally determined automatically.
postprocess_threshold: Post-process graph connected components by searching for points closer than threshold.
postprocess_components: Post-process specific graph connected components ([0] for largest component only).
Expand Down Expand Up @@ -411,10 +415,7 @@ def components_sgn(
for index, element in zip(labels_subset, centroids_subset):
coords[index] = element

components = graph_connected_components(coords, min_edge_distance, min_component_length)

length_components = [len(c) for c in components]
length_components, components = zip(*sorted(zip(length_components, components), reverse=True))
components, _ = graph_connected_components(coords, max_edge_distance, min_component_length)

# add original coordinates closer to eroded component than threshold
if postprocess_threshold is not None:
Expand Down Expand Up @@ -447,7 +448,7 @@ def label_components_sgn(
min_size: int = 1000,
threshold_erode: Optional[float] = None,
min_component_length: int = 50,
min_edge_distance: float = 30,
max_edge_distance: float = 30,
iterations_erode: Optional[int] = None,
postprocess_threshold: Optional[float] = None,
postprocess_components: Optional[List[int]] = None,
Expand All @@ -459,7 +460,7 @@ def label_components_sgn(
min_size: Minimal number of pixels for filtering small instances.
threshold_erode: Threshold of column value after erosion step with spatial statistics.
min_component_length: Minimal length for filtering out connected components.
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
max_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
iterations_erode: Number of iterations for erosion, normally determined automatically.
postprocess_threshold: Post-process graph connected components by searching for points closer than threshold.
postprocess_components: Post-process specific graph connected components ([0] for largest component only).
Expand All @@ -473,7 +474,7 @@ def label_components_sgn(
table = table[table.n_pixels >= min_size]

components = components_sgn(table, threshold_erode=threshold_erode, min_component_length=min_component_length,
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode,
max_edge_distance=max_edge_distance, iterations_erode=iterations_erode,
postprocess_threshold=postprocess_threshold,
postprocess_components=postprocess_components)

Expand All @@ -495,7 +496,7 @@ def postprocess_sgn_seg(
min_size: int = 1000,
threshold_erode: Optional[float] = None,
min_component_length: int = 50,
min_edge_distance: float = 30,
max_edge_distance: float = 30,
iterations_erode: Optional[int] = None,
) -> pd.DataFrame:
"""Postprocessing SGN segmentation of cochlea.
Expand All @@ -505,7 +506,7 @@ def postprocess_sgn_seg(
min_size: Minimal number of pixels for filtering small instances.
threshold_erode: Threshold of column value after erosion step with spatial statistics.
min_component_length: Minimal length for filtering out connected components.
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
max_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
iterations_erode: Number of iterations for erosion, normally determined automatically.

Returns:
Expand All @@ -514,7 +515,7 @@ def postprocess_sgn_seg(

comp_labels = label_components_sgn(table, min_size=min_size, threshold_erode=threshold_erode,
min_component_length=min_component_length,
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)
max_edge_distance=max_edge_distance, iterations_erode=iterations_erode)

table.loc[:, "component_labels"] = comp_labels

Expand All @@ -524,14 +525,14 @@ def postprocess_sgn_seg(
def components_ihc(
table: pd.DataFrame,
min_component_length: int = 50,
min_edge_distance: float = 30,
max_edge_distance: float = 30,
):
"""Create connected components for IHC segmentation.

Args:
table: Dataframe of segmentation table.
min_component_length: Minimal length for filtering out connected components.
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
max_edge_distance: Maximal distance in micrometer between points to create edges for connected components.

Returns:
Subgraph components as lists of label_ids of dataframe.
Expand All @@ -542,23 +543,23 @@ def components_ihc(
for index, element in zip(labels, centroids):
coords[index] = element

components = graph_connected_components(coords, min_edge_distance, min_component_length)
components, _ = graph_connected_components(coords, max_edge_distance, min_component_length)
return components


def label_components_ihc(
table: pd.DataFrame,
min_size: int = 1000,
min_component_length: int = 50,
min_edge_distance: float = 30,
max_edge_distance: float = 30,
) -> List[int]:
"""Label components using graph connected components.

Args:
table: Dataframe of segmentation table.
min_size: Minimal number of pixels for filtering small instances.
min_component_length: Minimal length for filtering out connected components.
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
max_edge_distance: Maximal distance in micrometer between points to create edges for connected components.

Returns:
List of component label for each point in dataframe. 0 - background, then in descending order of size
Expand All @@ -569,7 +570,7 @@ def label_components_ihc(
table = table[table.n_pixels >= min_size]

components = components_ihc(table, min_component_length=min_component_length,
min_edge_distance=min_edge_distance)
max_edge_distance=max_edge_distance)

# add size-filtered objects to have same initial length
table = pd.concat([table, entries_filtered], ignore_index=True)
Expand All @@ -591,23 +592,23 @@ def postprocess_ihc_seg(
table: pd.DataFrame,
min_size: int = 1000,
min_component_length: int = 50,
min_edge_distance: float = 30,
max_edge_distance: float = 30,
) -> pd.DataFrame:
"""Postprocessing IHC segmentation of cochlea.

Args:
table: Dataframe of segmentation table.
min_size: Minimal number of pixels for filtering small instances.
min_component_length: Minimal length for filtering out connected components.
min_edge_distance: Minimal distance in micrometer between points to create edges for connected components.
max_edge_distance: Maximal distance in micrometer between points to create edges for connected components.

Returns:
Dataframe with component labels.
"""

comp_labels = label_components_ihc(table, min_size=min_size,
min_component_length=min_component_length,
min_edge_distance=min_edge_distance)
max_edge_distance=max_edge_distance)

table.loc[:, "component_labels"] = comp_labels

Expand Down
2 changes: 1 addition & 1 deletion reproducibility/postprocess_sgn/SGN_v1_postprocess.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
{
"cochlea": "M_LR_000144_L",
"image_channel": "PV_resized",
"min_edge_distance": 70,
"max_edge_distance": 70,
"iterations_erode": 1,
"unet_version": "v1"
},
Expand Down
8 changes: 4 additions & 4 deletions reproducibility/postprocess_sgn/repro_postprocess_sgn_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def repro_postprocess_sgn_v1(
min_size = 1000
default_threshold_erode = None
default_min_length = 50
default_min_edge_distance = 30
default_max_edge_distance = 30
default_iterations_erode = None

with open(ddict, 'r') as myfile:
Expand All @@ -36,18 +36,18 @@ def repro_postprocess_sgn_v1(

threshold_erode = dic["threshold_erode"] if "threshold_erode" in dic else default_threshold_erode
min_component_length = dic["min_component_length"] if "min_component_length" in dic else default_min_length
min_edge_distance = dic["min_edge_distance"] if "min_edge_distance" in dic else default_min_edge_distance
max_edge_distance = dic["max_edge_distance"] if "max_edge_distance" in dic else default_max_edge_distance
iterations_erode = dic["iterations_erode"] if "iterations_erode" in dic else default_iterations_erode

print("threshold_erode", threshold_erode)
print("min_component_length", min_component_length)
print("min_edge", min_edge_distance)
print("max_edge", max_edge_distance)
print("iterations_erode", iterations_erode)

tsv_table = postprocess_sgn_seg(table, min_size=min_size,
threshold_erode=threshold_erode,
min_component_length=min_component_length,
min_edge_distance=min_edge_distance,
max_edge_distance=max_edge_distance,
iterations_erode=iterations_erode)

largest_comp = len(tsv_table[tsv_table["component_labels"] == 1])
Expand Down
Loading