@@ -93,7 +93,11 @@ def moving_average_3d(path: np.ndarray, window: int = 5) -> np.ndarray:
9393 return smooth_path
9494
9595
96- def measure_run_length_sgns (centroids : np .ndarray , scale_factor = 10 ):
96+ def measure_run_length_sgns (
97+ centroids : np .ndarray ,
98+ scale_factor : int = 10 ,
99+ apex_higher : bool = True ,
100+ ) -> Tuple [float , np .ndarray , dict ]:
97101 """Measure the run lengths of the SGN segmentation by finding a central path through Rosenthal's canal.
98102 1) Create a binary mask based on down-scaled centroids.
99103 2) Dilate the mask and close holes to ensure a filled structure.
@@ -105,6 +109,7 @@ def measure_run_length_sgns(centroids: np.ndarray, scale_factor=10):
105109 Args:
106110 centroids: Centroids of the SGN segmentation, ndarray of shape (N, 3).
107111 scale_factor: Downscaling factor for finding the central path.
112+ apex_higher: Flag for identifying apex and base. Apex is set to node with higher y-value if True.
108113
109114 Returns:
110115 Total distance of the path.
@@ -125,8 +130,16 @@ def measure_run_length_sgns(centroids: np.ndarray, scale_factor=10):
125130 start_voxel = tuple (pts [proj .argmin ()])
126131 end_voxel = tuple (pts [proj .argmax ()])
127132
133+ # compare y-value to not get into confusion with MoBIE dimensions
134+ if start_voxel [1 ] > end_voxel [1 ]:
135+ apex = start_voxel if apex_higher else end_voxel
136+ base = end_voxel if apex_higher else start_voxel
137+ else :
138+ apex = end_voxel if apex_higher else start_voxel
139+ base = start_voxel if apex_higher else end_voxel
140+
128141 # get central path and total distance
129- path = central_path_edt_graph (mask , start_voxel , end_voxel )
142+ path = central_path_edt_graph (mask , apex , base )
130143 path = path * scale_factor
131144 path = moving_average_3d (path , window = 5 )
132145 total_distance = sum ([math .dist (path [num + 1 ], path [num ]) for num in range (len (path ) - 1 )])
@@ -145,7 +158,11 @@ def measure_run_length_sgns(centroids: np.ndarray, scale_factor=10):
145158 return total_distance , path , path_dict
146159
147160
148- def measure_run_length_ihcs (centroids , max_edge_distance = 50 ):
161+ def measure_run_length_ihcs (
162+ centroids : np .ndarray ,
163+ max_edge_distance : float = 50 ,
164+ apex_higher : bool = True ,
165+ ) -> Tuple [float , np .ndarray , dict ]:
149166 """Measure the run lengths of the IHC segmentation
150167 by determining the shortest path between the most distant nodes of a graph.
151168 The graph is created based on a maximal edge distance between nodes.
@@ -158,6 +175,7 @@ def measure_run_length_ihcs(centroids, max_edge_distance=50):
158175 Args:
159176 centroids: Centroids of SGN segmentation.
160177 max_edge_distance: Maximal edge distance between graph nodes to create an edge between nodes.
178+ apex_higher: Flag for identifying apex and base. Apex is set to node with higher y-value if True.
161179
162180 Returns:
163181 Total distance of the path.
@@ -184,8 +202,17 @@ def measure_run_length_ihcs(centroids, max_edge_distance=50):
184202 if dist <= max_edge_distance :
185203 graph .add_edge (num_i , num_j , weight = dist )
186204
187- u , v = find_most_distant_nodes (graph )
188- path = nx .shortest_path (graph , source = u , target = v )
205+ start_node , end_node = find_most_distant_nodes (graph )
206+
207+ # compare y-value to not get into confusion with MoBIE dimensions
208+ if graph .nodes [start_node ]["pos" ][1 ] > graph .nodes [end_node ]["pos" ][1 ]:
209+ apex_node = start_node if apex_higher else end_node
210+ base_node = end_node if apex_higher else start_node
211+ else :
212+ apex_node = end_node if apex_higher else start_node
213+ base_node = start_node if apex_higher else end_node
214+
215+ path = nx .shortest_path (graph , source = apex_node , target = base_node )
189216 total_distance = nx .path_weight (graph , path , weight = "weight" )
190217
191218 # assign relative distance to points on path
@@ -319,6 +346,15 @@ def tonotopic_mapping(
319346 """
320347 # subset of centroids for given component label(s)
321348 new_subset = table [table ["component_labels" ].isin (component_label )]
349+
350+ # option for filtering IHC instances without synapses
351+ # leaving it commented for now because it would have little effect
352+
353+ # if "syn_per_IHC" in new_subset.columns:
354+ # syn_limit = 0
355+ # print(f"Keeping IHC instances with more than {syn_limit} synapses.")
356+ # new_subset = new_subset[new_subset["syn_per_IHC"] > syn_limit]
357+
322358 centroids = list (zip (new_subset ["anchor_x" ], new_subset ["anchor_y" ], new_subset ["anchor_z" ]))
323359 label_ids = [int (i ) for i in list (new_subset ["label_id" ])]
324360
@@ -347,17 +383,18 @@ def tonotopic_mapping(
347383 }
348384
349385 offset = [- 1 for _ in range (len (table ))]
386+ offset = list (np .float64 (offset ))
387+ table .loc [:, "offset" ] = offset
350388 # 'label_id' of dataframe starting at 1
351389 for key in list (node_dict .keys ()):
352- offset [int (node_dict [key ]["label_id" ] - 1 )] = node_dict [key ]["offset" ]
353-
354- table .loc [:, "offset" ] = offset
390+ table .loc [table ["label_id" ] == key , "offset" ] = node_dict [key ]["offset" ]
355391
356392 length_fraction = [0 for _ in range (len (table ))]
357- for key in list (node_dict .keys ()):
358- length_fraction [int (node_dict [key ]["label_id" ] - 1 )] = node_dict [key ]["length_fraction" ]
359-
393+ length_fraction = list (np .float64 (length_fraction ))
360394 table .loc [:, "length_fraction" ] = length_fraction
395+ for num , key in enumerate (list (node_dict .keys ())):
396+ table .loc [table ["label_id" ] == key , "length_fraction" ] = node_dict [key ]["length_fraction" ]
397+
361398 table .loc [:, "length[µm]" ] = table ["length_fraction" ] * total_distance
362399
363400 table = map_frequency (table , cell_type = cell_type , animal = animal )
0 commit comments