Skip to content

Commit b03b42e

Browse files
Merge pull request #37 from computational-cell-analytics/synapse-analysis
Add code for synapse model export and for initial synapse analysis
2 parents 16a6a94 + 9d6344f commit b03b42e

19 files changed

+1084
-154
lines changed

flamingo_tools/validation.py

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def fetch_data_for_evaluation(
4343
seg_name: str = "SGN_v2",
4444
z_extent: int = 0,
4545
components_for_postprocessing: Optional[List[int]] = None,
46+
cochlea: Optional[str] = None,
47+
extra_data: Optional[str] = None,
4648
) -> Tuple[np.ndarray, pd.DataFrame]:
4749
"""Fetch segmentation from S3 matching the annotation path for evaluation.
4850
@@ -53,28 +55,31 @@ def fetch_data_for_evaluation(
5355
z_extent: Additional z-slices to load from the segmentation.
5456
components_for_postprocessing: The component ids for restricting the segmentation to.
5557
Choose [1] for the default componentn containing the helix.
58+
cochlea: Optional name of the cochlea.
59+
extra_data: Extra data to fetch.
5660
5761
Returns:
5862
The segmentation downloaded from the S3 bucket.
5963
The annotations loaded from pandas and matching the segmentation.
6064
"""
6165
# Load the annotations and normalize them for the given z-extent.
6266
annotations = pd.read_csv(annotation_path)
63-
annotations = annotations.drop(columns="index")
67+
if "index" in annotations.columns:
68+
annotations = annotations.drop(columns="index")
6469
if z_extent == 0: # If we don't have a z-extent then we just drop the first axis and rename the other two.
6570
annotations = annotations.drop(columns="axis-0")
6671
annotations = annotations.rename(columns={"axis-1": "axis-0", "axis-2": "axis-1"})
67-
else: # Otherwise we have to center the first axis.
68-
# TODO
69-
raise NotImplementedError
7072

7173
# Load the segmentaiton from cache path if it is given and if it is already cached.
7274
if cache_path is not None and os.path.exists(cache_path):
7375
segmentation = imageio.imread(cache_path)
7476
return segmentation, annotations
7577

7678
# Parse which ID and which cochlea from the name.
77-
cochlea, slice_id = _parse_annotation_path(annotation_path)
79+
if cochlea is None:
80+
cochlea, slice_id = _parse_annotation_path(annotation_path)
81+
else:
82+
_, slice_id = _parse_annotation_path(annotation_path)
7883

