Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 74 additions & 5 deletions flamingo_tools/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def fetch_data_for_evaluation(
seg_name: str = "SGN_v2",
z_extent: int = 0,
components_for_postprocessing: Optional[List[int]] = None,
cochlea: Optional[str] = None,
extra_data: Optional[str] = None,
) -> Tuple[np.ndarray, pd.DataFrame]:
"""Fetch segmentation from S3 matching the annotation path for evaluation.

Expand All @@ -53,28 +55,31 @@ def fetch_data_for_evaluation(
z_extent: Additional z-slices to load from the segmentation.
components_for_postprocessing: The component ids for restricting the segmentation to.
Choose [1] for the default componentn containing the helix.
cochlea: Optional name of the cochlea.
extra_data: Extra data to fetch.

Returns:
The segmentation downloaded from the S3 bucket.
The annotations loaded from pandas and matching the segmentation.
"""
# Load the annotations and normalize them for the given z-extent.
annotations = pd.read_csv(annotation_path)
annotations = annotations.drop(columns="index")
if "index" in annotations.columns:
annotations = annotations.drop(columns="index")
if z_extent == 0: # If we don't have a z-extent then we just drop the first axis and rename the other two.
annotations = annotations.drop(columns="axis-0")
annotations = annotations.rename(columns={"axis-1": "axis-0", "axis-2": "axis-1"})
else: # Otherwise we have to center the first axis.
# TODO
raise NotImplementedError

# Load the segmentaiton from cache path if it is given and if it is already cached.
if cache_path is not None and os.path.exists(cache_path):
segmentation = imageio.imread(cache_path)
return segmentation, annotations

# Parse which ID and which cochlea from the name.
cochlea, slice_id = _parse_annotation_path(annotation_path)
if cochlea is None:
cochlea, slice_id = _parse_annotation_path(annotation_path)
else:
_, slice_id = _parse_annotation_path(annotation_path)

