Skip to content

Commit 7942d56

Browse files
committed
Equidistant centers of path
1 parent a648de4 commit 7942d56

File tree

4 files changed

+180
-57
lines changed

4 files changed

+180
-57
lines changed

flamingo_tools/segmentation/cochlea_mapping.py

Lines changed: 86 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,25 @@
11
import math
2-
import warnings
32
from typing import List, Tuple
43

54
import networkx as nx
65
import numpy as np
76
import pandas as pd
87
from networkx.algorithms.approximation import steiner_tree
98
from scipy.ndimage import distance_transform_edt, binary_dilation, binary_closing
9+
from scipy.interpolate import interp1d
1010

11-
import flamingo_tools.segmentation.postprocessing as postprocessing
12-
from flamingo_tools.segmentation.postprocessing import graph_connected_components
11+
from flamingo_tools.segmentation.postprocessing import downscaled_centroids
1312

1413

1514
def find_most_distant_nodes(G: nx.classes.graph.Graph, weight: str = 'weight') -> Tuple[float, float]:
1615
"""Find the most distant nodes in a graph.
1716
1817
Args:
19-
G: Input graph
18+
G: Input graph.
2019
2120
Returns:
22-
Node 1
23-
Node 2
21+
Node 1.
22+
Node 2.
2423
"""
2524
all_lengths = dict(nx.all_pairs_dijkstra_path_length(G, weight=weight))
2625
max_dist = 0
@@ -40,9 +39,12 @@ def central_path_edt_graph(mask: np.ndarray, start: Tuple[int], end: Tuple[int])
4039
"""Find the central path within a binary mask between a start and an end coordinate.
4140
4241
Args:
43-
mask: Binary mask of volume
44-
start: Starting coordinate
45-
end: End coordinate
42+
mask: Binary mask of volume.
43+
start: Starting coordinate.
44+
end: End coordinate.
45+
46+
Returns:
47+
Coordinates of central path.
4648
"""
4749
dt = distance_transform_edt(mask)
4850
G = nx.Graph()
@@ -74,11 +76,11 @@ def moving_average_3d(path: np.ndarray, window: int = 5) -> np.ndarray:
7476
"""Smooth a 3D path with a simple moving average filter.
7577
7678
Args:
77-
path: ndarray of shape (N, 3)
78-
window: half-window size; actual window = 2*window + 1
79+
path: ndarray of shape (N, 3).
80+
window: half-window size; actual window = 2*window + 1.
7981
8082
Returns:
81-
smoothed path: ndarray of same shape
83+
smoothed path: ndarray of same shape.
8284
"""
8385
kernel_size = 2 * window + 1
8486
kernel = np.ones(kernel_size) / kernel_size
@@ -102,11 +104,15 @@ def measure_run_length_sgns(centroids: np.ndarray, scale_factor=10):
102104
6) The points of the path are fed into a dictionary along with the fractional length.
103105
104106
Args:
105-
centroids: Centroids of the SGN segmentation, ndarray of shape (N, 3)
107+
centroids: Centroids of the SGN segmentation, ndarray of shape (N, 3).
106108
scale_factor: Downscaling factor for finding the central path.
107109
110+
Returns:
111+
Total distance of the path.
112+
Path as an nd.array of positions.
113+
A dictionary containing the position and the length fraction of each point in the path.
108114
"""
109-
mask = postprocessing.downscaled_centroids(centroids, scale_factor=scale_factor, downsample_mode="capped")
115+
mask = downscaled_centroids(centroids, scale_factor=scale_factor, downsample_mode="capped")
110116
mask = binary_dilation(mask, np.ones((3, 3, 3)), iterations=1)
111117
mask = binary_closing(mask, np.ones((3, 3, 3)), iterations=1)
112118
pts = np.argwhere(mask == 1)
@@ -137,23 +143,31 @@ def measure_run_length_sgns(centroids: np.ndarray, scale_factor=10):
137143
path_dict[num + 1] = {"pos": p, "length_fraction": rel_dist}
138144
path_dict[len(path)] = {"pos": path[-1], "length_fraction": 1}
139145

140-
return total_distance, path_dict
146+
return total_distance, path, path_dict
141147

142148

143-
def measure_run_length_ihcs(graph, weight="weight"):
149+
def measure_run_length_ihcs(centroids):
144150
"""Measure the run lengths of the IHC segmentation
145151
by finding the shortest path between the most distant nodes in a Steiner Tree.
146152
147153
Args:
148-
graph: Input graph.
154+
centroids: Centroids of SGN segmentation.
155+
156+
Returns:
157+
Total distance of the path.
158+
Path as an nd.array of positions.
159+
A dictionary containing the position and the length fraction of each point in the path.
149160
"""
150-
u, v = find_most_distant_nodes(graph)
161+
graph = nx.Graph()
162+
for num, pos in enumerate(centroids):
163+
graph.add_node(num, pos=pos)
151164
# approximate Steiner tree and find shortest path between the two most distant nodes
152165
terminals = set(graph.nodes()) # All nodes are required
153166
# Approximate Steiner Tree over all nodes
154-
T = steiner_tree(graph, terminals, weight=weight)
155-
path = nx.shortest_path(T, source=u, target=v, weight=weight)
156-
total_distance = nx.path_weight(T, path, weight=weight)
167+
T = steiner_tree(graph, terminals)
168+
u, v = find_most_distant_nodes(T)
169+
path = nx.shortest_path(T, source=u, target=v)
170+
total_distance = nx.path_weight(T, path, weight="weight")
157171

158172
# assign relative distance to points on path
159173
path_dict = {}
@@ -166,7 +180,7 @@ def measure_run_length_ihcs(graph, weight="weight"):
166180
path_dict[num + 1] = {"pos": graph.nodes[p]["pos"], "length_fraction": rel_dist}
167181
path_dict[len(path)] = {"pos": graph.nodes[path[-1]]["pos"], "length_fraction": 1}
168182

169-
return total_distance, path_dict
183+
return total_distance, path, path_dict
170184

171185

172186
def map_frequency(table: pd.DataFrame):
@@ -176,7 +190,10 @@ def map_frequency(table: pd.DataFrame):
176190
For mice: fit values between minimal (1kHz) and maximal (80kHz) values
177191
178192
Args:
179-
table:
193+
table: Dataframe containing the segmentation.
194+
195+
Returns:
196+
Dataframe containing frequency in an additional column 'frequency[kHz]'.
180197
"""
181198
var_k = 0.88
182199
fmin = 1
@@ -189,11 +206,51 @@ def map_frequency(table: pd.DataFrame):
189206
return table
190207