7984
# Open the S3 connection, get the path to the SGN segmentation in S3.
8085
internal_path = os.path.join(cochlea, "images", "ome-zarr", f"{seg_name}.ome.zarr")
@@ -111,6 +116,14 @@ def fetch_data_for_evaluation(
111116
if cache_path is not None:
112117
imageio.imwrite(cache_path, segmentation, compression="zlib")
113118

119+
if extra_data is not None:
120+
internal_path = os.path.join(cochlea, "images", "ome-zarr", f"{extra_data}.ome.zarr")
121+
s3_store, fs = get_s3_path(internal_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT)
122+
input_key = "s0"
123+
with zarr.open(s3_store, mode="r") as f:
124+
extra_im_data = f[input_key][roi]
125+
return segmentation, annotations, extra_im_data
126+
114127
return segmentation, annotations
115128

116129

@@ -347,6 +360,62 @@ def union(a, b):
347360
return consensus_df, unmatched_df
348361

349362

363+
def match_detections(
364+
detections: np.ndarray,
365+
annotations: np.ndarray,
366+
max_dist: float
367+
):
368+
"""One-to-one matching between 3-D detections and ground-truth points.
369+
370+
Args:
371+
detections: N x 3 candidate detections.
372+
annotations: M x 3 ground-truth annotations for the reference points.
373+
max_dist: Maximum Euclidean distance allowed for a match.
374+
375+
Returns:
376+
Indices in `detections` that were matched (true positives).
377+
Indices in `annotations` that were matched (true positives).
378+
Unmatched detection indices (false positives).
379+
Unmatched annotation indices (false negatives).
380+
"""
381+
det = np.asarray(detections, dtype=float)
382+
ann = np.asarray(annotations, dtype=float)
383+
N, M = len(det), len(ann)
384+
385+
# trivial corner cases --------------------------------------------------------
386+
if N == 0:
387+
return np.empty(0, int), np.empty(0, int), np.empty(0, int), np.arange(M)
388+
if M == 0:
389+
return np.empty(0, int), np.empty(0, int), np.arange(N), np.empty(0, int)
390+
391+
# 1. build sparse radius-filtered distance matrix -----------------------------
392+
tree_det = cKDTree(det)
393+
tree_ann = cKDTree(ann)
394+
coo = tree_det.sparse_distance_matrix(tree_ann, max_dist, output_type="coo_matrix")
395+
396+
if coo.nnz == 0: # nothing is close enough
397+
return np.empty(0, int), np.empty(0, int), np.arange(N), np.arange(M)
398+
399+
cost = np.full((N, M), 5 * max_dist, dtype=float)
400+
cost[coo.row, coo.col] = coo.data # fill only existing edges
401+
402+
# 2. optimal one-to-one assignment (Hungarian) --------------------------------
403+
row_ind, col_ind = linear_sum_assignment(cost)
404+
405+
# Filter assignments that were padded with +∞ cost for non-existent edges
406+
# (linear_sum_assignment automatically does that padding internally).
407+
valid_mask = cost[row_ind, col_ind] <= max_dist
408+
tp_det_ids = row_ind[valid_mask]
409+
tp_ann_ids = col_ind[valid_mask]
410+
assert len(tp_det_ids) == len(tp_ann_ids)
411+
412+
# 3. derive FP / FN -----------------------------------------------------------
413+
fp_det_ids = np.setdiff1d(np.arange(N), tp_det_ids, assume_unique=True)
414+
fn_ann_ids = np.setdiff1d(np.arange(M), tp_ann_ids, assume_unique=True)
415+
416+
return tp_det_ids, tp_ann_ids, fp_det_ids, fn_ann_ids
417+
418+
350419
def for_visualization(segmentation, annotations, matches):
351420
green_red = ["#00FF00", "#FF0000"]
352421

scripts/export_lower_resolution.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,22 @@
77
import zarr
88

99
from flamingo_tools.s3_utils import get_s3_path, BUCKET_NAME, SERVICE_ENDPOINT
10-
from skimage.segmentation import relabel_sequential
10+
# from skimage.segmentation import relabel_sequential
1111

1212

13-
def filter_component(fs, segmentation, cochlea, seg_name):
13+
def filter_component(fs, segmentation, cochlea, seg_name, components):
1414
# First, we download the MoBIE table for this segmentation.
1515
internal_path = os.path.join(BUCKET_NAME, cochlea, "tables", seg_name, "default.tsv")
1616
with fs.open(internal_path, "r") as f:
1717
table = pd.read_csv(f, sep="\t")
1818

1919
# Then we get the ids for the components and us them to filter the segmentation.
20-
component_mask = np.isin(table.component_labels.values, [1])
20+
component_mask = np.isin(table.component_labels.values, components)
2121
keep_label_ids = table.label_id.values[component_mask].astype("int64")
2222
filter_mask = ~np.isin(segmentation, keep_label_ids)
2323
segmentation[filter_mask] = 0
2424

25-
segmentation, _, _ = relabel_sequential(segmentation)
25+
# segmentation, _, _ = relabel_sequential(segmentation)
2626
return segmentation
2727

2828

@@ -42,8 +42,10 @@ def export_lower_resolution(args):
4242
with zarr.open(s3_store, mode="r") as f:
4343
data = f[input_key][:]
4444
print(data.shape)
45-
if args.filter_by_component:
46-
data = filter_component(fs, data, args.cochlea, channel)
45+
if args.filter_by_components is not None:
46+
data = filter_component(fs, data, args.cochlea, channel, args.filter_by_components)
47+
if args.binarize:
48+
data = (data > 0).astype("uint8")
4749
tifffile.imwrite(out_path, data, bigtiff=True, compression="zlib")
4850

4951

@@ -53,7 +55,8 @@ def main():
5355
parser.add_argument("--scale", "-s", type=int, required=True)
5456
parser.add_argument("--output_folder", "-o", required=True)
5557
parser.add_argument("--channels", nargs="+", default=["PV", "VGlut3", "CTBP2"])
56-
parser.add_argument("--filter_by_component", action="store_true")
58+
parser.add_argument("--filter_by_components", nargs="+", type=int, default=None)
59+
parser.add_argument("--binarize", action="store_true")
5760
args = parser.parse_args()
5861

5962
export_lower_resolution(args)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import os
2+
import json
3+
4+
import pandas as pd
5+
6+
from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target
7+
8+
OUTPUT_FOLDER = "./results/otof-measurements"
9+
10+
11+
def check_project(save=False):
12+
s3 = create_s3_target()
13+
14+
# content = s3.open(f"{BUCKET_NAME}/project.json", mode="r", encoding="utf-8")
15+
# x = json.loads(content.read())
16+
# print(x)
17+
# return
18+
19+
cochleae = ["M_AMD_OTOF1_L", "M_AMD_OTOF2_L"]
20+
ihc_name = "IHC_v2"
21+
22+
for cochlea in cochleae:
23+
content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8")
24+
info = json.loads(content.read())
25+
sources = info["sources"]
26+
27+
source_names = list(sources.keys())
28+
assert ihc_name in source_names
29+
30+
# Get the ihc table folder.
31+
ihc = sources[ihc_name]["segmentation"]
32+
table_folder = os.path.join(BUCKET_NAME, cochlea, ihc["tableData"]["tsv"]["relativePath"])
33+
34+
# For debugging.
35+
# print(s3.ls(table_folder))
36+
37+
default_table = s3.open(os.path.join(table_folder, "default.tsv"), mode="rb")
38+
default_table = pd.read_csv(default_table, sep="\t")
39+
40+
measurement_table = s3.open(os.path.join(table_folder, "Apha_IHC-v2_object-measures.tsv"), mode="rb")
41+
measurement_table = pd.read_csv(measurement_table, sep="\t")
42+
if save:
43+
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
44+
measurement_table.to_csv(os.path.join(OUTPUT_FOLDER, f"{cochlea}.csv"), index=False)
45+
46+
print("Cochlea:", cochlea)
47+
print("AlphaTag measurements for:", len(measurement_table), "IHCs:")
48+
print(measurement_table.columns)
49+
print()
50+
51+
52+
def plot_distribution():
53+
import seaborn as sns
54+
import matplotlib.pyplot as plt
55+
56+
table1 = "./results/otof-measurements/M_AMD_OTOF1_L.csv"
57+
table2 = "./results/otof-measurements/M_AMD_OTOF2_L.csv"
58+
59+
table1 = pd.read_csv(table1)
60+
table2 = pd.read_csv(table2)
61+
62+
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
63+
sns.histplot(data=table1, x="mean", bins=32, ax=axes[0])
64+
axes[0].set_title("Dual AAV")
65+
sns.histplot(data=table2, x="mean", bins=32, ax=axes[1])
66+
axes[1].set_title("Overloaded AAV")
67+
68+
fig.suptitle("OTOF Gene Therapy - Mean AlphaTag Intensity of IHCs")
69+
plt.tight_layout()
70+
71+
plt.show()
72+
73+
74+
def main():
75+
# check_project(save=True)
76+
plot_distribution()
77+
78+
79+
if __name__ == "__main__":
80+
main()
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import os
2+
import json
3+
4+
import pandas as pd
5+
6+
from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target
7+
8+
OUTPUT_FOLDER = "./results/sgn-measurements"
9+
10+
11+
def check_project(save=False):
12+
s3 = create_s3_target()
13+
14+
content = s3.open(f"{BUCKET_NAME}/project.json", mode="r", encoding="utf-8")
15+
project_info = json.loads(content.read())
16+
17+
cochleae = [
18+
"M_LR_000144_L", "M_LR_000145_L", "M_LR_000151_R", "M_LR_000155_L",
19+
]
20+
21+
sgn_name = "SGN_resized_v2"
22+
for cochlea in cochleae:
23+
assert cochlea in project_info["datasets"]
24+
25+
content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8")
26+
info = json.loads(content.read())
27+
sources = info["sources"]
28+
29+
source_names = list(sources.keys())
30+
if sgn_name not in source_names:
31+
continue
32+
33+
# Get the ihc table folder.
34+
sgn = sources[sgn_name]["segmentation"]
35+
table_folder = os.path.join(BUCKET_NAME, cochlea, sgn["tableData"]["tsv"]["relativePath"])
36+
37+
# For debugging.
38+
x = s3.ls(table_folder)
39+
if len(x) == 1:
40+
continue
41+
42+
default_table = s3.open(os.path.join(table_folder, "default.tsv"), mode="rb")
43+
default_table = pd.read_csv(default_table, sep="\t")
44+
main_ids = default_table[default_table.component_labels == 1].label_id
45+
46+
measurement_table = s3.open(
47+
os.path.join(table_folder, "GFP-resized_SGN-resized-v2_object-measures.tsv"), mode="rb"
48+
)
49+
measurement_table = pd.read_csv(measurement_table, sep="\t")
50+
measurement_table = measurement_table[measurement_table.label_id.isin(main_ids)]
51+
assert len(measurement_table) == len(main_ids)
52+
53+
if save:
54+
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
55+
measurement_table.to_csv(os.path.join(OUTPUT_FOLDER, f"{cochlea}.csv"), index=False)
56+
57+
print("Cochlea:", cochlea)
58+
print("GFP measurements for:", len(measurement_table), "SGNs:")
59+
print(measurement_table.columns)
60+
print()
61+
62+
63+
def plot_distribution():
64+
import seaborn as sns
65+
import matplotlib.pyplot as plt
66+
67+
table1 = "./results/sgn-measurements/M_LR_000145_L.csv"
68+
table2 = "./results/sgn-measurements/M_LR_000151_R.csv"
69+
table3 = "./results/sgn-measurements/M_LR_000155_L.csv"
70+
71+
table1 = pd.read_csv(table1)
72+
table2 = pd.read_csv(table2)
73+
table3 = pd.read_csv(table3)
74+
75+
print(len(table1))
76+
print(len(table3))
77+
78+
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
79+
80+
sns.histplot(data=table1, x="mean", bins=32, ax=axes[0])
81+
axes[0].set_title("M145_L")
82+
# Something is wrong here, the values are normalized.
83+
# sns.histplot(data=table2, x="mean", bins=32, ax=axes[1])
84+
# axes[1].set_title("M151_R")
85+
sns.histplot(data=table3, x="mean", bins=32, ax=axes[1])
86+
axes[1].set_title("M155_L")
87+
88+
fig.suptitle("SGN Gene Therapy - Mean GFP Intensity of SGNs")
89+
plt.tight_layout()
90+
91+
plt.show()
92+
93+
94+
def main():
95+
# check_project(save=True)
96+
plot_distribution()
97+
98+
99+
if __name__ == "__main__":
100+
main()

0 commit comments

Comments
 (0)