Skip to content

Commit 48d5df3

Browse files
committed
Find centers for multiple IHC components
1 parent 902290d commit 48d5df3

File tree

2 files changed

+165
-5
lines changed

2 files changed

+165
-5
lines changed

flamingo_tools/segmentation/cochlea_mapping.py

Lines changed: 158 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,141 @@ def measure_run_length_sgns(
275275
return total_distance, path, path_dict
276276

277277

278+
def measure_run_length_ihcs_multi_component(
279+
centroids_components: List[np.ndarray],
280+
max_edge_distance: float = 30,
281+
apex_higher: bool = True,
282+
component_label: List[int] = [1],
283+
) -> Tuple[float, np.ndarray, dict]:
284+
"""Adaptation of measure_run_length_sgns_multi_component to IHCs.
285+
286+
"""
287+
total_path = []
288+
print(f"Evaluating {len(centroids_components)} components.")
289+
# 1) Process centroids for each component
290+
for centroids in centroids_components:
291+
graph = nx.Graph()
292+
coords = {}
293+
labels = [int(i) for i in range(len(centroids))]
294+
for index, element in zip(labels, centroids):
295+
coords[index] = element
296+
297+
for num, pos in coords.items():
298+
graph.add_node(num, pos=pos)
299+
300+
# create edges between points whose distance is less than threshold max_edge_distance
301+
for num_i, pos_i in coords.items():
302+
for num_j, pos_j in coords.items():
303+
if num_i < num_j:
304+
dist = math.dist(pos_i, pos_j)
305+
if dist <= max_edge_distance:
306+
graph.add_edge(num_i, num_j, weight=dist)
307+
308+
components = [list(c) for c in nx.connected_components(graph)]
309+
len_c = [len(c) for c in components]
310+
len_c, components = zip(*sorted(zip(len_c, components), reverse=True))
311+
312+
# combine separate connected components by adding edges between nodes which are closest together
313+
if len(components) > 1:
314+
print(f"Graph consists of {len(components)} connected components.")
315+
if len(component_label) != len(components):
316+
raise ValueError(f"Length of graph components {len(components)} "
317+
f"does not match number of component labels {len(component_label)}. "
318+
"Check max_edge_distance and post-processing.")
319+
320+
# Order connected components in order of component labels
321+
# e.g. component_labels = [7, 4, 1, 11] and len_c = [600, 400, 300, 55]
322+
# get re-ordered to [300, 400, 600, 55]
323+
components_sorted = [
324+
c[1] for _, c in sorted(zip(sorted(range(len(component_label)), key=lambda i: component_label[i]),
325+
sorted(zip(len_c, components), key=lambda x: x[0], reverse=True)))]
326+
327+
# Connect nodes of neighboring components that are closest together
328+
for num in range(0, len(components_sorted) - 1):
329+
min_dist = float("inf")
330+
closest_pair = None
331+
332+
# Compare only nodes between two neighboring components
333+
for node_a in components_sorted[num]:
334+
for node_b in components_sorted[num + 1]:
335+
dist = math.dist(graph.nodes[node_a]["pos"], graph.nodes[node_b]["pos"])
336+
if dist < min_dist:
337+
min_dist = dist
338+
closest_pair = (node_a, node_b)
339+
graph.add_edge(closest_pair[0], closest_pair[1], weight=min_dist)
340+
341+
print("Connect components in order of component labels.")
342+
343+
start_node, end_node = find_most_distant_nodes(graph)
344+
345+
# compare y-value to not get into confusion with MoBIE dimensions
346+
if graph.nodes[start_node]["pos"][1] > graph.nodes[end_node]["pos"][1]:
347+
apex_node = start_node if apex_higher else end_node
348+
base_node = end_node if apex_higher else start_node
349+
else:
350+
apex_node = end_node if apex_higher else start_node
351+
base_node = start_node if apex_higher else end_node
352+
353+
path = nx.shortest_path(graph, source=apex_node, target=base_node)
354+
path_pos = np.array([graph.nodes[p]["pos"] for p in path])
355+
path = moving_average_3d(path_pos, window=5)
356+
total_path.append(path)
357+
358+
# 2) Order paths to have consistent start/end points
359+
# Find starting order of first two components
360+
c1a = total_path[0][0, :]
361+
c1b = total_path[0][-1, :]
362+
363+
c2a = total_path[1][0, :]
364+
c2b = total_path[1][-1, :]
365+
366+
distances = [math.dist(c1a, c2a), math.dist(c1a, c2b), math.dist(c1b, c2a), math.dist(c1b, c2b)]
367+
min_index = distances.index(min(distances))
368+
if min_index in [0, 1]:
369+
total_path[0] = np.flip(total_path[0], axis=0)
370+
371+
# Order other components from start to end
372+
for num in range(0, len(total_path) - 1):
373+
dist_connecting_nodes_1 = math.dist(total_path[num][-1, :], total_path[num+1][0, :])
374+
dist_connecting_nodes_2 = math.dist(total_path[num][-1, :], total_path[num+1][-1, :])
375+
if dist_connecting_nodes_2 < dist_connecting_nodes_1:
376+
total_path[num+1] = np.flip(total_path[num+1], axis=0)
377+
378+
# 3) Assign base/apex position to path
379+
# compare y-value to not get into confusion with MoBIE dimensions
380+
if total_path[0][0, 1] > total_path[-1][-1, 1]:
381+
if not apex_higher:
382+
total_path.reverse()
383+
total_path = [np.flip(t) for t in total_path]
384+
elif apex_higher:
385+
total_path.reverse()
386+
total_path = [np.flip(t) for t in total_path]
387+
388+
# 4) Assign distance of nodes by skipping intermediate space between separate components
389+
total_distance = sum([math.dist(p[num + 1], p[num]) for p in total_path for num in range(len(p) - 1)])
390+
path_dict = {}
391+
accumulated = 0
392+
index = 0
393+
for num, pa in enumerate(total_path):
394+
if num == 0:
395+
path_dict[0] = {"pos": total_path[0][0], "length_fraction": 0}
396+
else:
397+
path_dict[index] = {"pos": total_path[num][0], "length_fraction": path_dict[index-1]["length_fraction"]}
398+
399+
index += 1
400+
for enum, p in enumerate(pa[1:]):
401+
distance = math.dist(total_path[num][enum], p)
402+
accumulated += distance
403+
rel_dist = accumulated / total_distance
404+
path_dict[index] = {"pos": p, "length_fraction": rel_dist}
405+
index += 1
406+
path_dict[index-1] = {"pos": total_path[-1][-1, :], "length_fraction": 1}
407+
408+
# 5) Concatenate individual paths to form total path
409+
path = np.concatenate(total_path, axis=0)
410+
411+
return total_distance, path, path_dict
412+
278413
def measure_run_length_ihcs(
279414
centroids: np.ndarray,
280415
max_edge_distance: float = 30,
@@ -445,8 +580,13 @@ def get_centers_from_path(
445580
target_s = [s for num, s in enumerate(target_s) if num % 2 == 1]
446581
else:
447582
target_s = np.linspace(0, total_distance, n_blocks)
448-
f = interp1d(cum_len, path, axis=0) # fill_value="extrapolate"
449-
centers = f(target_s)
583+
try:
584+
f = interp1d(cum_len, path, axis=0) # fill_value="extrapolate"
585+
centers = f(target_s)
586+
except ValueError as ve:
587+
print("Using extrapolation to fill values.")
588+
f = interp1d(cum_len, path, axis=0, fill_value="extrapolate")
589+
centers = f(target_s)
450590
return centers
451591

452592

@@ -527,6 +667,7 @@ def equidistant_centers(
527667
component_label: List[int] = [1],
528668
cell_type: str = "sgn",
529669
n_blocks: int = 10,
670+
max_edge_distance: float = 30,
530671
offset_blocks: bool = True,
531672
) -> np.ndarray:
532673
"""Find equidistant centers within the central path of the Rosenthal's canal.
@@ -546,8 +687,21 @@ def equidistant_centers(
546687
centroids = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"]))
547688

548689
if cell_type == "ihc":
549-
total_distance, path, _ = measure_run_length_ihcs(centroids, component_label=component_label)
550-
return get_centers_from_path(path, total_distance, n_blocks=n_blocks, offset_blocks=offset_blocks)
690+
if len(component_label) == 1:
691+
total_distance, path, _ = measure_run_length_ihcs(
692+
centroids, component_label=component_label, max_edge_distance=max_edge_distance
693+
)
694+
return get_centers_from_path(path, total_distance, n_blocks=n_blocks, offset_blocks=offset_blocks)
695+
else:
696+
centroids_components = []
697+
for label in component_label:
698+
subset = table[table["component_labels"] == label]
699+
subset_centroids = list(zip(subset["anchor_x"], subset["anchor_y"], subset["anchor_z"]))
700+
centroids_components.append(subset_centroids)
701+
total_distance, path, path_dict = measure_run_length_ihcs_multi_component(
702+
centroids_components, max_edge_distance=max_edge_distance
703+
)
704+
return get_centers_from_path_dict(path_dict, n_blocks=n_blocks, offset_blocks=offset_blocks)
551705

552706
else:
553707
if len(component_label) == 1:

reproducibility/block_extraction/repro_equidistant_centers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def repro_equidistant_centers(
2020
default_component_list = [1]
2121
default_halo_size = [256, 256, 128]
2222
default_n_blocks = 6
23+
default_max_edge_distance = 30
2324

2425
with open(input_path, 'r') as myfile:
2526
data = myfile.read()
@@ -57,8 +58,13 @@ def update_dic(dic, keyword, default):
5758
component_list = update_dic(dic, "component_list", default_component_list)
5859
_ = update_dic(dic, "halo_size", default_halo_size)
5960
n_blocks = update_dic(dic, "n_blocks", default_n_blocks)
61+
max_edge_distance = update_dic(dic, "max_edge_distance", default_max_edge_distance)
62+
63+
centers = equidistant_centers(
64+
table, component_label=component_list, cell_type=cell_type,
65+
n_blocks=n_blocks, max_edge_distance=max_edge_distance
66+
)
6067

61-
centers = equidistant_centers(table, component_label=component_list, cell_type=cell_type, n_blocks=n_blocks)
6268
centers = [[round(c) for c in center] for center in centers]
6369

6470
dic["crop_centers"] = centers

0 commit comments

Comments
 (0)