Skip to content

Commit 7923490

Browse files
committed
Tonotopic mapping for SGN (filtered) and IHC
1 parent c273c85 commit 7923490

File tree

7 files changed

+266
-53
lines changed

7 files changed

+266
-53
lines changed

flamingo_tools/segmentation/cochlea_mapping.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import math
2+
from typing import List, Optional, Tuple
23

34
import networkx as nx
5+
import numpy as np
6+
import pandas as pd
47
from networkx.algorithms.approximation import steiner_tree
58

69
from flamingo_tools.segmentation.postprocessing import graph_connected_components
710

811

9-
def find_most_distant_nodes(G, weight='weight'):
12+
def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -> Tuple[float, float]:
1013
all_lengths = dict(nx.all_pairs_dijkstra_path_length(G, weight=weight))
1114
max_dist = 0
1215
farthest_pair = (None, None)
@@ -21,32 +24,29 @@ def find_most_distant_nodes(G, weight='weight'):
2124
return u, v
2225

2326

24-
def steiner_path_between_distant_nodes(G, weight='weight'):
25-
# Step 1: Find the most distant pair of nodes
26-
u, v = find_most_distant_nodes(G, weight=weight)
27-
terminals = set(G.nodes()) # All nodes are required
28-
29-
# Step 2: Approximate Steiner Tree over all nodes
30-
T = steiner_tree(G, terminals, weight=weight)
31-
32-
# Step 3: Find the shortest path between u and v in the Steiner Tree
33-
path = nx.shortest_path(T, source=u, target=v, weight=weight)
34-
total_weight = nx.path_weight(T, path, weight=weight)
35-
36-
return {
37-
"start": u,
38-
"end": v,
39-
"path": path,
40-
"total_weight": total_weight,
41-
"steiner_tree": T
42-
}
43-
44-
45-
def tonotopic_mapping(table, component_label=[1], min_edge_distance=30, min_component_length=50,
46-
cell_type="ihc", weight='weight'):
27+
def tonotopic_mapping(
28+
table: pd.DataFrame,
29+
component_label: List[int] = [1],
30+
max_edge_distance: float = 30,
31+
min_component_length: int = 50,
32+
cell_type: str = "ihc",
33+
filter_factor: Optional[float] = None
34+
) -> pd.DataFrame:
4735
"""Tonotopic mapping of IHCs by supplying a table with component labels.
4836
The mapping assigns a tonotopic label to each IHC according to the position along the length of the cochlea.
37+
38+
Args:
39+
table: Dataframe of segmentation table.
40+
component_label: List of component labels to evaluate.
41+
max_edge_distance: Maximal edge distance to connect nodes.
42+
min_component_length: Minimal number of nodes in component.
43+
cell_type: Cell type of segmentation.
44+
Filter factor: Fraction of nodes to remove before mapping.
45+
46+
Returns:
47+
Table with tonotopic label for cells.
4948
"""
49+
weight = "weight"
5050
# subset of centroids for given component label(s)
5151
new_subset = table[table["component_labels"].isin(component_label)]
5252
comp_label_ids = list(new_subset["label_id"])
@@ -58,12 +58,25 @@ def tonotopic_mapping(table, component_label=[1], min_edge_distance=30, min_comp
5858
for index, element in zip(labels_subset, centroids_subset):
5959
coords[index] = element
6060

61-
components, graph = graph_connected_components(coords, min_edge_distance, min_component_length)
61+
_, graph = graph_connected_components(coords, max_edge_distance, min_component_length)
6262

63-
# approximate Steiner tree and find shortest path between the two most distant nodes
63+
unfiltered_graph = graph.copy()
64+
65+
if filter_factor is not None:
66+
if 0 < filter_factor < 1:
67+
rng = np.random.default_rng(seed=1234)
68+
original_array = np.array(comp_label_ids)
69+
target_length = int(len(original_array) * filter_factor)
70+
filtered_list = list(rng.choice(original_array, size=target_length, replace=False))
71+
for filter_id in filtered_list:
72+
graph.remove_node(filter_id)
73+
else:
74+
raise ValueError(f"Invalid filter factor {filter_factor}. Choose a filter factor between 0 and 1.")
6475

6576
u, v = find_most_distant_nodes(graph)
66-
if cell_type == "ihc":
77+
78+
if not nx.has_path(graph, source=u, target=v) or cell_type == "ihc":
79+
# approximate Steiner tree and find shortest path between the two most distant nodes
6780
terminals = set(graph.nodes()) # All nodes are required
6881
# Approximate Steiner Tree over all nodes
6982
T = steiner_tree(graph, terminals, weight=weight)
@@ -86,7 +99,7 @@ def tonotopic_mapping(table, component_label=[1], min_edge_distance=30, min_comp
8699
path_list[path[-1]] = {"label_id": path[-1], "tonotopic": 1}
87100

88101
# add missing nodes from component
89-
pos = nx.get_node_attributes(graph, 'pos')
102+
pos = nx.get_node_attributes(unfiltered_graph, 'pos')
90103
for c in comp_label_ids:
91104
if c not in path:
92105
min_dist = float('inf')
@@ -102,8 +115,8 @@ def tonotopic_mapping(table, component_label=[1], min_edge_distance=30, min_comp
102115

103116
tonotopic = [0 for _ in range(len(table))]
104117
# be aware of 'label_id' of dataframe starting at 1
105-
for d in path_list:
106-
tonotopic[d["label_id"] - 1] = d["value"] * total_distance
118+
for key in list(path_list.keys()):
119+
tonotopic[int(path_list[key]["label_id"] - 1)] = path_list[key]["tonotopic"] * total_distance
107120

108121
table.loc[:, "tonotopic_label"] = tonotopic
109122

flamingo_tools/segmentation/postprocessing.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -319,12 +319,12 @@ def downscaled_centroids(
319319
return new_array
320320

321321

322-
def graph_connected_components(coords: dict, min_edge_distance: float, min_component_length: int):
322+
def graph_connected_components(coords: dict, max_edge_distance: float, min_component_length: int):
323323
"""Create a list of IDs for each connected component of a graph.
324324
325325
Args:
326326
coords: Dictionary containing label IDs as keys and their position as value.
327-
min_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes.
327+
max_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes.
328328
min_component_length: Minimal length of nodes of connected component. Filtered out if lower.
329329
330330
Returns:
@@ -335,12 +335,12 @@ def graph_connected_components(coords: dict, min_edge_distance: float, min_compo
335335
for num, pos in coords.items():
336336
graph.add_node(num, pos=pos)
337337

338-
# create edges between points whose distance is less than threshold min_edge_distance
338+
# create edges between points whose distance is less than threshold max_edge_distance
339339
for num_i, pos_i in coords.items():
340340
for num_j, pos_j in coords.items():
341341
if num_i < num_j:
342342
dist = math.dist(pos_i, pos_j)
343-
if dist <= min_edge_distance:
343+
if dist <= max_edge_distance:
344344
graph.add_edge(num_i, num_j, weight=dist)
345345

346346
components = list(nx.connected_components(graph))
@@ -360,7 +360,7 @@ def components_sgn(
360360
keyword: str = "distance_nn100",
361361
threshold_erode: Optional[float] = None,
362362
min_component_length: int = 50,
363-
min_edge_distance: float = 30,
363+
max_edge_distance: float = 30,
364364
iterations_erode: Optional[int] = None,
365365
postprocess_threshold: Optional[float] = None,
366366
postprocess_components: Optional[List[int]] = None,
@@ -372,7 +372,7 @@ def components_sgn(
372372
keyword: Keyword of the dataframe column for erosion.
373373
threshold_erode: Threshold of column value after erosion step with spatial statistics.
374374
min_component_length: Minimal length for filtering out connected components.
375-
min_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
375+
max_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
376376
iterations_erode: Number of iterations for erosion, normally determined automatically.
377377
postprocess_threshold: Post-process graph connected components by searching for points closer than threshold.
378378
postprocess_components: Post-process specific graph connected components ([0] for largest component only).
@@ -412,7 +412,7 @@ def components_sgn(
412412
for index, element in zip(labels_subset, centroids_subset):
413413
coords[index] = element
414414

415-
components, _ = graph_connected_components(coords, min_edge_distance, min_component_length)
415+
components, _ = graph_connected_components(coords, max_edge_distance, min_component_length)
416416

417417
length_components = [len(c) for c in components]
418418
length_components, components = zip(*sorted(zip(length_components, components), reverse=True))
@@ -448,7 +448,7 @@ def label_components_sgn(
448448
min_size: int = 1000,
449449
threshold_erode: Optional[float] = None,
450450
min_component_length: int = 50,
451-
min_edge_distance: float = 30,
451+
max_edge_distance: float = 30,
452452
iterations_erode: Optional[int] = None,
453453
postprocess_threshold: Optional[float] = None,
454454
postprocess_components: Optional[List[int]] = None,
@@ -460,7 +460,7 @@ def label_components_sgn(
460460
min_size: Minimal number of pixels for filtering small instances.
461461
threshold_erode: Threshold of column value after erosion step with spatial statistics.
462462
min_component_length: Minimal length for filtering out connected components.
463-
min_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
463+
max_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
464464
iterations_erode: Number of iterations for erosion, normally determined automatically.
465465
postprocess_threshold: Post-process graph connected components by searching for points closer than threshold.
466466
postprocess_components: Post-process specific graph connected components ([0] for largest component only).
@@ -474,7 +474,7 @@ def label_components_sgn(
474474
table = table[table.n_pixels >= min_size]
475475

476476
components = components_sgn(table, threshold_erode=threshold_erode, min_component_length=min_component_length,
477-
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode,
477+
max_edge_distance=max_edge_distance, iterations_erode=iterations_erode,
478478
postprocess_threshold=postprocess_threshold,
479479
postprocess_components=postprocess_components)
480480

@@ -496,7 +496,7 @@ def postprocess_sgn_seg(
496496
min_size: int = 1000,
497497
threshold_erode: Optional[float] = None,
498498
min_component_length: int = 50,
499-
min_edge_distance: float = 30,
499+
max_edge_distance: float = 30,
500500
iterations_erode: Optional[int] = None,
501501
) -> pd.DataFrame:
502502
"""Postprocessing SGN segmentation of cochlea.
@@ -506,7 +506,7 @@ def postprocess_sgn_seg(
506506
min_size: Minimal number of pixels for filtering small instances.
507507
threshold_erode: Threshold of column value after erosion step with spatial statistics.
508508
min_component_length: Minimal length for filtering out connected components.
509-
min_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
509+
max_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
510510
iterations_erode: Number of iterations for erosion, normally determined automatically.
511511
512512
Returns:
@@ -515,7 +515,7 @@ def postprocess_sgn_seg(
515515

516516
comp_labels = label_components_sgn(table, min_size=min_size, threshold_erode=threshold_erode,
517517
min_component_length=min_component_length,
518-
min_edge_distance=min_edge_distance, iterations_erode=iterations_erode)
518+
max_edge_distance=max_edge_distance, iterations_erode=iterations_erode)
519519

520520
table.loc[:, "component_labels"] = comp_labels
521521

@@ -525,14 +525,14 @@ def postprocess_sgn_seg(
525525
def components_ihc(
526526
table: pd.DataFrame,
527527
min_component_length: int = 50,
528-
min_edge_distance: float = 30,
528+
max_edge_distance: float = 30,
529529
):
530530
"""Create connected components for IHC segmentation.
531531
532532
Args:
533533
table: Dataframe of segmentation table.
534534
min_component_length: Minimal length for filtering out connected components.
535-
min_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
535+
max_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
536536
537537
Returns:
538538
Subgraph components as lists of label_ids of dataframe.
@@ -543,23 +543,23 @@ def components_ihc(
543543
for index, element in zip(labels, centroids):
544544
coords[index] = element
545545

546-
components, _ = graph_connected_components(coords, min_edge_distance, min_component_length)
546+
components, _ = graph_connected_components(coords, max_edge_distance, min_component_length)
547547
return components
548548

549549

550550
def label_components_ihc(
551551
table: pd.DataFrame,
552552
min_size: int = 1000,
553553
min_component_length: int = 50,
554-
min_edge_distance: float = 30,
554+
max_edge_distance: float = 30,
555555
) -> List[int]:
556556
"""Label components using graph connected components.
557557
558558
Args:
559559
table: Dataframe of segmentation table.
560560
min_size: Minimal number of pixels for filtering small instances.
561561
min_component_length: Minimal length for filtering out connected components.
562-
min_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
562+
max_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
563563
564564
Returns:
565565
List of component label for each point in dataframe. 0 - background, then in descending order of size
@@ -570,7 +570,7 @@ def label_components_ihc(
570570
table = table[table.n_pixels >= min_size]
571571

572572
components = components_ihc(table, min_component_length=min_component_length,
573-
min_edge_distance=min_edge_distance)
573+
max_edge_distance=max_edge_distance)
574574

575575
# add size-filtered objects to have same initial length
576576
table = pd.concat([table, entries_filtered], ignore_index=True)
@@ -592,23 +592,23 @@ def postprocess_ihc_seg(
592592
table: pd.DataFrame,
593593
min_size: int = 1000,
594594
min_component_length: int = 50,
595-
min_edge_distance: float = 30,
595+
max_edge_distance: float = 30,
596596
) -> pd.DataFrame:
597597
"""Postprocessing IHC segmentation of cochlea.
598598
599599
Args:
600600
table: Dataframe of segmentation table.
601601
min_size: Minimal number of pixels for filtering small instances.
602602
min_component_length: Minimal length for filtering out connected components.
603-
min_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
603+
max_edge_distance: Maximal distance in micrometer between points to create edges for connected components.
604604
605605
Returns:
606606
Dataframe with component labels.
607607
"""
608608

609609
comp_labels = label_components_ihc(table, min_size=min_size,
610610
min_component_length=min_component_length,
611-
min_edge_distance=min_edge_distance)
611+
max_edge_distance=max_edge_distance)
612612

613613
table.loc[:, "component_labels"] = comp_labels
614614

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
[
2+
{
3+
"cochlea": "M_LR_000226_L",
4+
"segmentation_channel": "IHC_v3",
5+
"type": "ihc"
6+
},
7+
{
8+
"cochlea": "M_LR_000226_R",
9+
"segmentation_channel": "IHC_v3",
10+
"type": "ihc"
11+
},
12+
{
13+
"cochlea": "M_LR_000227_L",
14+
"segmentation_channel": "IHC_v3",
15+
"type": "ihc"
16+
},
17+
{
18+
"cochlea": "M_LR_000227_R",
19+
"segmentation_channel": "IHC_v3",
20+
"type": "ihc"
21+
}
22+
]
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
[
2+
{
3+
"cochlea": "M_AMD_000058_L",
4+
"segmentation_channel": "SGN_v2",
5+
"type": "sgn",
6+
"filter_factor": 0.75
7+
},
8+
{
9+
"cochlea": "M_LR_000144_L",
10+
"segmentation_channel": "SGN_resized_v2",
11+
"max_edge_distance": 70,
12+
"type": "sgn",
13+
"filter_factor": 0.75
14+
},
15+
{
16+
"cochlea": "M_LR_000144_R",
17+
"segmentation_channel": "SGN_v2",
18+
"type": "sgn",
19+
"filter_factor": 0.75
20+
},
21+
{
22+
"cochlea": "M_LR_000145_L",
23+
"segmentation_channel": "SGN_resized_v2",
24+
"type": "sgn",
25+
"filter_factor": 0.75
26+
},
27+
{
28+
"cochlea": "M_LR_000151_R",
29+
"segmentation_channel": "SGN_resized_v2",
30+
"type": "sgn",
31+
"filter_factor": 0.75
32+
},
33+
{
34+
"cochlea": "M_LR_000155_L",
35+
"segmentation_channel": "SGN_resized_v2",
36+
"type": "sgn",
37+
"filter_factor": 0.75
38+
},
39+
{
40+
"cochlea": "M_LR_000155_R",
41+
"segmentation_channel": "SGN_v2",
42+
"type": "sgn",
43+
"filter_factor": 0.75
44+
},
45+
{
46+
"cochlea": "M_LR_000167_R",
47+
"segmentation_channel": "SGN_v2",
48+
"type": "sgn",
49+
"filter_factor": 0.75
50+
},
51+
{
52+
"cochlea": "M_LR_000184_L",
53+
"segmentation_channel": "SGN_resized_v2",
54+
"type": "sgn",
55+
"filter_factor": 0.75
56+
}
57+
]

0 commit comments

Comments
 (0)