88from networkx .algorithms .approximation import steiner_tree
99
1010from flamingo_tools .segmentation .postprocessing import graph_connected_components
11+ from flamingo_tools .segmentation .distance_weighted_steiner import distance_weighted_steiner_path
1112
1213
1314def find_most_distant_nodes (G : nx .classes .graph .Graph , weight : str = 'weight' ) -> Tuple [float , float ]:
@@ -25,6 +26,94 @@ def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -
2526 return u , v
2627
2728
29+ def voxel_subsample (G , factor = 0.25 , voxel_size = None , seed = 1234 ):
30+ coords = np .asarray ([G .nodes [n ]["pos" ] for n in G .nodes ])
31+ nodes = np .asarray (list (G .nodes ))
32+
33+ # choose a voxel edge length if the caller has not fixed one
34+ if voxel_size is None :
35+ bbox = np .ptp (coords , axis = 0 ) # edge lengths
36+ voxel_size = (bbox .prod () / (len (G )/ factor )) ** (1 / 3 )
37+
38+ # integer voxel indices
39+ mins = coords .min (axis = 0 )
40+ vox = np .floor ((coords - mins ) / voxel_size ).astype (np .int32 )
41+
42+ # bucket nodes per voxel
43+ from collections import defaultdict
44+ buckets = defaultdict (list )
45+ for idx , v in enumerate (map (tuple , vox )):
46+ buckets [v ].append (idx )
47+
48+ rng = np .random .default_rng (seed )
49+ keep = []
50+ for bucket in buckets .values ():
51+ k = max (1 , int (round (len (bucket )* factor ))) # local quota
52+ keep .extend (rng .choice (bucket , k , replace = False ))
53+
54+ sampled_nodes = nodes [keep ]
55+ return G .subgraph (sampled_nodes ).copy ()
56+
57+
58+ def measure_run_length_sgns (graph , centroids , label_ids , filter_factor , weight = "weight" ):
59+ if filter_factor is not None :
60+ if 0 <= filter_factor < 1 :
61+ graph = voxel_subsample (graph , factor = filter_factor )
62+ centroid_labels = list (graph .nodes )
63+ centroids = [graph .nodes [n ]["pos" ] for n in graph .nodes ]
64+ k_nn_thick = int (40 * filter_factor )
65+ # centroids = [centroids[label_ids.index(i)] for i in centroid_labels]
66+
67+ else :
68+ raise ValueError (f"Invalid filter factor { filter_factor } . Choose a filter factor between 0 and 1." )
69+ else :
70+ k_nn_thick = 40
71+ centroid_labels = label_ids
72+
73+ path_coords , path = distance_weighted_steiner_path (
74+ centroids , # (N,3) ndarray
75+ centroid_labels = centroid_labels , # (N,) ndarray
76+ k_nn_thick = k_nn_thick , # 20‒30 is robust for SGN clouds int(40 * (1 - filter_factor))
77+ lam = 0.5 , # 0.3‒1.0 : larger → stronger centripetal bias
78+ r_connect = 50.0 # connect neighbours within 50 µm
79+ )
80+
81+ for num , p in enumerate (path [:- 1 ]):
82+ pos_i = centroids [centroid_labels .index (p )]
83+ pos_j = centroids [centroid_labels .index (path [num + 1 ])]
84+ dist = math .dist (pos_i , pos_j )
85+ graph .add_edge (p , path [num + 1 ], weight = dist )
86+
87+ total_distance = nx .path_weight (graph , path , weight = weight )
88+
89+ return total_distance , path , graph
90+
91+
92+ def measure_run_length_ihcs (graph , weight = "weight" ):
93+ u , v = find_most_distant_nodes (graph )
94+ # approximate Steiner tree and find shortest path between the two most distant nodes
95+ terminals = set (graph .nodes ()) # All nodes are required
96+ # Approximate Steiner Tree over all nodes
97+ T = steiner_tree (graph , terminals , weight = weight )
98+ path = nx .shortest_path (T , source = u , target = v , weight = weight )
99+ total_distance = nx .path_weight (T , path , weight = weight )
100+ return total_distance , path
101+
102+
103+ def map_frequency (table ):
104+ # map frequency using Greenwood function f(x) = A * (10 **(ax) - K), for humans: a=2.1, k=0.88, A = 165.4 [kHz]
105+ var_k = 0.88
106+ # calculate values to fit (assumed) minimal (1kHz) and maximal (80kHz) hearing range of mice at x=0, x=1
107+ fmin = 1
108+ fmax = 80
109+ var_A = fmin / (1 - var_k )
110+ var_exp = ((fmax + var_A * var_k ) / var_A )
111+ table .loc [table ['distance_to_path[µm]' ] >= 0 , 'tonotopic_value[kHz]' ] = var_A * (var_exp ** table ["length_fraction" ] - var_k )
112+ table .loc [table ['distance_to_path[µm]' ] < 0 , 'tonotopic_value[kHz]' ] = 0
113+
114+ return table
115+
116+
28117def tonotopic_mapping (
29118 table : pd .DataFrame ,
30119 component_label : List [int ] = [1 ],
@@ -47,16 +136,14 @@ def tonotopic_mapping(
47136 Returns:
48137 Table with tonotopic label for cells.
49138 """
50- weight = "weight"
51139 # subset of centroids for given component label(s)
52140 new_subset = table [table ["component_labels" ].isin (component_label )]
53- comp_label_ids = list (new_subset ["label_id" ])
54- centroids_subset = list (zip (new_subset ["anchor_x" ], new_subset ["anchor_y" ], new_subset ["anchor_z" ]))
55- labels_subset = [int (i ) for i in list (new_subset ["label_id" ])]
141+ centroids = list (zip (new_subset ["anchor_x" ], new_subset ["anchor_y" ], new_subset ["anchor_z" ]))
142+ label_ids = [int (i ) for i in list (new_subset ["label_id" ])]
56143
57144 # create graph with connected components
58145 coords = {}
59- for index , element in zip (labels_subset , centroids_subset ):
146+ for index , element in zip (label_ids , centroids ):
60147 coords [index ] = element
61148
62149 components , graph = graph_connected_components (coords , max_edge_distance , min_component_length )
@@ -66,45 +153,33 @@ def tonotopic_mapping(
66153
67154 unfiltered_graph = graph .copy ()
68155
69- if filter_factor is not None :
70- if 0 <= filter_factor < 1 :
71- rng = np .random .default_rng (seed = 1234 )
72- original_array = np .array (comp_label_ids )
73- target_length = int (len (original_array ) * filter_factor )
74- filtered_list = list (rng .choice (original_array , size = target_length , replace = False ))
75- for filter_id in filtered_list :
76- graph .remove_node (filter_id )
77- else :
78- raise ValueError (f"Invalid filter factor { filter_factor } . Choose a filter factor between 0 and 1." )
79-
80- u , v = find_most_distant_nodes (graph )
81-
82- if not nx .has_path (graph , source = u , target = v ) or cell_type == "ihc" :
83- # approximate Steiner tree and find shortest path between the two most distant nodes
84- terminals = set (graph .nodes ()) # All nodes are required
85- # Approximate Steiner Tree over all nodes
86- T = steiner_tree (graph , terminals , weight = weight )
87- path = nx .shortest_path (T , source = u , target = v , weight = weight )
88- total_distance = nx .path_weight (T , path , weight = weight )
156+ if cell_type == "ihc" :
157+ total_distance , path = measure_run_length_ihcs (graph )
89158
90159 else :
91- path = nx .shortest_path (graph , source = u , target = v , weight = weight )
92- total_distance = nx .path_weight (graph , path , weight = weight )
160+ total_distance , path , graph = measure_run_length_sgns (graph , centroids , label_ids ,
161+ filter_factor , weight = "weight" )
162+
163+ # measure_betweenness
164+ centrality = nx .betweenness_centrality (graph , k = 100 , normalized = True , weight = 'weight' , seed = 1234 )
165+ score = sum (centrality [n ] for n in path ) / len (path )
166+ print (f"path distance: { total_distance } " )
167+ print (f"centrality score: { score } " )
93168
94169 # assign relative distance to nodes on path
95- path_list = {}
96- path_list [path [0 ]] = {"label_id" : path [0 ], "tonotopic " : 0 }
170+ path_dict = {}
171+ path_dict [path [0 ]] = {"label_id" : path [0 ], "length_fraction " : 0 }
97172 accumulated = 0
98173 for num , p in enumerate (path [1 :- 1 ]):
99174 distance = graph .get_edge_data (path [num ], p )["weight" ]
100175 accumulated += distance
101176 rel_dist = accumulated / total_distance
102- path_list [p ] = {"label_id" : p , "tonotopic " : rel_dist }
103- path_list [path [- 1 ]] = {"label_id" : path [- 1 ], "tonotopic " : 1 }
177+ path_dict [p ] = {"label_id" : p , "length_fraction " : rel_dist }
178+ path_dict [path [- 1 ]] = {"label_id" : path [- 1 ], "length_fraction " : 1 }
104179
105- # add missing nodes from component
180+ # add missing nodes from component and compute distance to path
106181 pos = nx .get_node_attributes (unfiltered_graph , 'pos' )
107- for c in comp_label_ids :
182+ for c in label_ids :
108183 if c not in path :
109184 min_dist = float ('inf' )
110185 nearest_node = None
@@ -115,27 +190,28 @@ def tonotopic_mapping(
115190 min_dist = dist
116191 nearest_node = p
117192
118- path_list [c ] = {"label_id" : c , "tonotopic" : path_list [nearest_node ]["tonotopic" ]}
193+ path_dict [c ] = {
194+ "label_id" : c ,
195+ "length_fraction" : path_dict [nearest_node ]["length_fraction" ],
196+ "distance_to_path" : min_dist ,
197+ }
198+ else :
199+ path_dict [c ]["distance_to_path" ] = 0
119200
120- # label in micrometer
121- tonotopic = [0 for _ in range (len (table ))]
122- # be aware of 'label_id' of dataframe starting at 1
123- for key in list (path_list .keys ()):
124- tonotopic [int (path_list [key ]["label_id" ] - 1 )] = path_list [key ]["tonotopic" ] * total_distance
201+ distance_to_path = [- 1 for _ in range (len (table ))]
202+ # 'label_id' of dataframe starting at 1
203+ for key in list (path_dict .keys ()):
204+ distance_to_path [int (path_dict [key ]["label_id" ] - 1 )] = path_dict [key ]["distance_to_path" ]
125205
126- table .loc [:, "tonotopic_label " ] = tonotopic
206+ table .loc [:, "distance_to_path[µm] " ] = distance_to_path
127207
128- # map frequency using Greenwood function f(x) = A * (10 **(ax) - K), for humans: a=2.1, k=0.88, A = 165.4 [kHz]
129- tonotopic_map = [0 for _ in range (len (table ))]
130- var_k = 0.88
131- # calculate values to fit (assumed) minimal (1kHz) and maximal (80kHz) hearing range of mice at x=0, x=1
132- fmin = 1
133- fmax = 80
134- var_A = fmin / (1 - var_k )
135- var_exp = ((fmax + var_A * var_k ) / var_A )
136- for key in list (path_list .keys ()):
137- tonotopic_map [int (path_list [key ]["label_id" ] - 1 )] = var_A * (var_exp ** path_list [key ]["tonotopic" ] - var_k )
208+ length_fraction = [0 for _ in range (len (table ))]
209+ for key in list (path_dict .keys ()):
210+ length_fraction [int (path_dict [key ]["label_id" ] - 1 )] = path_dict [key ]["length_fraction" ]
211+
212+ table .loc [:, "length_fraction" ] = length_fraction
213+ table .loc [:, "run_length[µm]" ] = table ["length_fraction" ] * total_distance
138214
139- table . loc [:, "tonotopic_value[kHz]" ] = tonotopic_map
215+ table = map_frequency ( table )
140216
141217 return table
0 commit comments