191208

209+
def equidistant_centers(
210+
table: pd.DataFrame,
211+
component_label: List[int] = [1],
212+
cell_type: str = "sgn",
213+
n_blocks: int = 10,
214+
offset_blocks: bool = True,
215+
) -> np.ndarray:
216+
"""Find equidistant centers within the central path of the Rosenthal's canal.
217+
218+
Args:
219+
table: Dataframe containing centroids of SGN segmentation.
220+
component_label: List of components for centroid subset.
221+
cell_type: Cell type of the segmentation.
222+
n_blocks: Number of equidistant centers for block creation.
223+
offset_block: Centers are shifted by half a length if True. Avoid centers at the start/end of the path.
224+
225+
Returns:
226+
Equidistant centers as float values
227+
"""
228+
# subset of centroids for given component label(s)
229+
new_subset = table[table["component_labels"].isin(component_label)]
230+
centroids = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"]))
231+
232+
if cell_type == "ihc":
233+
total_distance, path, _ = measure_run_length_ihcs(centroids)
234+
235+
else:
236+
total_distance, path, _ = measure_run_length_sgns(centroids)
237+
238+
diffs = np.diff(path, axis=0)
239+
seg_lens = np.linalg.norm(diffs, axis=1)
240+
cum_len = np.insert(np.cumsum(seg_lens), 0, 0)
241+
if offset_blocks:
242+
target_s = np.linspace(0, total_distance, n_blocks * 2 + 1)
243+
target_s = [s for num, s in enumerate(target_s) if num % 2 == 1]
244+
else:
245+
target_s = np.linspace(0, total_distance, n_blocks)
246+
f = interp1d(cum_len, path, axis=0)
247+
centers = f(target_s)
248+
return centers
249+
250+
192251
def tonotopic_mapping(
193252
table: pd.DataFrame,
194253
component_label: List[int] = [1],
195-
max_edge_distance: float = 30,
196-
min_component_length: int = 50,
197254
cell_type: str = "ihc"
198255
) -> pd.DataFrame:
199256
"""Tonotopic mapping of IHCs by supplying a table with component labels.
@@ -202,10 +259,7 @@ def tonotopic_mapping(
202259
Args:
203260
table: Dataframe of segmentation table.
204261
component_label: List of component labels to evaluate.
205-
max_edge_distance: Maximal edge distance to connect nodes.
206-
min_component_length: Minimal number of nodes in component.
207262
cell_type: Cell type of segmentation.
208-
Filter factor: Fraction of nodes to remove before mapping.
209263
210264
Returns:
211265
Table with tonotopic label for cells.
@@ -215,33 +269,20 @@ def tonotopic_mapping(
215269
centroids = list(zip(new_subset["anchor_x"], new_subset["anchor_y"], new_subset["anchor_z"]))
216270
label_ids = [int(i) for i in list(new_subset["label_id"])]
217271

218-
# create graph with connected components
219-
coords = {}
220-
for index, element in zip(label_ids, centroids):
221-
coords[index] = element
222-
223-
components, graph = graph_connected_components(coords, max_edge_distance, min_component_length)
224-
if len(components) > 1:
225-
warnings.warn(f"There are {len(components)} connected components, expected 1. "
226-
"Check parameters for post-processing (max_edge_distance, min_component_length).")
227-
228-
unfiltered_graph = graph.copy()
229-
230272
if cell_type == "ihc":
231-
total_distance, path_dict = measure_run_length_ihcs(graph)
273+
total_distance, _, path_dict = measure_run_length_ihcs(centroids)
232274

233275
else:
234-
total_distance, path_dict = measure_run_length_sgns(centroids)
276+
total_distance, _, path_dict = measure_run_length_sgns(centroids)
235277

236278
# add missing nodes from component and compute distance to path
237-
pos = nx.get_node_attributes(unfiltered_graph, 'pos')
238279
node_dict = {}
239-
for c in label_ids:
280+
for num, c in enumerate(label_ids):
240281
min_dist = float('inf')
241282
nearest_node = None
242283

243284
for key in path_dict.keys():
244-
dist = math.dist(pos[c], path_dict[key]["pos"])
285+
dist = math.dist(centroids[num], path_dict[key]["pos"])
245286
if dist < min_dist:
246287
min_dist = dist
247288
nearest_node = key
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import argparse
2+
import json
3+
import os
4+
from typing import Optional
5+
6+
import pandas as pd
7+
from flamingo_tools.s3_utils import get_s3_path
8+
from flamingo_tools.segmentation.cochlea_mapping import equidistant_centers
9+
10+
11+
def repro_equidistant_centers(
12+
ddict: dict,
13+
output_path: str,
14+
s3_credentials: Optional[str] = None,
15+
s3_bucket_name: Optional[str] = None,
16+
s3_service_endpoint: Optional[str] = None,
17+
force_overwrite: Optional[bool] = None,
18+
):
19+
default_cell_type = "ihc"
20+
default_component_list = [1]
21+
default_halo_size = [256, 256, 100]
22+
default_n_blocks = 6
23+
24+
with open(ddict, 'r') as myfile:
25+
data = myfile.read()
26+
param_dicts = json.loads(data)
27+
28+
out_dict = []
29+
30+
if os.path.isfile(output_path) and not force_overwrite:
31+
print(f"Skipping {output_path}. File already exists.")
32+
33+
for dic in param_dicts:
34+
cochlea = dic["cochlea"]
35+
img_channel = dic["image_channel"]
36+
seg_channel = dic["segmentation_channel"]
37+
38+
s3_path = os.path.join(f"{cochlea}", "tables", f"{seg_channel}", "default.tsv")
39+
print(f"Finding equidistant centers for {cochlea}.")
40+
41+
tsv_path, fs = get_s3_path(s3_path, bucket_name=s3_bucket_name,
42+
service_endpoint=s3_service_endpoint, credential_file=s3_credentials)
43+
with fs.open(tsv_path, 'r') as f:
44+
table = pd.read_csv(f, sep="\t")
45+
46+
cell_type = dic["type"] if "type" in dic else default_cell_type
47+
component_list = dic["component_list"] if "component_list" in dic else default_component_list
48+
halo_size = dic["halo_size"] if "halo_size" in dic else default_halo_size
49+
n_blocks = dic["n_blocks"] if "n_blocks" in dic else default_n_blocks
50+
51+
centers = equidistant_centers(table, component_label=component_list, cell_type=cell_type, n_blocks=n_blocks)
52+
centers = [[int(c) for c in center] for center in centers]
53+
ddict = {"cochlea": cochlea}
54+
ddict["image_channel"] = img_channel
55+
ddict["crop_centers"] = centers
56+
ddict["halo_size"] = halo_size
57+
out_dict.append(ddict)
58+
59+
with open(output_path, "w") as f:
60+
json.dump(out_dict, f, indent='\t', separators=(',', ': '))
61+
62+
63+
def main():
64+
parser = argparse.ArgumentParser(
65+
description="Script to extract region of interest (ROI) block around center coordinate.")
66+
67+
parser.add_argument('-i', '--input', type=str, required=True, help="Input JSON dictionary.")
68+
parser.add_argument('-o', "--output", type=str, required=True, help="Output JSON dictionary.")
69+
70+
parser.add_argument("--force", action="store_true", help="Forcefully overwrite output.")
71+
parser.add_argument("--s3_credentials", type=str, default=None,
72+
help="Input file containing S3 credentials. "
73+
"Optional if AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY were exported.")
74+
parser.add_argument("--s3_bucket_name", type=str, default=None,
75+
help="S3 bucket name. Optional if BUCKET_NAME was exported.")
76+
parser.add_argument("--s3_service_endpoint", type=str, default=None,
77+
help="S3 service endpoint. Optional if SERVICE_ENDPOINT was exported.")
78+
79+
args = parser.parse_args()
80+
81+
repro_equidistant_centers(
82+
args.input, args.output,
83+
args.s3_credentials, args.s3_bucket_name, args.s3_service_endpoint,
84+
args.force,
85+
)
86+
87+
88+
if __name__ == "__main__":
89+
90+
main()

reproducibility/tonotopic_mapping/repro_tonotopic_mapping.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ def repro_tonotopic_mapping(
1717
force_overwrite: Optional[bool] = None,
1818
):
1919
default_cell_type = "ihc"
20-
default_max_edge_distance = 30
21-
default_min_length = 50
2220
default_component_list = [1]
2321

2422
remove_columns = ["tonotopic_label",
@@ -49,17 +47,14 @@ def repro_tonotopic_mapping(
4947
table = pd.read_csv(f, sep="\t")
5048

5149
cell_type = dic["type"] if "type" in dic else default_cell_type
52-
max_edge_distance = dic["max_edge_distance"] if "max_edge_distance" in dic else default_max_edge_distance
53-
min_component_length = dic["min_component_length"] if "min_component_length" in dic else default_min_length
5450
component_list = dic["component_list"] if "component_list" in dic else default_component_list
5551

5652
for column in remove_columns:
5753
if column in list(table.columns):
5854
table = table.drop(column, axis=1)
5955

6056
if not os.path.isfile(output_table_path) or force_overwrite:
61-
table = tonotopic_mapping(table, component_label=component_list, max_edge_distance=max_edge_distance,
62-
min_component_length=min_component_length, cell_type=cell_type)
57+
table = tonotopic_mapping(table, component_label=component_list, cell_type=cell_type)
6358

6459
table.to_csv(output_table_path, sep="\t", index=False)
6560

scripts/prediction/tonotopic_mapping.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,9 @@ def main():
1313
"Either locally or on an S3 bucket.")
1414

1515
parser.add_argument("-i", "--input", required=True, help="Input table with IHC segmentation.")
16-
parser.add_argument("-o", "--output", required=True, help="Output path for post-processed table.")
16+
parser.add_argument("-o", "--output", required=True, help="Output path for json file with cropping parameters.")
1717

18-
parser.add_argument("-t", "--type", type=str, default="ihc", help="Cell type of segmentation.")
19-
parser.add_argument("--edge_distance", type=float, default=30, help="Maximal edge distance between nodes.")
20-
parser.add_argument("--component_length", type=int, default=50, help="Minimal number of nodes in component.")
18+
parser.add_argument("-t", "--type", type=str, default="sgn", help="Cell type of segmentation.")
2119

2220
parser.add_argument("--s3", action="store_true", help="Flag for using S3 bucket.")
2321
parser.add_argument("--s3_credentials", type=str, default=None,
@@ -41,8 +39,7 @@ def main():
4139
tsv_table = pd.read_csv(f, sep="\t")
4240

4341
table = tonotopic_mapping(
44-
tsv_table, max_edge_distance=args.edge_distance, min_component_length=args.component_length,
45-
cell_type=args.type,
42+
tsv_table, cell_type=args.type,
4643
)
4744

4845
table.to_csv(args.output, sep="\t", index=False)

0 commit comments

Comments
 (0)