# Open the S3 connection, get the path to the SGN segmentation in S3.
internal_path = os.path.join(cochlea, "images", "ome-zarr", f"{seg_name}.ome.zarr")
Expand Down Expand Up @@ -111,6 +116,14 @@ def fetch_data_for_evaluation(
if cache_path is not None:
imageio.imwrite(cache_path, segmentation, compression="zlib")

if extra_data is not None:
internal_path = os.path.join(cochlea, "images", "ome-zarr", f"{extra_data}.ome.zarr")
s3_store, fs = get_s3_path(internal_path, bucket_name=BUCKET_NAME, service_endpoint=SERVICE_ENDPOINT)
input_key = "s0"
with zarr.open(s3_store, mode="r") as f:
extra_im_data = f[input_key][roi]
return segmentation, annotations, extra_im_data

return segmentation, annotations


Expand Down Expand Up @@ -347,6 +360,62 @@ def union(a, b):
return consensus_df, unmatched_df


def match_detections(
detections: np.ndarray,
annotations: np.ndarray,
max_dist: float
):
"""One-to-one matching between 3-D detections and ground-truth points.

Args:
detections: N x 3 candidate detections.
annotations: M x 3 ground-truth annotations for the reference points.
max_dist: Maximum Euclidean distance allowed for a match.

Returns:
Indices in `detections` that were matched (true positives).
Indices in `annotations` that were matched (true positives).
Unmatched detection indices (false positives).
Unmatched annotation indices (false negatives).
"""
det = np.asarray(detections, dtype=float)
ann = np.asarray(annotations, dtype=float)
N, M = len(det), len(ann)

# trivial corner cases --------------------------------------------------------
if N == 0:
return np.empty(0, int), np.empty(0, int), np.empty(0, int), np.arange(M)
if M == 0:
return np.empty(0, int), np.empty(0, int), np.arange(N), np.empty(0, int)

# 1. build sparse radius-filtered distance matrix -----------------------------
tree_det = cKDTree(det)
tree_ann = cKDTree(ann)
coo = tree_det.sparse_distance_matrix(tree_ann, max_dist, output_type="coo_matrix")

if coo.nnz == 0: # nothing is close enough
return np.empty(0, int), np.empty(0, int), np.arange(N), np.arange(M)

cost = np.full((N, M), 5 * max_dist, dtype=float)
cost[coo.row, coo.col] = coo.data # fill only existing edges

# 2. optimal one-to-one assignment (Hungarian) --------------------------------
row_ind, col_ind = linear_sum_assignment(cost)

# Filter assignments that were padded with +∞ cost for non-existent edges
# (linear_sum_assignment automatically does that padding internally).
valid_mask = cost[row_ind, col_ind] <= max_dist
tp_det_ids = row_ind[valid_mask]
tp_ann_ids = col_ind[valid_mask]
assert len(tp_det_ids) == len(tp_ann_ids)

# 3. derive FP / FN -----------------------------------------------------------
fp_det_ids = np.setdiff1d(np.arange(N), tp_det_ids, assume_unique=True)
fn_ann_ids = np.setdiff1d(np.arange(M), tp_ann_ids, assume_unique=True)

return tp_det_ids, tp_ann_ids, fp_det_ids, fn_ann_ids


def for_visualization(segmentation, annotations, matches):
green_red = ["#00FF00", "#FF0000"]

Expand Down
17 changes: 10 additions & 7 deletions scripts/export_lower_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@
import zarr

from flamingo_tools.s3_utils import get_s3_path, BUCKET_NAME, SERVICE_ENDPOINT
from skimage.segmentation import relabel_sequential
# from skimage.segmentation import relabel_sequential


def filter_component(fs, segmentation, cochlea, seg_name):
def filter_component(fs, segmentation, cochlea, seg_name, components):
# First, we download the MoBIE table for this segmentation.
internal_path = os.path.join(BUCKET_NAME, cochlea, "tables", seg_name, "default.tsv")
with fs.open(internal_path, "r") as f:
table = pd.read_csv(f, sep="\t")

# Then we get the ids for the components and us them to filter the segmentation.
component_mask = np.isin(table.component_labels.values, [1])
component_mask = np.isin(table.component_labels.values, components)
keep_label_ids = table.label_id.values[component_mask].astype("int64")
filter_mask = ~np.isin(segmentation, keep_label_ids)
segmentation[filter_mask] = 0

segmentation, _, _ = relabel_sequential(segmentation)
# segmentation, _, _ = relabel_sequential(segmentation)
return segmentation


Expand All @@ -42,8 +42,10 @@ def export_lower_resolution(args):
with zarr.open(s3_store, mode="r") as f:
data = f[input_key][:]
print(data.shape)
if args.filter_by_component:
data = filter_component(fs, data, args.cochlea, channel)
if args.filter_by_components is not None:
data = filter_component(fs, data, args.cochlea, channel, args.filter_by_components)
if args.binarize:
data = (data > 0).astype("uint8")
tifffile.imwrite(out_path, data, bigtiff=True, compression="zlib")


Expand All @@ -53,7 +55,8 @@ def main():
parser.add_argument("--scale", "-s", type=int, required=True)
parser.add_argument("--output_folder", "-o", required=True)
parser.add_argument("--channels", nargs="+", default=["PV", "VGlut3", "CTBP2"])
parser.add_argument("--filter_by_component", action="store_true")
parser.add_argument("--filter_by_components", nargs="+", type=int, default=None)
parser.add_argument("--binarize", action="store_true")
args = parser.parse_args()

export_lower_resolution(args)
Expand Down
80 changes: 80 additions & 0 deletions scripts/measurements/evaluate_otof_therapy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
import json

import pandas as pd

from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target

OUTPUT_FOLDER = "./results/otof-measurements"


def check_project(save=False):
s3 = create_s3_target()

# content = s3.open(f"{BUCKET_NAME}/project.json", mode="r", encoding="utf-8")
# x = json.loads(content.read())
# print(x)
# return

cochleae = ["M_AMD_OTOF1_L", "M_AMD_OTOF2_L"]
ihc_name = "IHC_v2"

for cochlea in cochleae:
content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8")
info = json.loads(content.read())
sources = info["sources"]

source_names = list(sources.keys())
assert ihc_name in source_names

# Get the ihc table folder.
ihc = sources[ihc_name]["segmentation"]
table_folder = os.path.join(BUCKET_NAME, cochlea, ihc["tableData"]["tsv"]["relativePath"])

# For debugging.
# print(s3.ls(table_folder))

default_table = s3.open(os.path.join(table_folder, "default.tsv"), mode="rb")
default_table = pd.read_csv(default_table, sep="\t")

measurement_table = s3.open(os.path.join(table_folder, "Apha_IHC-v2_object-measures.tsv"), mode="rb")
measurement_table = pd.read_csv(measurement_table, sep="\t")
if save:
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
measurement_table.to_csv(os.path.join(OUTPUT_FOLDER, f"{cochlea}.csv"), index=False)

print("Cochlea:", cochlea)
print("AlphaTag measurements for:", len(measurement_table), "IHCs:")
print(measurement_table.columns)
print()


def plot_distribution():
import seaborn as sns
import matplotlib.pyplot as plt

table1 = "./results/otof-measurements/M_AMD_OTOF1_L.csv"
table2 = "./results/otof-measurements/M_AMD_OTOF2_L.csv"

table1 = pd.read_csv(table1)
table2 = pd.read_csv(table2)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
sns.histplot(data=table1, x="mean", bins=32, ax=axes[0])
axes[0].set_title("Dual AAV")
sns.histplot(data=table2, x="mean", bins=32, ax=axes[1])
axes[1].set_title("Overloaded AAV")

fig.suptitle("OTOF Gene Therapy - Mean AlphaTag Intensity of IHCs")
plt.tight_layout()

plt.show()


def main():
# check_project(save=True)
plot_distribution()


if __name__ == "__main__":
main()
100 changes: 100 additions & 0 deletions scripts/measurements/evaluate_sgn_therapy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import os
import json

import pandas as pd

from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target

OUTPUT_FOLDER = "./results/sgn-measurements"


def check_project(save=False):
s3 = create_s3_target()

content = s3.open(f"{BUCKET_NAME}/project.json", mode="r", encoding="utf-8")
project_info = json.loads(content.read())

cochleae = [
"M_LR_000144_L", "M_LR_000145_L", "M_LR_000151_R", "M_LR_000155_L",
]

sgn_name = "SGN_resized_v2"
for cochlea in cochleae:
assert cochlea in project_info["datasets"]

content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8")
info = json.loads(content.read())
sources = info["sources"]

source_names = list(sources.keys())
if sgn_name not in source_names:
continue

# Get the ihc table folder.
sgn = sources[sgn_name]["segmentation"]
table_folder = os.path.join(BUCKET_NAME, cochlea, sgn["tableData"]["tsv"]["relativePath"])

# For debugging.
x = s3.ls(table_folder)
if len(x) == 1:
continue

default_table = s3.open(os.path.join(table_folder, "default.tsv"), mode="rb")
default_table = pd.read_csv(default_table, sep="\t")
main_ids = default_table[default_table.component_labels == 1].label_id

measurement_table = s3.open(
os.path.join(table_folder, "GFP-resized_SGN-resized-v2_object-measures.tsv"), mode="rb"
)
measurement_table = pd.read_csv(measurement_table, sep="\t")
measurement_table = measurement_table[measurement_table.label_id.isin(main_ids)]
assert len(measurement_table) == len(main_ids)

if save:
os.makedirs(OUTPUT_FOLDER, exist_ok=True)
measurement_table.to_csv(os.path.join(OUTPUT_FOLDER, f"{cochlea}.csv"), index=False)

print("Cochlea:", cochlea)
print("GFP measurements for:", len(measurement_table), "SGNs:")
print(measurement_table.columns)
print()


def plot_distribution():
import seaborn as sns
import matplotlib.pyplot as plt

table1 = "./results/sgn-measurements/M_LR_000145_L.csv"
table2 = "./results/sgn-measurements/M_LR_000151_R.csv"
table3 = "./results/sgn-measurements/M_LR_000155_L.csv"

table1 = pd.read_csv(table1)
table2 = pd.read_csv(table2)
table3 = pd.read_csv(table3)

print(len(table1))
print(len(table3))

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

sns.histplot(data=table1, x="mean", bins=32, ax=axes[0])
axes[0].set_title("M145_L")
# Something is wrong here, the values are normalized.
# sns.histplot(data=table2, x="mean", bins=32, ax=axes[1])
# axes[1].set_title("M151_R")
sns.histplot(data=table3, x="mean", bins=32, ax=axes[1])
axes[1].set_title("M155_L")

fig.suptitle("SGN Gene Therapy - Mean GFP Intensity of SGNs")
plt.tight_layout()

plt.show()


def main():
# check_project(save=True)
plot_distribution()


if __name__ == "__main__":
main()
Loading
Loading