11import math
2- import warnings
32from typing import List , Tuple
43
54import networkx as nx
65import numpy as np
76import pandas as pd
87from networkx .algorithms .approximation import steiner_tree
98from scipy .ndimage import distance_transform_edt , binary_dilation , binary_closing
9+ from scipy .interpolate import interp1d
1010
11- import flamingo_tools .segmentation .postprocessing as postprocessing
12- from flamingo_tools .segmentation .postprocessing import graph_connected_components
11+ from flamingo_tools .segmentation .postprocessing import downscaled_centroids
1312
1413
1514def find_most_distant_nodes (G : nx .classes .graph .Graph , weight : str = 'weight' ) -> Tuple [float , float ]:
1615 """Find the most distant nodes in a graph.
1716
1817 Args:
19- G: Input graph
18+ G: Input graph.
2019
2120 Returns:
22- Node 1
23- Node 2
21+ Node 1.
22+ Node 2.
2423 """
2524 all_lengths = dict (nx .all_pairs_dijkstra_path_length (G , weight = weight ))
2625 max_dist = 0
@@ -40,9 +39,12 @@ def central_path_edt_graph(mask: np.ndarray, start: Tuple[int], end: Tuple[int])
4039 """Find the central path within a binary mask between a start and an end coordinate.
4140
4241 Args:
43- mask: Binary mask of volume
44- start: Starting coordinate
45- end: End coordinate
42+ mask: Binary mask of volume.
43+ start: Starting coordinate.
44+ end: End coordinate.
45+
46+ Returns:
47+ Coordinates of central path.
4648 """
4749 dt = distance_transform_edt (mask )
4850 G = nx .Graph ()
@@ -74,11 +76,11 @@ def moving_average_3d(path: np.ndarray, window: int = 5) -> np.ndarray:
7476 """Smooth a 3D path with a simple moving average filter.
7577
7678 Args:
77- path: ndarray of shape (N, 3)
78- window: half-window size; actual window = 2*window + 1
79+ path: ndarray of shape (N, 3).
80+ window: half-window size; actual window = 2*window + 1.
7981
8082 Returns:
81- smoothed path: ndarray of same shape
83+ smoothed path: ndarray of same shape.
8284 """
8385 kernel_size = 2 * window + 1
8486 kernel = np .ones (kernel_size ) / kernel_size
@@ -102,11 +104,15 @@ def measure_run_length_sgns(centroids: np.ndarray, scale_factor=10):
102104 6) The points of the path are fed into a dictionary along with the fractional length.
103105
104106 Args:
105- centroids: Centroids of the SGN segmentation, ndarray of shape (N, 3)
107+ centroids: Centroids of the SGN segmentation, ndarray of shape (N, 3).
106108 scale_factor: Downscaling factor for finding the central path.
107109
110+ Returns:
111+ Total distance of the path.
112+ Path as an nd.array of positions.
113+ A dictionary containing the position and the length fraction of each point in the path.
108114 """
109- mask = postprocessing . downscaled_centroids (centroids , scale_factor = scale_factor , downsample_mode = "capped" )
115+ mask = downscaled_centroids (centroids , scale_factor = scale_factor , downsample_mode = "capped" )
110116 mask = binary_dilation (mask , np .ones ((3 , 3 , 3 )), iterations = 1 )
111117 mask = binary_closing (mask , np .ones ((3 , 3 , 3 )), iterations = 1 )
112118 pts = np .argwhere (mask == 1 )
@@ -137,23 +143,31 @@ def measure_run_length_sgns(centroids: np.ndarray, scale_factor=10):
137143 path_dict [num + 1 ] = {"pos" : p , "length_fraction" : rel_dist }
138144 path_dict [len (path )] = {"pos" : path [- 1 ], "length_fraction" : 1 }
139145
140- return total_distance , path_dict
146+ return total_distance , path , path_dict
141147
142148
143- def measure_run_length_ihcs (graph , weight = "weight" ):
149+ def measure_run_length_ihcs (centroids ):
144150 """Measure the run lengths of the IHC segmentation
145151 by finding the shortest path between the most distant nodes in a Steiner Tree.
146152
147153 Args:
148- graph: Input graph.
154+ centroids: Centroids of SGN segmentation.
155+
156+ Returns:
157+ Total distance of the path.
158+ Path as an nd.array of positions.
159+ A dictionary containing the position and the length fraction of each point in the path.
149160 """
150- u , v = find_most_distant_nodes (graph )
161+ graph = nx .Graph ()
162+ for num , pos in enumerate (centroids ):
163+ graph .add_node (num , pos = pos )
151164 # approximate Steiner tree and find shortest path between the two most distant nodes
152165 terminals = set (graph .nodes ()) # All nodes are required
153166 # Approximate Steiner Tree over all nodes
154- T = steiner_tree (graph , terminals , weight = weight )
155- path = nx .shortest_path (T , source = u , target = v , weight = weight )
156- total_distance = nx .path_weight (T , path , weight = weight )
167+ T = steiner_tree (graph , terminals )
168+ u , v = find_most_distant_nodes (T )
169+ path = nx .shortest_path (T , source = u , target = v )
170+ total_distance = nx .path_weight (T , path , weight = "weight" )
157171
158172 # assign relative distance to points on path
159173 path_dict = {}
@@ -166,7 +180,7 @@ def measure_run_length_ihcs(graph, weight="weight"):
166180 path_dict [num + 1 ] = {"pos" : graph .nodes [p ]["pos" ], "length_fraction" : rel_dist }
167181 path_dict [len (path )] = {"pos" : graph .nodes [path [- 1 ]]["pos" ], "length_fraction" : 1 }
168182
169- return total_distance , path_dict
183+ return total_distance , path , path_dict
170184
171185
172186def map_frequency (table : pd .DataFrame ):
@@ -176,7 +190,10 @@ def map_frequency(table: pd.DataFrame):
176190 For mice: fit values between minimal (1kHz) and maximal (80kHz) values
177191
178192 Args:
179- table:
193+ table: Dataframe containing the segmentation.
194+
195+ Returns:
196+ Dataframe containing frequency in an additional column 'frequency[kHz]'.
180197 """
181198 var_k = 0.88
182199 fmin = 1
@@ -189,11 +206,51 @@ def map_frequency(table: pd.DataFrame):
189206 return table
190207
191208
209+ def equidistant_centers (
210+ table : pd .DataFrame ,
211+ component_label : List [int ] = [1 ],
212+ cell_type : str = "sgn" ,
213+ n_blocks : int = 10 ,
214+ offset_blocks : bool = True ,
215+ ) -> np .ndarray :
216+ """Find equidistant centers within the central path of the Rosenthal's canal.
217+
218+ Args:
219+ table: Dataframe containing centroids of SGN segmentation.
220+ component_label: List of components for centroid subset.
221+ cell_type: Cell type of the segmentation.
222+ n_blocks: Number of equidistant centers for block creation.
223+ offset_block: Centers are shifted by half a length if True. Avoid centers at the start/end of the path.
224+
225+ Returns:
226+ Equidistant centers as float values
227+ """
228+ # subset of centroids for given component label(s)
229+ new_subset = table [table ["component_labels" ].isin (component_label )]
230+ centroids = list (zip (new_subset ["anchor_x" ], new_subset ["anchor_y" ], new_subset ["anchor_z" ]))
231+
232+ if cell_type == "ihc" :
233+ total_distance , path , _ = measure_run_length_ihcs (centroids )
234+
235+ else :
236+ total_distance , path , _ = measure_run_length_sgns (centroids )
237+
238+ diffs = np .diff (path , axis = 0 )
239+ seg_lens = np .linalg .norm (diffs , axis = 1 )
240+ cum_len = np .insert (np .cumsum (seg_lens ), 0 , 0 )
241+ if offset_blocks :
242+ target_s = np .linspace (0 , total_distance , n_blocks * 2 + 1 )
243+ target_s = [s for num , s in enumerate (target_s ) if num % 2 == 1 ]
244+ else :
245+ target_s = np .linspace (0 , total_distance , n_blocks )
246+ f = interp1d (cum_len , path , axis = 0 )
247+ centers = f (target_s )
248+ return centers
249+
250+
192251def tonotopic_mapping (
193252 table : pd .DataFrame ,
194253 component_label : List [int ] = [1 ],
195- max_edge_distance : float = 30 ,
196- min_component_length : int = 50 ,
197254 cell_type : str = "ihc"
198255) -> pd .DataFrame :
199256 """Tonotopic mapping of IHCs by supplying a table with component labels.
@@ -202,10 +259,7 @@ def tonotopic_mapping(
202259 Args:
203260 table: Dataframe of segmentation table.
204261 component_label: List of component labels to evaluate.
205- max_edge_distance: Maximal edge distance to connect nodes.
206- min_component_length: Minimal number of nodes in component.
207262 cell_type: Cell type of segmentation.
208- Filter factor: Fraction of nodes to remove before mapping.
209263
210264 Returns:
211265 Table with tonotopic label for cells.
@@ -215,33 +269,20 @@ def tonotopic_mapping(
215269 centroids = list (zip (new_subset ["anchor_x" ], new_subset ["anchor_y" ], new_subset ["anchor_z" ]))
216270 label_ids = [int (i ) for i in list (new_subset ["label_id" ])]
217271
218- # create graph with connected components
219- coords = {}
220- for index , element in zip (label_ids , centroids ):
221- coords [index ] = element
222-
223- components , graph = graph_connected_components (coords , max_edge_distance , min_component_length )
224- if len (components ) > 1 :
225- warnings .warn (f"There are { len (components )} connected components, expected 1. "
226- "Check parameters for post-processing (max_edge_distance, min_component_length)." )
227-
228- unfiltered_graph = graph .copy ()
229-
230272 if cell_type == "ihc" :
231- total_distance , path_dict = measure_run_length_ihcs (graph )
273+ total_distance , _ , path_dict = measure_run_length_ihcs (centroids )
232274
233275 else :
234- total_distance , path_dict = measure_run_length_sgns (centroids )
276+ total_distance , _ , path_dict = measure_run_length_sgns (centroids )
235277
236278 # add missing nodes from component and compute distance to path
237- pos = nx .get_node_attributes (unfiltered_graph , 'pos' )
238279 node_dict = {}
239- for c in label_ids :
280+ for num , c in enumerate ( label_ids ) :
240281 min_dist = float ('inf' )
241282 nearest_node = None
242283
243284 for key in path_dict .keys ():
244- dist = math .dist (pos [ c ], path_dict [key ]["pos" ])
285+ dist = math .dist (centroids [ num ], path_dict [key ]["pos" ])
245286 if dist < min_dist :
246287 min_dist = dist
247288 nearest_node = key
0 commit comments