@@ -275,6 +275,141 @@ def measure_run_length_sgns(
275275 return total_distance , path , path_dict
276276
277277
278+ def measure_run_length_ihcs_multi_component (
279+ centroids_components : List [np .ndarray ],
280+ max_edge_distance : float = 30 ,
281+ apex_higher : bool = True ,
282+ component_label : List [int ] = [1 ],
283+ ) -> Tuple [float , np .ndarray , dict ]:
284+ """Adaptation of measure_run_length_sgns_multi_component to IHCs.
285+
286+ """
287+ total_path = []
288+ print (f"Evaluating { len (centroids_components )} components." )
289+ # 1) Process centroids for each component
290+ for centroids in centroids_components :
291+ graph = nx .Graph ()
292+ coords = {}
293+ labels = [int (i ) for i in range (len (centroids ))]
294+ for index , element in zip (labels , centroids ):
295+ coords [index ] = element
296+
297+ for num , pos in coords .items ():
298+ graph .add_node (num , pos = pos )
299+
300+ # create edges between points whose distance is less than threshold max_edge_distance
301+ for num_i , pos_i in coords .items ():
302+ for num_j , pos_j in coords .items ():
303+ if num_i < num_j :
304+ dist = math .dist (pos_i , pos_j )
305+ if dist <= max_edge_distance :
306+ graph .add_edge (num_i , num_j , weight = dist )
307+
308+ components = [list (c ) for c in nx .connected_components (graph )]
309+ len_c = [len (c ) for c in components ]
310+ len_c , components = zip (* sorted (zip (len_c , components ), reverse = True ))
311+
312+ # combine separate connected components by adding edges between nodes which are closest together
313+ if len (components ) > 1 :
314+ print (f"Graph consists of { len (components )} connected components." )
315+ if len (component_label ) != len (components ):
316+ raise ValueError (f"Length of graph components { len (components )} "
317+ f"does not match number of component labels { len (component_label )} . "
318+ "Check max_edge_distance and post-processing." )
319+
320+ # Order connected components in order of component labels
321+ # e.g. component_labels = [7, 4, 1, 11] and len_c = [600, 400, 300, 55]
322+ # get re-ordered to [300, 400, 600, 55]
323+ components_sorted = [
324+ c [1 ] for _ , c in sorted (zip (sorted (range (len (component_label )), key = lambda i : component_label [i ]),
325+ sorted (zip (len_c , components ), key = lambda x : x [0 ], reverse = True )))]
326+
327+ # Connect nodes of neighboring components that are closest together
328+ for num in range (0 , len (components_sorted ) - 1 ):
329+ min_dist = float ("inf" )
330+ closest_pair = None
331+
332+ # Compare only nodes between two neighboring components
333+ for node_a in components_sorted [num ]:
334+ for node_b in components_sorted [num + 1 ]:
335+ dist = math .dist (graph .nodes [node_a ]["pos" ], graph .nodes [node_b ]["pos" ])
336+ if dist < min_dist :
337+ min_dist = dist
338+ closest_pair = (node_a , node_b )
339+ graph .add_edge (closest_pair [0 ], closest_pair [1 ], weight = min_dist )
340+
341+ print ("Connect components in order of component labels." )
342+
343+ start_node , end_node = find_most_distant_nodes (graph )
344+
345+ # compare y-value to not get into confusion with MoBIE dimensions
346+ if graph .nodes [start_node ]["pos" ][1 ] > graph .nodes [end_node ]["pos" ][1 ]:
347+ apex_node = start_node if apex_higher else end_node
348+ base_node = end_node if apex_higher else start_node
349+ else :
350+ apex_node = end_node if apex_higher else start_node
351+ base_node = start_node if apex_higher else end_node
352+
353+ path = nx .shortest_path (graph , source = apex_node , target = base_node )
354+ path_pos = np .array ([graph .nodes [p ]["pos" ] for p in path ])
355+ path = moving_average_3d (path_pos , window = 5 )
356+ total_path .append (path )
357+
358+ # 2) Order paths to have consistent start/end points
359+ # Find starting order of first two components
360+ c1a = total_path [0 ][0 , :]
361+ c1b = total_path [0 ][- 1 , :]
362+
363+ c2a = total_path [1 ][0 , :]
364+ c2b = total_path [1 ][- 1 , :]
365+
366+ distances = [math .dist (c1a , c2a ), math .dist (c1a , c2b ), math .dist (c1b , c2a ), math .dist (c1b , c2b )]
367+ min_index = distances .index (min (distances ))
368+ if min_index in [0 , 1 ]:
369+ total_path [0 ] = np .flip (total_path [0 ], axis = 0 )
370+
371+ # Order other components from start to end
372+ for num in range (0 , len (total_path ) - 1 ):
373+ dist_connecting_nodes_1 = math .dist (total_path [num ][- 1 , :], total_path [num + 1 ][0 , :])
374+ dist_connecting_nodes_2 = math .dist (total_path [num ][- 1 , :], total_path [num + 1 ][- 1 , :])
375+ if dist_connecting_nodes_2 < dist_connecting_nodes_1 :
376+ total_path [num + 1 ] = np .flip (total_path [num + 1 ], axis = 0 )
377+
378+ # 3) Assign base/apex position to path
379+ # compare y-value to not get into confusion with MoBIE dimensions
380+ if total_path [0 ][0 , 1 ] > total_path [- 1 ][- 1 , 1 ]:
381+ if not apex_higher :
382+ total_path .reverse ()
383+ total_path = [np .flip (t ) for t in total_path ]
384+ elif apex_higher :
385+ total_path .reverse ()
386+ total_path = [np .flip (t ) for t in total_path ]
387+
388+ # 4) Assign distance of nodes by skipping intermediate space between separate components
389+ total_distance = sum ([math .dist (p [num + 1 ], p [num ]) for p in total_path for num in range (len (p ) - 1 )])
390+ path_dict = {}
391+ accumulated = 0
392+ index = 0
393+ for num , pa in enumerate (total_path ):
394+ if num == 0 :
395+ path_dict [0 ] = {"pos" : total_path [0 ][0 ], "length_fraction" : 0 }
396+ else :
397+ path_dict [index ] = {"pos" : total_path [num ][0 ], "length_fraction" : path_dict [index - 1 ]["length_fraction" ]}
398+
399+ index += 1
400+ for enum , p in enumerate (pa [1 :]):
401+ distance = math .dist (total_path [num ][enum ], p )
402+ accumulated += distance
403+ rel_dist = accumulated / total_distance
404+ path_dict [index ] = {"pos" : p , "length_fraction" : rel_dist }
405+ index += 1
406+ path_dict [index - 1 ] = {"pos" : total_path [- 1 ][- 1 , :], "length_fraction" : 1 }
407+
408+ # 5) Concatenate individual paths to form total path
409+ path = np .concatenate (total_path , axis = 0 )
410+
411+ return total_distance , path , path_dict
412+
278413def measure_run_length_ihcs (
279414 centroids : np .ndarray ,
280415 max_edge_distance : float = 30 ,
@@ -445,8 +580,13 @@ def get_centers_from_path(
445580 target_s = [s for num , s in enumerate (target_s ) if num % 2 == 1 ]
446581 else :
447582 target_s = np .linspace (0 , total_distance , n_blocks )
448- f = interp1d (cum_len , path , axis = 0 ) # fill_value="extrapolate"
449- centers = f (target_s )
583+ try :
584+ f = interp1d (cum_len , path , axis = 0 ) # fill_value="extrapolate"
585+ centers = f (target_s )
586+ except ValueError as ve :
587+ print ("Using extrapolation to fill values." )
588+ f = interp1d (cum_len , path , axis = 0 , fill_value = "extrapolate" )
589+ centers = f (target_s )
450590 return centers
451591
452592
@@ -527,6 +667,7 @@ def equidistant_centers(
527667 component_label : List [int ] = [1 ],
528668 cell_type : str = "sgn" ,
529669 n_blocks : int = 10 ,
670+ max_edge_distance : float = 30 ,
530671 offset_blocks : bool = True ,
531672) -> np .ndarray :
532673 """Find equidistant centers within the central path of the Rosenthal's canal.
@@ -546,8 +687,21 @@ def equidistant_centers(
546687 centroids = list (zip (new_subset ["anchor_x" ], new_subset ["anchor_y" ], new_subset ["anchor_z" ]))
547688
548689 if cell_type == "ihc" :
549- total_distance , path , _ = measure_run_length_ihcs (centroids , component_label = component_label )
550- return get_centers_from_path (path , total_distance , n_blocks = n_blocks , offset_blocks = offset_blocks )
690+ if len (component_label ) == 1 :
691+ total_distance , path , _ = measure_run_length_ihcs (
692+ centroids , component_label = component_label , max_edge_distance = max_edge_distance
693+ )
694+ return get_centers_from_path (path , total_distance , n_blocks = n_blocks , offset_blocks = offset_blocks )
695+ else :
696+ centroids_components = []
697+ for label in component_label :
698+ subset = table [table ["component_labels" ] == label ]
699+ subset_centroids = list (zip (subset ["anchor_x" ], subset ["anchor_y" ], subset ["anchor_z" ]))
700+ centroids_components .append (subset_centroids )
701+ total_distance , path , path_dict = measure_run_length_ihcs_multi_component (
702+ centroids_components , max_edge_distance = max_edge_distance
703+ )
704+ return get_centers_from_path_dict (path_dict , n_blocks = n_blocks , offset_blocks = offset_blocks )
551705
552706 else :
553707 if len (component_label ) == 1 :
0 commit comments