11import math
22import warnings
3- from typing import List , Optional , Tuple
3+ from typing import List , Tuple
44
55import networkx as nx
66import numpy as np
77import pandas as pd
88from networkx .algorithms .approximation import steiner_tree
9+ from scipy .ndimage import distance_transform_edt , binary_dilation , binary_closing
910
11+ import flamingo_tools .segmentation .postprocessing as postprocessing
1012from flamingo_tools .segmentation .postprocessing import graph_connected_components
11- from flamingo_tools .segmentation .distance_weighted_steiner import distance_weighted_steiner_path
1213
1314
1415def find_most_distant_nodes (G : nx .classes .graph .Graph , weight : str = 'weight' ) -> Tuple [float , float ]:
16+ """Find the most distant nodes in a graph.
17+
18+ Args:
19+ G: Input graph
20+
21+ Returns:
22+ Node 1
23+ Node 2
24+ """
1525 all_lengths = dict (nx .all_pairs_dijkstra_path_length (G , weight = weight ))
1626 max_dist = 0
1727 farthest_pair = (None , None )
@@ -26,90 +36,155 @@ def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -
2636 return u , v
2737
2838
29- def voxel_subsample (G , factor = 0.25 , voxel_size = None , seed = 1234 ):
30- coords = np .asarray ([G .nodes [n ]["pos" ] for n in G .nodes ])
31- nodes = np .asarray (list (G .nodes ))
39+ def central_path_edt_graph (mask : np .ndarray , start : Tuple [int ], end : Tuple [int ]):
40+ """Find the central path within a binary mask between a start and an end coordinate.
3241
33- # choose a voxel edge length if the caller has not fixed one
34- if voxel_size is None :
35- bbox = np .ptp (coords , axis = 0 ) # edge lengths
36- voxel_size = (bbox .prod () / (len (G )/ factor )) ** (1 / 3 )
37-
38- # integer voxel indices
39- mins = coords .min (axis = 0 )
40- vox = np .floor ((coords - mins ) / voxel_size ).astype (np .int32 )
42+ Args:
43+ mask: Binary mask of volume
44+ start: Starting coordinate
45+ end: End coordinate
46+ """
47+ dt = distance_transform_edt (mask )
48+ G = nx .Graph ()
49+ shape = mask .shape
50+ def idx_to_node (z , y , x ): return z * shape [1 ]* shape [2 ] + y * shape [2 ] + x
51+ border_coords = [(1 , 0 , 0 ), (- 1 , 0 , 0 ), (0 , 1 , 0 ), (0 , - 1 , 0 ), (0 , 0 , 1 ), (0 , 0 , - 1 )]
52+ for z in range (shape [0 ]):
53+ for y in range (shape [1 ]):
54+ for x in range (shape [2 ]):
55+ if not mask [z , y , x ]:
56+ continue
57+ u = idx_to_node (z , y , x )
58+ for dz , dy , dx in border_coords :
59+ nz , ny , nx_ = z + dz , y + dy , x + dx
60+ if nz >= 0 and nz < shape [0 ] and mask [nz , ny , nx_ ]:
61+ v = idx_to_node (nz , ny , nx_ )
62+ w = 1.0 / (1e-3 + min (dt [z , y , x ], dt [nz , ny , nx_ ]))
63+ G .add_edge (u , v , weight = w )
64+ s = idx_to_node (* start )
65+ t = idx_to_node (* end )
66+ path = nx .shortest_path (G , source = s , target = t , weight = "weight" )
67+ coords = [(p // (shape [1 ]* shape [2 ]),
68+ (p // shape [2 ]) % shape [1 ],
69+ p % shape [2 ]) for p in path ]
70+ return np .array (coords )
71+
72+
73+ def moving_average_3d (path : np .ndarray , window : int = 5 ) -> np .ndarray :
74+ """Smooth a 3D path with a simple moving average filter.
4175
42- # bucket nodes per voxel
43- from collections import defaultdict
44- buckets = defaultdict (list )
45- for idx , v in enumerate (map (tuple , vox )):
46- buckets [v ].append (idx )
76+ Args:
77+ path: ndarray of shape (N, 3)
78+ window: half-window size; actual window = 2*window + 1
4779
48- rng = np . random . default_rng ( seed )
49- keep = []
50- for bucket in buckets . values ():
51- k = max ( 1 , int ( round ( len ( bucket ) * factor ))) # local quota
52- keep . extend ( rng . choice ( bucket , k , replace = False ))
80+ Returns:
81+ smoothed path: ndarray of same shape
82+ """
83+ kernel_size = 2 * window + 1
84+ kernel = np . ones ( kernel_size ) / kernel_size
5385
54- sampled_nodes = nodes [keep ]
55- return G .subgraph (sampled_nodes ).copy ()
86+ smooth_path = np .zeros_like (path )
5687
88+ for d in range (3 ):
89+ pad = np .pad (path [:, d ], window , mode = 'edge' )
90+ smooth_path [:, d ] = np .convolve (pad , kernel , mode = 'valid' )
5791
58- def measure_run_length_sgns (graph , centroids , label_ids , filter_factor , weight = "weight" ):
59- if filter_factor is not None :
60- if 0 <= filter_factor < 1 :
61- graph = voxel_subsample (graph , factor = filter_factor )
62- centroid_labels = list (graph .nodes )
63- centroids = [graph .nodes [n ]["pos" ] for n in graph .nodes ]
64- k_nn_thick = int (40 * filter_factor )
65- # centroids = [centroids[label_ids.index(i)] for i in centroid_labels]
92+ return smooth_path
6693
67- else :
68- raise ValueError (f"Invalid filter factor { filter_factor } . Choose a filter factor between 0 and 1." )
69- else :
70- k_nn_thick = 40
71- centroid_labels = label_ids
7294
73- path_coords , path = distance_weighted_steiner_path (
74- centroids , # (N,3) ndarray
75- centroid_labels = centroid_labels , # (N,) ndarray
76- k_nn_thick = k_nn_thick , # 20‒30 is robust for SGN clouds int(40 * (1 - filter_factor))
77- lam = 0.5 , # 0.3‒1.0 : larger → stronger centripetal bias
78- r_connect = 50.0 # connect neighbours within 50 µm
79- )
95+ def measure_run_length_sgns (centroids : np .ndarray , scale_factor = 10 ):
96+ """Measure the run lengths of the SGN segmentation by finding a central path through Rosenthal's canal.
97+ 1) Create a binary mask based on down-scaled centroids.
98+ 2) Dilate the mask and close holes to ensure a filled structure.
99+ 3) Determine the endpoints of the structure using the principal axis.
100+ 4) Identify a central path based on the 3D Euclidean distance transform.
101+ 5) The path is up-scaled and smoothed using a moving average filter.
102+ 6) The points of the path are fed into a dictionary along with the fractional length.
80103
81- for num , p in enumerate (path [:- 1 ]):
82- pos_i = centroids [centroid_labels .index (p )]
83- pos_j = centroids [centroid_labels .index (path [num + 1 ])]
84- dist = math .dist (pos_i , pos_j )
85- graph .add_edge (p , path [num + 1 ], weight = dist )
104+ Args:
105+ centroids: Centroids of the SGN segmentation, ndarray of shape (N, 3)
106+ scale_factor: Downscaling factor for finding the central path.
86107
87- total_distance = nx .path_weight (graph , path , weight = weight )
108+ """
109+ mask = postprocessing .downscaled_centroids (centroids , scale_factor = scale_factor , downsample_mode = "capped" )
110+ mask = binary_dilation (mask , np .ones ((3 , 3 , 3 )), iterations = 1 )
111+ mask = binary_closing (mask , np .ones ((3 , 3 , 3 )), iterations = 1 )
112+ pts = np .argwhere (mask == 1 )
113+
114+ # find two endpoints: min/max along principal axis
115+ c_mean = pts .mean (axis = 0 )
116+ cov = np .cov ((pts - c_mean ).T )
117+ evals , evecs = np .linalg .eigh (cov )
118+ axis = evecs [:, np .argmax (evals )]
119+ proj = (pts - c_mean ) @ axis
120+ start_voxel = tuple (pts [proj .argmin ()])
121+ end_voxel = tuple (pts [proj .argmax ()])
122+
123+ # get central path and total distance
124+ path = central_path_edt_graph (mask , start_voxel , end_voxel )
125+ path = path * scale_factor
126+ path = moving_average_3d (path , window = 5 )
127+ total_distance = sum ([math .dist (path [num + 1 ], path [num ]) for num in range (len (path ) - 1 )])
128+
129+ # assign relative distance to points on path
130+ path_dict = {}
131+ path_dict [0 ] = {"pos" : path [0 ], "length_fraction" : 0 }
132+ accumulated = 0
133+ for num , p in enumerate (path [1 :- 1 ]):
134+ distance = math .dist (path [num ], p )
135+ accumulated += distance
136+ rel_dist = accumulated / total_distance
137+ path_dict [num + 1 ] = {"pos" : p , "length_fraction" : rel_dist }
138+ path_dict [len (path )] = {"pos" : path [- 1 ], "length_fraction" : 1 }
88139
89- return total_distance , path , graph
140+ return total_distance , path_dict
90141
91142
92143def measure_run_length_ihcs (graph , weight = "weight" ):
144+ """Measure the run lengths of the IHC segmentation
145+ by finding the shortest path between the most distant nodes in a Steiner Tree.
146+
147+ Args:
148+ graph: Input graph.
149+ """
93150 u , v = find_most_distant_nodes (graph )
94151 # approximate Steiner tree and find shortest path between the two most distant nodes
95152 terminals = set (graph .nodes ()) # All nodes are required
96153 # Approximate Steiner Tree over all nodes
97154 T = steiner_tree (graph , terminals , weight = weight )
98155 path = nx .shortest_path (T , source = u , target = v , weight = weight )
99156 total_distance = nx .path_weight (T , path , weight = weight )
100- return total_distance , path
101157
158+ # assign relative distance to points on path
159+ path_dict = {}
160+ path_dict [0 ] = {"pos" : graph .nodes [path [0 ]]["pos" ], "length_fraction" : 0 }
161+ accumulated = 0
162+ for num , p in enumerate (path [1 :- 1 ]):
163+ distance = math .dist (graph .nodes [path [num ]]["pos" ], graph .nodes [p ]["pos" ])
164+ accumulated += distance
165+ rel_dist = accumulated / total_distance
166+ path_dict [num + 1 ] = {"pos" : graph .nodes [p ]["pos" ], "length_fraction" : rel_dist }
167+ path_dict [len (path )] = {"pos" : graph .nodes [path [- 1 ]]["pos" ], "length_fraction" : 1 }
168+
169+ return total_distance , path_dict
170+
171+
172+ def map_frequency (table : pd .DataFrame ):
173+ """Map the frequency range of SGNs in the cochlea
174+ using Greenwood function f(x) = A * (10 **(ax) - K).
175+ Values for humans: a=2.1, k=0.88, A = 165.4 [kHz].
176+ For mice: fit values between minimal (1kHz) and maximal (80kHz) values
102177
103- def map_frequency (table ):
104- # map frequency using Greenwood function f(x) = A * (10 **(ax) - K), for humans: a=2.1, k=0.88, A = 165.4 [kHz]
178+ Args:
179+ table:
180+ """
105181 var_k = 0.88
106- # calculate values to fit (assumed) minimal (1kHz) and maximal (80kHz) hearing range of mice at x=0, x=1
107182 fmin = 1
108183 fmax = 80
109184 var_A = fmin / (1 - var_k )
110185 var_exp = ((fmax + var_A * var_k ) / var_A )
111- table .loc [table ['distance_to_path[µm] ' ] >= 0 , 'tonotopic_value [kHz]' ] = var_A * (var_exp ** table ["length_fraction" ] - var_k )
112- table .loc [table ['distance_to_path[µm] ' ] < 0 , 'tonotopic_value [kHz]' ] = 0
186+ table .loc [table ['offset ' ] >= 0 , 'frequency [kHz]' ] = var_A * (var_exp ** table ["length_fraction" ] - var_k )
187+ table .loc [table ['offset ' ] < 0 , 'frequency [kHz]' ] = 0
113188
114189 return table
115190
@@ -119,8 +194,7 @@ def tonotopic_mapping(
119194 component_label : List [int ] = [1 ],
120195 max_edge_distance : float = 30 ,
121196 min_component_length : int = 50 ,
122- cell_type : str = "ihc" ,
123- filter_factor : Optional [float ] = None
197+ cell_type : str = "ihc"
124198) -> pd .DataFrame :
125199 """Tonotopic mapping of IHCs by supplying a table with component labels.
126200 The mapping assigns a tonotopic label to each IHC according to the position along the length of the cochlea.
@@ -154,63 +228,43 @@ def tonotopic_mapping(
154228 unfiltered_graph = graph .copy ()
155229
156230 if cell_type == "ihc" :
157- total_distance , path = measure_run_length_ihcs (graph )
231+ total_distance , path_dict = measure_run_length_ihcs (graph )
158232
159233 else :
160- total_distance , path , graph = measure_run_length_sgns (graph , centroids , label_ids ,
161- filter_factor , weight = "weight" )
162-
163- # measure_betweenness
164- centrality = nx .betweenness_centrality (graph , k = 100 , normalized = True , weight = 'weight' , seed = 1234 )
165- score = sum (centrality [n ] for n in path ) / len (path )
166- print (f"path distance: { total_distance } " )
167- print (f"centrality score: { score } " )
168-
169- # assign relative distance to nodes on path
170- path_dict = {}
171- path_dict [path [0 ]] = {"label_id" : path [0 ], "length_fraction" : 0 }
172- accumulated = 0
173- for num , p in enumerate (path [1 :- 1 ]):
174- distance = graph .get_edge_data (path [num ], p )["weight" ]
175- accumulated += distance
176- rel_dist = accumulated / total_distance
177- path_dict [p ] = {"label_id" : p , "length_fraction" : rel_dist }
178- path_dict [path [- 1 ]] = {"label_id" : path [- 1 ], "length_fraction" : 1 }
234+ total_distance , path_dict = measure_run_length_sgns (centroids )
179235
180236 # add missing nodes from component and compute distance to path
181237 pos = nx .get_node_attributes (unfiltered_graph , 'pos' )
238+ node_dict = {}
182239 for c in label_ids :
183- if c not in path :
184- min_dist = float ('inf' )
185- nearest_node = None
186-
187- for p in path :
188- dist = math .dist (pos [c ], pos [p ])
189- if dist < min_dist :
190- min_dist = dist
191- nearest_node = p
192-
193- path_dict [c ] = {
194- "label_id" : c ,
195- "length_fraction" : path_dict [nearest_node ]["length_fraction" ],
196- "distance_to_path" : min_dist ,
197- }
198- else :
199- path_dict [c ]["distance_to_path" ] = 0
200-
201- distance_to_path = [- 1 for _ in range (len (table ))]
240+ min_dist = float ('inf' )
241+ nearest_node = None
242+
243+ for key in path_dict .keys ():
244+ dist = math .dist (pos [c ], path_dict [key ]["pos" ])
245+ if dist < min_dist :
246+ min_dist = dist
247+ nearest_node = key
248+
249+ node_dict [c ] = {
250+ "label_id" : c ,
251+ "length_fraction" : path_dict [nearest_node ]["length_fraction" ],
252+ "offset" : min_dist ,
253+ }
254+
255+ offset = [- 1 for _ in range (len (table ))]
202256 # 'label_id' of dataframe starting at 1
203- for key in list (path_dict .keys ()):
204- distance_to_path [int (path_dict [key ]["label_id" ] - 1 )] = path_dict [key ]["distance_to_path " ]
257+ for key in list (node_dict .keys ()):
258+ offset [int (node_dict [key ]["label_id" ] - 1 )] = node_dict [key ]["offset " ]
205259
206- table .loc [:, "distance_to_path[µm] " ] = distance_to_path
260+ table .loc [:, "offset " ] = offset
207261
208262 length_fraction = [0 for _ in range (len (table ))]
209- for key in list (path_dict .keys ()):
210- length_fraction [int (path_dict [key ]["label_id" ] - 1 )] = path_dict [key ]["length_fraction" ]
263+ for key in list (node_dict .keys ()):
264+ length_fraction [int (node_dict [key ]["label_id" ] - 1 )] = node_dict [key ]["length_fraction" ]
211265
212266 table .loc [:, "length_fraction" ] = length_fraction
213- table .loc [:, "run_length [µm]" ] = table ["length_fraction" ] * total_distance
267+ table .loc [:, "length [µm]" ] = table ["length_fraction" ] * total_distance
214268
215269 table = map_frequency (table )
216270
0 commit comments