diff --git a/flamingo_tools/segmentation/chreef_utils.py b/flamingo_tools/segmentation/chreef_utils.py index 26fb8e4..427cde9 100644 --- a/flamingo_tools/segmentation/chreef_utils.py +++ b/flamingo_tools/segmentation/chreef_utils.py @@ -12,12 +12,17 @@ def coord_from_string(center_str): return tuple([int(c) for c in center_str.split("-")]) -def find_annotations(annotation_dir, cochlea) -> dict: - """Create dictionary for analysis of ChReef annotations. +def find_annotations(annotation_dir: str, cochlea: str) -> dict: + """Create a dictionary for the analysis of ChReef annotations. + Annotations should have format positive-negative__crop__allNegativeExcluded_thr.tif Args: annotation_dir: Directory containing annotations. + cochlea: The name of the cochlea to analyze. + + Returns: + Dictionary with information about the intensity annotations. """ def extract_center_string(cochlea, name): @@ -58,7 +63,7 @@ def get_roi(coord: tuple, roi_halo: tuple, resolution: float = 0.38) -> Tuple[in resolution: Resolution of array in µm. Returns: - region of interest + The region of interest. """ coords = list(coord) # reverse dimensions for correct extraction @@ -123,7 +128,10 @@ def find_inbetween_ids( Args: arr_negexc: Array with all negatives excluded. arr_allweak: Array with all weak positives. - roi_sgn: Region of interest of segmentation. + roi_seg: Region of interest of segmentation. + + Returns: + A list of the ids that are in between the respective thresholds. """ # negative annotation == 1, positive annotation == 2 negexc_negatives = find_overlapping_masks(arr_negexc, roi_seg, label_id_base=1) @@ -141,8 +149,12 @@ def get_median_intensity(file_negexc, file_allweak, center, data_seg, table): roi_seg = data_seg[roi] inbetween_ids = find_inbetween_ids(arr_negexc, arr_allweak, roi_seg) + if len(inbetween_ids) == 0: + return None + subset = table[table["label_id"].isin(inbetween_ids)] intensities = list(subset["median"]) + return np.median(list(intensities)) @@ -154,11 +166,14 @@ def localize_median_intensities(annotation_dir, cochlea, data_seg, table_measure for center_str in annotation_dic["center_strings"]: center_coord = coord_from_string(center_str) - print(f"Getting mean intensities for {center_coord}.") + print(f"Getting median intensities for {center_coord}.") file_pos = annotation_dic[center_str]["file_pos"] file_neg = annotation_dic[center_str]["file_neg"] median_intensity = get_median_intensity(file_neg, file_pos, center_coord, data_seg, table_measure) + if median_intensity is None: + print(f"No inbetween IDs found for {center_str}.") + annotation_dic[center_str]["median_intensity"] = median_intensity return annotation_dic diff --git a/scripts/figures/chreef_analysis.py b/scripts/figures/chreef_analysis.py new file mode 100644 index 0000000..530d300 --- /dev/null +++ b/scripts/figures/chreef_analysis.py @@ -0,0 +1,87 @@ +import json +import os +import pickle + +# import matplotlib.pyplot as plt +# import numpy as np +import pandas as pd +# import tifffile +# import zarr +# from matplotlib import cm, colors + +from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target + +INTENSITY_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/tables/measurements" # noqa +# The cochlea for the CHReef analysis. +COCHLEAE = [ + "M_LR_000143_L", + "M_LR_000144_L", + "M_LR_000145_L", + "M_LR_000153_L", + "M_LR_000155_L", + "M_LR_000189_L", + "M_LR_000143_R", + "M_LR_000144_R", + "M_LR_000145_R", + "M_LR_000153_R", + "M_LR_000155_R", + "M_LR_000189_R", +] + + +def download_data(): + s3 = create_s3_target() + source_name = "SGN_v2" + + cache_path = "./chreef_data.pkl" + if os.path.exists(cache_path): + with open(cache_path, "rb") as f: + return pickle.load(f) + + chreef_data = {} + for cochlea in COCHLEAE: + print("Processsing cochlea:", cochlea) + content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8") + info = json.loads(content.read()) + sources = info["sources"] + + # Load the seg table and filter the compartments. + source = sources[source_name]["segmentation"] + rel_path = source["tableData"]["tsv"]["relativePath"] + table_content = s3.open(os.path.join(BUCKET_NAME, cochlea, rel_path, "default.tsv"), mode="rb") + table = pd.read_csv(table_content, sep="\t") + + # May need to be adjusted for some cochleae. + table = table[table.component_labels == 1] + # The relevant values for analysis. + try: + values = table[["label_id", "length[µm]", "frequency[kHz]", "marker_labels"]] + except KeyError: + print("Could not find the values for", cochlea, "it will be skippped.") + continue + + fname = f"{cochlea.replace('_', '-')}_GFP_SGN-v2_object-measures.tsv" + intensity_file = os.path.join(INTENSITY_ROOT, fname) + assert os.path.exists(intensity_file), intensity_file + intensity_table = pd.read_csv(intensity_file, sep="\t") + values = values.merge(intensity_table, on="label_id") + + chreef_data[cochlea] = values + + with open(cache_path, "wb") as f: + chreef_data = pickle.dump(chreef_data, f) + return chreef_data + + +def analyze_transduction(chreef_data): + breakpoint() + pass + + +def main(): + chreef_data = download_data() + analyze_transduction(chreef_data) + + +if __name__ == "__main__": + main() diff --git a/scripts/figures/plot_fig2.py b/scripts/figures/plot_fig2.py new file mode 100644 index 0000000..ea51dec --- /dev/null +++ b/scripts/figures/plot_fig2.py @@ -0,0 +1,314 @@ +import argparse +import os +from glob import glob + +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +from util import literature_reference_values + +png_dpi = 300 + + +def fig_02c(save_path, plot=False, all_versions=False): + """Scatter plot showing the precision, recall, and F1-score of SGN (distance U-Net, manual), + IHC (distance U-Net, manual), and synapse detection (U-Net). + """ + # precision, recall, f1-score + sgn_unet = [0.887, 0.88, 0.884] + sgn_annotator = [0.95, 0.849, 0.9] + + ihc_v4b = [0.91, 0.819, 0.862] + ihc_v4c = [0.905, 0.831, 0.866] + ihc_v4c_filter = [0.919, 0.775, 0.841] + + ihc_annotator = [0.958, 0.956, 0.957] + syn_unet = [0.931, 0.905, 0.918] + + # This is the version with IHC v4b segmentation: + # 4th version of the network with optimized segmentation params + version_1 = [sgn_unet, sgn_annotator, ihc_v4b, ihc_annotator, syn_unet] + settings_1 = ["automatic", "manual", "automatic", "manual", "automatic"] + + # This is the version with IHC v4c segmentation: + # 4th version of the network with optimized segmentation params and split of falsely merged IHCs + version_2 = [sgn_unet, sgn_annotator, ihc_v4c, ihc_annotator, syn_unet] + settings_2 = ["automatic", "manual", "automatic", "manual", "automatic"] + + # This is the version with IHC v4c + filter segmentation: + # 4th version of the network with optimized segmentation params and split of falsely merged IHCs + # + filtering out IHCs with zero mapped synapses. + version_3 = [sgn_unet, sgn_annotator, ihc_v4c_filter, ihc_annotator, syn_unet] + settings_3 = ["automatic", "manual", "automatic", "manual", "automatic"] + + if all_versions: + versions = [version_1, version_2, version_3] + settings = [settings_1, settings_2, settings_3] + save_suffix = ["_v4b", "_v4c", "_v4c_filter"] + save_paths = [save_path + i for i in save_suffix] + else: + versions = [version_2] + settings = [settings_2] + save_suffix = ["_v4c"] + save_paths = [save_path + i for i in save_suffix] + + for version, setting, save_path in zip(versions, settings, save_paths): + precision = [i[0] for i in version] + recall = [i[1] for i in version] + f1score = [i[2] for i in version] + + descr_y = 0.72 + + # Convert setting labels to numerical x positions + x = np.array([0.8, 1.2, 1.8, 2.2, 3]) + offset = 0.08 # horizontal shift for scatter separation + + # Plot + plt.figure(figsize=(8, 5)) + + main_label_size = 20 + sub_label_size = 16 + main_tick_size = 12 + legendsize = 16 + + plt.scatter(x - offset, precision, label="Precision", marker="o", s=80) + plt.scatter(x, recall, label="Recall", marker="^", s=80) + plt.scatter(x + offset, f1score, label="F1-score", marker="*", s=80) + + plt.text(1, descr_y, "SGN", fontsize=main_label_size, horizontalalignment="center") + plt.text(2, descr_y, "IHC", fontsize=main_label_size, horizontalalignment="center") + plt.text(3, descr_y, "Synapse", fontsize=main_label_size, horizontalalignment="center") + + # Labels and formatting + plt.xticks(x, setting, fontsize=sub_label_size) + plt.yticks(fontsize=main_tick_size) + plt.ylabel("Value", fontsize=main_label_size) + plt.ylim(0.76, 1) + plt.legend(loc="upper center", bbox_to_anchor=(0.5, 1.11), + ncol=3, fancybox=True, shadow=False, framealpha=0.8, fontsize=legendsize) + plt.grid(axis="y", linestyle="--", alpha=0.5) + + plt.tight_layout() + plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) + + if plot: + plt.show() + else: + plt.close() + + +# Load the synapse counts for all IHCs from the relevant tables. +def _load_ribbon_synapse_counts(): + tables = glob("ihc_counts/*M_LR*.tsv") + syn_counts = [] + for tab in tables: + x = pd.read_csv(tab, sep="\t") + syn_counts.extend(x["synapse_count"].values.tolist()) + return syn_counts + + +def fig_02d_01(save_path, plot=False, all_versions=False, plot_average_ribbon_synapses=False): + """Box plot showing the counts for SGN and IHC per (mouse) cochlea in comparison to literature values. + """ + main_tick_size = 16 + main_label_size = 24 + + rows = 1 + columns = 3 if plot_average_ribbon_synapses else 2 + + sgn_values = [11153, 11398, 10333, 11820] + ihc_v4b_values = [836, 808, 796, 901] + ihc_v4c_values = [712, 710, 721, 675] + ihc_v4c_filtered_values = [562, 647, 626, 628] + + if all_versions: + ihc_list = [ihc_v4b_values, ihc_v4c_values, ihc_v4c_filtered_values] + suffixes = ["_v4b", "_v4c", "_v4c_filtered"] + assert not plot_average_ribbon_synapses + else: + ihc_list = [ihc_v4c_values] + suffixes = ["_v4c"] + + for (ihc_values, suffix) in zip(ihc_list, suffixes): + fig, axes = plt.subplots(rows, columns, figsize=(columns*4, rows*4)) + ax = axes.flatten() + + save_path_new = save_path + suffix + ax[0].boxplot(sgn_values) + ax[1].boxplot(ihc_values) + + # Labels and formatting + ax[0].set_xticklabels(["SGN"], fontsize=main_label_size) + + ylim0 = 9500 + ylim1 = 12500 + y_ticks = [i for i in range(10000, 12000 + 1, 1000)] + + ax[0].set_ylabel("Count per cochlea", fontsize=main_label_size) + ax[0].set_yticks(y_ticks) + ax[0].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) + ax[0].set_ylim(ylim0, ylim1) + ax[0].yaxis.set_ticks_position("left") + + # set range of literature values + xmin = 0.5 + xmax = 1.5 + ax[0].set_xlim(xmin, xmax) + lower_y, upper_y = literature_reference_values("SGN") + ax[0].hlines([lower_y, upper_y], xmin, xmax) + ax[0].text(1, upper_y + 100, "literature", color="C0", fontsize=main_tick_size, ha="center") + ax[0].fill_between([xmin, xmax], lower_y, upper_y, color="C0", alpha=0.05, interpolate=True) + + ylim0 = 500 + ylim1 = 950 + y_ticks = [i for i in range(500, 900 + 1, 100)] + + ax[1].set_xticklabels(["IHC"], fontsize=main_label_size) + ax[1].set_yticks(y_ticks) + ax[1].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) + ax[1].set_ylim(ylim0, ylim1) + if not plot_average_ribbon_synapses: + ax[1].yaxis.tick_right() + ax[1].yaxis.set_ticks_position("right") + + # set range of literature values + xmin = 0.5 + xmax = 1.5 + lower_y, upper_y = literature_reference_values("IHC") + ax[1].set_xlim(xmin, xmax) + ax[1].hlines([lower_y, upper_y], xmin, xmax) + ax[1].text(1, upper_y + 20, "literature", color="C0", fontsize=main_tick_size, ha="center") + ax[1].fill_between([xmin, xmax], lower_y, upper_y, color="C0", alpha=0.05, interpolate=True) + + if plot_average_ribbon_synapses: + ribbon_synapse_counts = _load_ribbon_synapse_counts() + # ylim0 = 4.9 + # ylim1 = 25.1 + y_ticks = [0, 10, 20, 30, 40, 50] + + ax[2].boxplot(ribbon_synapse_counts) + ax[2].set_xticklabels(["Ribbon Syn. per IHC"], fontsize=main_label_size) + ax[2].set_yticks(y_ticks) + ax[2].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) + # ax[2].set_ylim(ylim0, ylim1) + + # set range of literature values + xmin = 0.5 + xmax = 1.5 + lower_y, upper_y = literature_reference_values("synapse") + ax[2].set_xlim(xmin, xmax) + ax[2].hlines([lower_y, upper_y], xmin, xmax) + ax[2].text(1, upper_y + 2, "literature", color="C0", fontsize=main_tick_size, ha="center") + ax[2].fill_between([xmin, xmax], lower_y, upper_y, color="C0", alpha=0.05, interpolate=True) + + plt.tight_layout() + plt.savefig(save_path_new, dpi=png_dpi) + + if plot: + plt.show() + else: + plt.close() + + +def fig_02d_02(save_path, filter_zeros=True, plot=False): + """Bar plot showing the distribution of synapse markers per IHC segmentation average over multiple clochleae. + """ + cochleae = ["M_LR_000226_L", "M_LR_000226_R", "M_LR_000227_L", "M_LR_000227_R"] + ihc_version = "ihc_counts_v4b" + synapse_dir = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/synapses/{ihc_version}" + + max_dist = 3 + bins = 10 + cap = 30 + plot_density = False + + results = [] + for cochlea in cochleae: + synapse_file = os.path.join(synapse_dir, f"ihc_count_{cochlea}.tsv") + ihc_table = pd.read_csv(synapse_file, sep="\t") + syn_per_ihc = list(ihc_table["synapse_count"]) + if filter_zeros: + syn_per_ihc = [s for s in syn_per_ihc if s != 0] + results.append(syn_per_ihc) + + results = [np.clip(r, 0, cap) for r in results] + + # Define bins (shared for all datasets) + bins = np.linspace(0, 30, 11) # 29 bins + + # Compute histogram (relative) for each dataset + histograms = [] + for data in results: + counts, _ = np.histogram(data, bins=bins, density=plot_density) + histograms.append(counts) + + histograms = np.array(histograms) + + # Compute mean and std for each bin across datasets + mean_counts = histograms.mean(axis=0) + std_counts = histograms.std(axis=0) + + # Get bin centers for plotting + bin_centers = 0.5 * (bins[1:] + bins[:-1]) + + # Plot + plt.figure(figsize=(8, 5)) + plt.bar(bin_centers, mean_counts, width=(bins[1] - bins[0]), yerr=std_counts, alpha=0.7, capsize=4, + label="Mean ± Std Dev", edgecolor="black") + + main_label_size = 20 + main_tick_size = 16 + legendsize = 16 + + # Labels and formatting + x_ticks = [i for i in range(0, cap + 1, 5)] + if plot_density: + y_ticks = [i * 0.02 for i in range(0, 10, 2)] + else: + y_ticks = [i for i in range(0, 300, 50)] + + plt.xticks(x_ticks, fontsize=main_tick_size) + plt.yticks(y_ticks, fontsize=main_tick_size) + if plot_density: + plt.ylabel("Proportion of IHCs [%]", fontsize=main_label_size) + else: + plt.ylabel("Number of IHCs", fontsize=main_label_size) + plt.xlabel(f"Ribbon Synapses per IHC @ {max_dist} µm", fontsize=main_label_size) + + plt.title("Average Synapses per IHC for a Dataset of 4 Cochleae") + + plt.grid(axis="y", linestyle="--", alpha=0.5) + plt.legend(fontsize=legendsize) + plt.tight_layout() + plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) + + if plot: + plt.show() + else: + plt.close() + + +def main(): + parser = argparse.ArgumentParser(description="Generate plots for Fig 2 of the cochlea paper.") + parser.add_argument("--figure_dir", "-f", type=str, help="Output directory for plots.", default="./panels/fig2") + parser.add_argument("--plot", action="store_true") + args = parser.parse_args() + + os.makedirs(args.figure_dir, exist_ok=True) + + # Panel C: Evaluation of the segmentation results: + fig_02c(save_path=os.path.join(args.figure_dir, "fig_02c"), plot=args.plot, all_versions=False) + + # Panel D: The number of SGNs, IHCs and average number of ribbon synapses per IHC + fig_02d_01(save_path=os.path.join(args.figure_dir, "fig_02d"), plot=args.plot, plot_average_ribbon_synapses=True) + + # Alternative version of synapse distribution for panel D. + # fig_02d_02(save_path=os.path.join(args.figure_dir, "fig_02d_02"), plot=args.plot) + # fig_02d_02(save_path=os.path.join(args.figure_dir, "fig_02d_02_v4c"), filter_zeros=False, plot=plot) + # fig_02d_02(save_path=os.path.join(args.figure_dir, "fig_02d_02_v4c_filtered"), filter_zeros=True, plot=plot) + # fig_02d_02(save_path=os.path.join(args.figure_dir, "fig_02d_02_v4b"), filter_zeros=True, plot=args.plot) + + +if __name__ == "__main__": + main() diff --git a/scripts/figures/plot_fig3.py b/scripts/figures/plot_fig3.py new file mode 100644 index 0000000..dff146c --- /dev/null +++ b/scripts/figures/plot_fig3.py @@ -0,0 +1,169 @@ +import argparse +import os +import imageio.v3 as imageio +from glob import glob + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib import cm, colors + +from util import sliding_runlength_sum, frequency_mapping + +INPUT_ROOT = "/home/pape/Work/my_projects/flamingo-tools/scripts/M_LR_000227_R/scale3/frequency_mapping" + +png_dpi = 300 + + +def fig_03a(save_path): + import napari + + path = os.path.join(INPUT_ROOT, "frequencies_IHC_v4c.tif") + vol = imageio.imread(path) + + # Create the colormap + fig, ax = plt.subplots(figsize=(6, 1.3)) + fig.subplots_adjust(bottom=0.5) + + freq_min = np.min(np.nonzero(vol)) + freq_max = vol.max() + norm = colors.Normalize(vmin=freq_min, vmax=freq_max, clip=True) + cmap = plt.get_cmap("viridis") + + cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax, orientation="horizontal") + cb.set_label("Frequency [kHz]") + plt.title("Tonotopic Mapping: IHCs") + plt.tight_layout() + out_path = os.path.join(save_path) + plt.savefig(out_path) + + # Show the image in napari for rendering. + v = napari.Viewer() + v.add_image(vol, colormap="viridis") + napari.run() + + +def fig_03b(save_path): + import napari + + path = os.path.join(INPUT_ROOT, "frequencies_SGN_v2.tif") + vol = imageio.imread(path) + + # Create the colormap + fig, ax = plt.subplots(figsize=(6, 1.3)) + fig.subplots_adjust(bottom=0.5) + + freq_min = np.min(np.nonzero(vol)) + freq_max = vol.max() + norm = colors.Normalize(vmin=freq_min, vmax=freq_max, clip=True) + cmap = plt.get_cmap("viridis") + + cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax, orientation="horizontal") + cb.set_label("Frequency [kHz]") + plt.title("Tonotopic Mapping: SGNs") + plt.tight_layout() + out_path = os.path.join(save_path) + plt.savefig(out_path) + + # Show the image in napari for rendering. + v = napari.Viewer() + v.add_image(vol, colormap="viridis") + napari.run() + + +def fig_03c_rl(save_path, plot=False): + tables = glob("./ihc_counts/ihc_count_M_LR*.tsv") + fig, ax = plt.subplots(figsize=(8, 4)) + + width = 50 # micron + + for tab_path in tables: + # TODO map to alias + alias = os.path.basename(tab_path)[10:-4].replace("_", "").replace("0", "") + tab = pd.read_csv(tab_path, sep="\t") + run_length = tab["run_length"].values + syn_count = tab["synapse_count"].values + + # Compute the running sum of 10 micron. + run_length, syn_count_running = sliding_runlength_sum(run_length, syn_count, width=width) + ax.plot(run_length, syn_count_running, label=alias) + + ax.set_xlabel("Length [µm]") + ax.set_ylabel("Synapse Count") + ax.set_title(f"Ribbon Syn. per IHC: Runnig sum @ {width} µm") + ax.legend(title="cochlea") + plt.tight_layout() + + plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) + if plot: + plt.show() + else: + plt.close() + + +def fig_03c_octave(save_path, plot=False): + tables = glob("./ihc_counts/ihc_count_M_LR*.tsv") + + result = {"cochlea": [], "octave_band": [], "value": []} + for tab_path in tables: + # TODO map to alias + alias = os.path.basename(tab_path)[10:-4].replace("_", "").replace("0", "") + tab = pd.read_csv(tab_path, sep="\t") + freq = tab["frequency"].values + syn_count = tab["synapse_count"].values + + # Compute the running sum of 10 micron. + octave_binned = frequency_mapping(freq, syn_count, animal="mouse") + + result["cochlea"].extend([alias] * len(octave_binned)) + result["octave_band"].extend(octave_binned.axes[0].values.tolist()) + result["value"].extend(octave_binned.values.tolist()) + + result = pd.DataFrame(result) + bin_labels = pd.unique(result["octave_band"]) + band_to_x = {band: i for i, band in enumerate(bin_labels)} + result["x_pos"] = result["octave_band"].map(band_to_x) + + fig, ax = plt.subplots(figsize=(8, 4)) + for name, grp in result.groupby("cochlea"): + ax.scatter(grp["x_pos"], grp["value"], label=name, s=60, alpha=0.8) + + ax.set_xticks(range(len(bin_labels))) + ax.set_xticklabels(bin_labels) + ax.set_xlabel("Octave band (kHz)") + + ax.set_ylabel("Average Ribbon Synapse Count per IHC") + ax.set_title("Ribbon synapse count per octave band") + ax.legend(title="Cochlea") + + plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) + if plot: + plt.show() + else: + plt.close() + + +def main(): + parser = argparse.ArgumentParser(description="Generate plots for Fig 3 of the cochlea paper.") + parser.add_argument("--figure_dir", "-f", type=str, help="Output directory for plots.", default="./panels/fig3") + parser.add_argument("--plot", action="store_true") + args = parser.parse_args() + + os.makedirs(args.figure_dir, exist_ok=True) + + # Panel A: Tonotopic mapping of IHCs (rendering in napari) + # fig_03a(save_path=os.path.join(args.figure_dir, "fig_03a.png")) + + # Panel B: Tonotopic mapping of SGNs (rendering in napari) + # fig_03b(save_path=os.path.join(args.figure_dir, "fig_03b.png")) + + # Panel C: Spatial distribution of synapses across the cochlea. + # We have two options: running sum over the runlength or per octave band + fig_03c_rl(save_path=os.path.join(args.figure_dir, "fig_03c_runlength.png"), plot=args.plot) + fig_03c_octave(save_path=os.path.join(args.figure_dir, "fig_03c_octave.png"), plot=args.plot) + + # TODO: Panel D: Spatial distribution of SGN sub-types. + + +if __name__ == "__main__": + main() diff --git a/scripts/figures/plot_fig4.py b/scripts/figures/plot_fig4.py new file mode 100644 index 0000000..3b74c4f --- /dev/null +++ b/scripts/figures/plot_fig4.py @@ -0,0 +1,319 @@ +import argparse +import json +import os +import pickle + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target + +from util import frequency_mapping, literature_reference_values + +INTENSITY_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/tables/measurements" # noqa + +# The cochlea for the CHReef analysis. +COCHLEAE = [ + "M_LR_000143_L", + "M_LR_000144_L", + "M_LR_000145_L", + "M_LR_000153_L", + "M_LR_000155_L", + "M_LR_000189_L", + "M_LR_000143_R", + "M_LR_000144_R", + "M_LR_000145_R", + "M_LR_000153_R", + "M_LR_000155_R", + "M_LR_000189_R", +] + +png_dpi = 300 + + +def get_chreef_data(): + s3 = create_s3_target() + source_name = "SGN_v2" + + cache_path = "./chreef_data.pkl" + if os.path.exists(cache_path): + with open(cache_path, "rb") as f: + return pickle.load(f) + + chreef_data = {} + for cochlea in COCHLEAE: + print("Processsing cochlea:", cochlea) + content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8") + info = json.loads(content.read()) + sources = info["sources"] + + # Load the seg table and filter the compartments. + source = sources[source_name]["segmentation"] + rel_path = source["tableData"]["tsv"]["relativePath"] + table_content = s3.open(os.path.join(BUCKET_NAME, cochlea, rel_path, "default.tsv"), mode="rb") + table = pd.read_csv(table_content, sep="\t") + + # May need to be adjusted for some cochleae. + table = table[table.component_labels == 1] + # The relevant values for analysis. + try: + values = table[["label_id", "length[µm]", "frequency[kHz]", "marker_labels"]] + except KeyError: + print("Could not find the values for", cochlea, "it will be skippped.") + continue + + fname = f"{cochlea.replace('_', '-')}_GFP_SGN-v2_object-measures.tsv" + intensity_file = os.path.join(INTENSITY_ROOT, fname) + assert os.path.exists(intensity_file), intensity_file + intensity_table = pd.read_csv(intensity_file, sep="\t") + values = values.merge(intensity_table, on="label_id") + + chreef_data[cochlea] = values + + with open(cache_path, "wb") as f: + chreef_data = pickle.dump(chreef_data, f) + return chreef_data + + +def group_lr(names_lr, values): + assert len(names_lr) == len(values) + names = [] + values_left, values_right = {}, {} + for name_lr, val in zip(names_lr, values): + name, side = name_lr[:-1], name_lr[-1] + names.append(name) + if side == "R": + values_right[name] = val + elif side == "L": + values_left[name] = val + else: + raise RuntimeError + names = sorted(list(set(names))) + + values_left = [values_left.get(name, np.nan) for name in names] + values_right = [values_right.get(name, np.nan) for name in names] + + return names, values_left, values_right + + +def fig_04c(chreef_data, save_path, plot=False, plot_by_side=False): + """Box plot showing the SGN counts of ChReef treated cochleae compared to healthy ones. + """ + # Previous version with hard-coded values. + # cochlea = ["M_LR_000144_L", "M_LR_000145_L", "M_LR_000151_R"] + # alias = ["c01", "c02", "c03"] + # sgns = [7796, 6119, 9225] + + # TODO map the cochlea name to its alias + alias = [name.replace("_", "").replace("0", "") for name in chreef_data.keys()] + sgns = [len(vals) for vals in chreef_data.values()] + + if plot_by_side: + alias, sgns_left, sgns_right = group_lr(alias, sgns) + + x = np.arange(len(alias)) + + # Plot + plt.figure(figsize=(8, 5)) + + main_label_size = 20 + sub_label_size = 16 + main_tick_size = 12 + legendsize = 16 + + if plot_by_side: + plt.scatter(x, sgns_left, label="SGN count (Left)", marker="o", s=80) + plt.scatter(x, sgns_right, label="SGN count (Right)", marker="x", s=80) + else: + plt.scatter(x, sgns, label="SGN count", marker="o", s=80) + + # Labels and formatting + plt.xticks(x, alias, fontsize=sub_label_size) + plt.xlabel("Cochlea", fontsize=main_label_size) + plt.yticks(fontsize=main_tick_size) + plt.ylabel("SGN count per cochlea", fontsize=main_label_size) + plt.ylim(4000, 13800) + plt.legend(loc="best", fontsize=sub_label_size) + plt.legend(loc="upper center", bbox_to_anchor=(0.5, 1.11), + ncol=3, fancybox=True, shadow=False, framealpha=0.8, fontsize=legendsize) + + # set range of literature values + xmin = -0.5 + xmax = len(alias) - 0.5 + plt.xlim(xmin, xmax) + lower_y, upper_y = literature_reference_values("SGN") + plt.hlines([lower_y, upper_y], xmin, xmax) + plt.text(1, lower_y - 400, "literature", color="C0", fontsize=main_tick_size, ha="center") + plt.fill_between([xmin, xmax], lower_y, upper_y, color="C0", alpha=0.05, interpolate=True) + + sgn_values = [11153, 11398, 10333, 11820] + sgn_value = np.mean(sgn_values) + sgn_std = np.std(sgn_values) + + upper_y = sgn_value + 1.96 * sgn_std + lower_y = sgn_value - 1.96 * sgn_std + + plt.hlines([lower_y, upper_y], xmin, xmax, colors=["C1" for _ in range(2)]) + plt.text(1, upper_y + 100, "healthy cochleae (95% confidence interval)", + color="C1", fontsize=main_tick_size, ha="center") + plt.fill_between([xmin, xmax], lower_y, upper_y, color="C1", alpha=0.05, interpolate=True) + + plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) + plt.tight_layout() + + if plot: + plt.show() + else: + plt.close() + + +def fig_04d(chreef_data, save_path, plot=False, plot_by_side=False, intensity=False): + """Transduction efficiency per cochlea. + """ + # TODO map the cochlea name to its alias + alias = [name.replace("_", "").replace("0", "") for name in chreef_data.keys()] + + values = [] + for vals in chreef_data.values(): + if intensity: + intensities = vals["median"].values + values.append(intensities.mean()) + else: + # The marker labels don't make sense yet, they are in + # 0: unlabeled + # 1: positive + # 2: negative + # but they should all be either positive or negative. + # Or am I missing something? + marker_labels = vals["marker_labels"].values + n_pos = (marker_labels == 1).sum() + n_neg = (marker_labels == 2).sum() + eff = float(n_pos) / (n_pos + n_neg) + values.append(eff) + + if plot_by_side: + alias, values_left, values_right = group_lr(alias, values) + + x = np.arange(len(alias)) + + # Plot + plt.figure(figsize=(8, 5)) + + main_label_size = 20 + sub_label_size = 16 + main_tick_size = 12 + legendsize = 16 + + label = "Intensity" if intensity else "Transduction efficiency" + if plot_by_side: + plt.scatter(x, values_left, label=f"{label} (Left)", marker="o", s=80) + plt.scatter(x, values_right, label=f"{label} (Right)", marker="x", s=80) + else: + plt.scatter(x, values, label=label, marker="o", s=80) + + # Labels and formatting + plt.xticks(x, alias, fontsize=sub_label_size) + plt.xlabel("Cochlea", fontsize=main_label_size) + plt.yticks(fontsize=main_tick_size) + plt.ylabel(label, fontsize=main_label_size) + plt.legend(loc="best", fontsize=sub_label_size) + plt.legend(loc="upper center", bbox_to_anchor=(0.5, 1.11), + ncol=3, fancybox=True, shadow=False, framealpha=0.8, fontsize=legendsize) + if not intensity: + plt.ylim(0.5, 1.05) + + plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) + plt.tight_layout() + + if plot: + plt.show() + else: + plt.close() + + +def fig_04e(chreef_data, save_path, plot, intensity=False): + + result = {"cochlea": [], "octave_band": [], "value": []} + for name, values in chreef_data.items(): + # TODO map name to alias + alias = name.replace("_", "").replace("0", "") + + freq = values["frequency[kHz]"].values + if intensity: + intensity_values = values["median"].values + octave_binned = frequency_mapping(freq, intensity_values, animal="mouse") + else: + marker_labels = values["marker_labels"].values + octave_binned = frequency_mapping(freq, marker_labels, animal="mouse", transduction_efficiency=True) + + result["cochlea"].extend([alias] * len(octave_binned)) + result["octave_band"].extend(octave_binned.axes[0].values.tolist()) + result["value"].extend(octave_binned.values.tolist()) + + result = pd.DataFrame(result) + bin_labels = pd.unique(result["octave_band"]) + band_to_x = {band: i for i, band in enumerate(bin_labels)} + result["x_pos"] = result["octave_band"].map(band_to_x) + + fig, axes = plt.subplots(1, 2, figsize=(8, 4), sharex=True, sharey=True) + for name_lr, grp in result.groupby("cochlea"): + name, side = name_lr[:-1], name_lr[-1] + ax, marker = (axes[0], "o") if side == "L" else (axes[1], "x") + ax.scatter(grp["x_pos"], grp["value"], label=name, s=60, alpha=0.8, marker=marker) + + for ax in axes: + ax.set_xticks(range(len(bin_labels))) + ax.set_xticklabels(bin_labels) + ax.set_xlabel("Octave band (kHz)") + + if intensity: + axes[0].set_ylabel("Marker Intensity") + for ax, side in zip(axes, ("Left", "Right")): + ax.set_title(f"Intensity per octave band ({side})") + else: + axes[0].set_ylabel("Transduction Efficiency") + axes[0].set_ylim(0.5, 1.05) + for ax, side in zip(axes, ("Left", "Right")): + ax.set_title(f"Transduction efficiency per octave band ({side})") + + # FIXME make this uniform across the plots! + axes[0].legend(title="Cochlea") + axes[1].legend(title="Cochlea") + plt.tight_layout() + + plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) + if plot: + plt.show() + else: + plt.close() + + +def main(): + parser = argparse.ArgumentParser(description="Generate plots for Fig 4 of the cochlea paper.") + parser.add_argument("--figure_dir", "-f", type=str, help="Output directory for plots.", default="./panels/fig4") + parser.add_argument("--plot", action="store_true") + args = parser.parse_args() + + os.makedirs(args.figure_dir, exist_ok=True) + + # Get the chreef data as a dictionary of cochlea name to measurements. + chreef_data = get_chreef_data() + # M_LR_00143_L is a complete outlier + chreef_data.pop("M_LR_000143_L") + + # Create the panels: + + # C: The SGN count compared to reference values from literature and healthy + # Maybe remove literature reference from plot? + fig_04c(chreef_data, save_path=os.path.join(args.figure_dir, "fig_04c"), plot=args.plot, plot_by_side=True) + + # D: The transduction efficiency. We also plot GFP intensities. + fig_04d(chreef_data, save_path=os.path.join(args.figure_dir, "fig_04d_transduction"), plot=args.plot, plot_by_side=True) # noqa + fig_04d(chreef_data, save_path=os.path.join(args.figure_dir, "fig_04d_intensity"), plot=args.plot, plot_by_side=True, intensity=True) # noqa + + fig_04e(chreef_data, save_path=os.path.join(args.figure_dir, "fig_04e_transduction"), plot=args.plot) + fig_04e(chreef_data, save_path=os.path.join(args.figure_dir, "fig_04e_intensity"), plot=args.plot, intensity=True) + + +if __name__ == "__main__": + main() diff --git a/scripts/figures/plot_fig6.py b/scripts/figures/plot_fig6.py new file mode 100644 index 0000000..c0cd5fe --- /dev/null +++ b/scripts/figures/plot_fig6.py @@ -0,0 +1,88 @@ +import argparse +import os + +import matplotlib.pyplot as plt + +png_dpi = 300 + + +def fig_06a(save_path, plot=False): + """Box plot showing the counts for SGN and IHC per gerbil cochlea in comparison to literature values. + """ + main_tick_size = 12 + main_label_size = 16 + + rows = 1 + columns = 2 + + fig, axes = plt.subplots(rows, columns, figsize=(columns*4, rows*4)) + ax = axes.flatten() + + sgn_values = [20050, 21995] + ihc_values = [1100] + + ax[0].boxplot(sgn_values) + ax[1].boxplot(ihc_values) + + # Labels and formatting + ax[0].set_xticklabels(["SGN"], fontsize=main_label_size) + + ylim0 = 12000 + ylim1 = 22500 + y_ticks = [i for i in range(ylim0, ylim1 + 1, 2000)] + + ax[0].set_ylabel('Count per cochlea', fontsize=main_label_size) + ax[0].set_yticks(y_ticks) + ax[0].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) + ax[0].set_ylim(ylim0, ylim1) + + # set range of literature values + xmin = 0.5 + xmax = 1.5 + ax[0].set_xlim(xmin, xmax) + upper_y = 15000 + lower_y = 13000 + ax[0].hlines([lower_y, upper_y], xmin, xmax) + ax[0].text(1, upper_y + 100, "literature reference (WIP)", color='C0', fontsize=main_tick_size, ha="center") + ax[0].fill_between([xmin, xmax], lower_y, upper_y, color='C0', alpha=0.05, interpolate=True) + + ylim0 = 800 + ylim1 = 1400 + y_ticks = [i for i in range(ylim0, ylim1 + 1, 100)] + + ax[1].set_xticklabels(["IHC"], fontsize=main_label_size) + + ax[1].set_ylabel('Count per cochlea', fontsize=main_label_size) + ax[1].set_yticks(y_ticks) + ax[1].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) + ax[1].set_ylim(ylim0, ylim1) + + # set range of literature values + xmin = 0.5 + xmax = 1.5 + ax[1].set_xlim(xmin, xmax) + upper_y = 1200 + lower_y = 1000 + ax[1].hlines([lower_y, upper_y], xmin, xmax) + ax[1].text(1, upper_y + 10, "literature reference (WIP)", color='C0', fontsize=main_tick_size, ha="center") + ax[1].fill_between([xmin, xmax], lower_y, upper_y, color='C0', alpha=0.05, interpolate=True) + + plt.tight_layout() + plt.savefig(save_path, dpi=png_dpi) + if plot: + plt.show() + else: + plt.close() + + +def main(): + parser = argparse.ArgumentParser(description="Generate plots for Fig 2 of the cochlea paper.") + parser.add_argument("figure_dir", type=str, help="Output directory for plots.", default="./panels") + args = parser.parse_args() + plot = False + + fig_06a(save_path=os.path.join(args.figure_dir, "fig_06a"), plot=plot) + + +if __name__ == "__main__": + main() diff --git a/scripts/figures/util.py b/scripts/figures/util.py new file mode 100644 index 0000000..f9dbc26 --- /dev/null +++ b/scripts/figures/util.py @@ -0,0 +1,73 @@ +import pandas as pd +import numpy as np + + +# Define the animal specific octave bands. +def _get_mapping(animal): + if animal == "mouse": + bin_edges = [0, 1, 2, 4, 8, 16, 32, 64, np.inf] + bin_labels = [ + "<1 k", "1–2 k", "2–4 k", "4–8 k", "8–16 k", "16–32 k", "32–64 k", ">64 k" + ] + elif animal == "gerbil": + bin_edges = [0, 0.5, 1, 2, 4, 8, 16, 32, np.inf] + bin_labels = [ + "<0.5 k", "0.5–1 k", "1–2 k", "2–4 k", "4–8 k", "8–16 k", "16–32 k", ">32 k" + ] + else: + raise ValueError + assert len(bin_edges) == len(bin_labels) + 1 + return bin_edges, bin_labels + + +def frequency_mapping(frequencies, values, animal="mouse", transduction_efficiency=False): + # Get the mapping of frequencies to octave bands for the given species. + bin_edges, bin_labels = _get_mapping(animal) + + # Construct the data frame with octave bands. + df = pd.DataFrame({"freq_khz": frequencies, "value": values}) + df["octave_band"] = pd.cut( + df["freq_khz"], bins=bin_edges, labels=bin_labels, right=False + ) + + if transduction_efficiency: # We compute the transduction efficiency per band. + num_pos = df[df["value"] == 1].groupby("octave_band", observed=False).size() + num_tot = df[df["value"].isin([1, 2])].groupby("octave_band", observed=False).size() + value_by_band = (num_pos / num_tot).reindex(bin_labels) + else: # Otherwise, aggregate the values over the octave band using the mean. + value_by_band = ( + df.groupby("octave_band", observed=True)["value"] + .mean() + .reindex(bin_labels) # keep octave order even if a bin is empty + ) + return value_by_band + + +def sliding_runlength_sum(run_length, values, width): + assert len(run_length) == len(values) + # Create data frame and sort it. + df = pd.DataFrame({"run_length": run_length, "value": values}) + df = df.sort_values("run_length").reset_index(drop=True).copy() + + x = df["run_length"].to_numpy() + y = df["value"].to_numpy() + + cumsum = np.cumsum(y) + start_idx = np.searchsorted(x, x - width, side="left") + window_sum = cumsum - np.concatenate(([0], cumsum[:-1]))[start_idx] + assert len(window_sum) == len(x) + + return x, window_sum + + +# TODO determine these from Aleyna's table! +def literature_reference_values(structure): + if structure == "SGN": + lower_bound, upper_bound = 10000, 12000 + elif structure == "IHC": + lower_bound, upper_bound = 780, 850 + elif structure == "synapse": + lower_bound, upper_bound = 10, 25 + else: + raise ValueError + return lower_bound, upper_bound diff --git a/scripts/measurements/evaluate_marker_annotations.py b/scripts/measurements/evaluate_marker_annotations.py index 7e8605d..539e206 100644 --- a/scripts/measurements/evaluate_marker_annotations.py +++ b/scripts/measurements/evaluate_marker_annotations.py @@ -1,4 +1,5 @@ import argparse +import json import os from typing import List, Optional @@ -9,10 +10,25 @@ from flamingo_tools.segmentation.chreef_utils import localize_median_intensities, find_annotations MARKER_DIR = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/ChReef_PV-GFP/2025-07_PV_GFP_SGN" +# The cochlea for the CHReef analysis. +COCHLEAE = [ + "M_LR_000143_L", + "M_LR_000144_L", + "M_LR_000145_L", + "M_LR_000153_L", + "M_LR_000155_L", + "M_LR_000189_L", + "M_LR_000143_R", + "M_LR_000144_R", + "M_LR_000145_R", + "M_LR_000153_R", + "M_LR_000155_R", + "M_LR_000189_R", +] def get_length_fraction_from_center(table, center_str): - """ Get 'length_fraction' parameter for center coordinate by averaging nearby segmentation instances. + """Get 'length_fraction' parameter for center coordinate by averaging nearby segmentation instances. """ center_coord = tuple([int(c) for c in center_str.split("-")]) (cx, cy, cz) = center_coord @@ -32,10 +48,10 @@ def get_length_fraction_from_center(table, center_str): def apply_nearest_threshold(intensity_dic, table_seg, table_measurement): """Apply threshold to nearest segmentation instances. - Crop centers are transformed into the 'length fraction' parameter of the segmentation table. - This avoids issues with the spiral shape of the cochlea and maps the assignment onto the Rosenthal's canal. + Crop centers are transformed into the "length fraction" parameter of the segmentation table. + This avoids issues with the spiral shape of the cochlea and maps the assignment onto the Rosenthal"s canal. """ - # assign crop centers to length fraction of Rosenthal's canal + # assign crop centers to length fraction of Rosenthal"s canal lf_intensity = {} for key in intensity_dic.keys(): length_fraction = get_length_fraction_from_center(table_seg, key) @@ -76,13 +92,48 @@ def apply_nearest_threshold(intensity_dic, table_seg, table_measurement): return table_seg +def find_thresholds(cochlea_annotations, cochlea, data_seg, table_measurement): + # Find the median intensities by averaging the individual annotations for specific crops + annotation_dics = {} + annotated_centers = [] + for annotation_dir in cochlea_annotations: + print(f"Localizing threshold with median intensities for {os.path.basename(annotation_dir)}.") + annotation_dic = localize_median_intensities(annotation_dir, cochlea, data_seg, table_measurement) + annotated_centers.extend(annotation_dic["center_strings"]) + annotation_dics[annotation_dir] = annotation_dic + + annotated_centers = list(set(annotated_centers)) + intensity_dic = {} + # loop over all annotated blocks + for annotated_center in annotated_centers: + intensities = [] + # loop over annotated block from single user + for annotator_key in annotation_dics.keys(): + if annotated_center not in annotation_dics[annotator_key]["center_strings"]: + continue + else: + median_intensity = annotation_dics[annotator_key][annotated_center]["median_intensity"] + if median_intensity is None: + print(f"No threshold for {os.path.basename(annotator_key)} and crop {annotated_center}.") + else: + intensities.append(median_intensity) + if len(intensities) == 0: + print(f"No viable annotation for cochlea {cochlea} and crop {annotated_center}.") + else: + intensity_dic[annotated_center] = {"median_intensity": float(sum(intensities) / len(intensities))} + + return intensity_dic + + def evaluate_marker_annotation( - cochleae, + cochleae: List[str], output_dir: str, annotation_dirs: Optional[List[str]] = None, seg_name: str = "SGN_v2", marker_name: str = "GFP", -): + threshold_save_dir: Optional[str] = None, + force: bool = False, +) -> None: """Evaluate marker annotations of a single or multiple annotators. Segmentation instances are assigned a positive (1) or negative label (2) in form of the "marker_label" component of the output segmentation table. @@ -91,10 +142,12 @@ def evaluate_marker_annotation( Args: cochleae: List of cochlea - output_dir: Output directory for segmentation table with 'marker_label' in format __.tsv + output_dir: Output directory for segmentation table with "marker_label" in format __.tsv annotation_dirs: List of directories containing marker annotations by annotator(s). seg_name: Identifier for segmentation. marker_name: Identifier for marker stain. + threshold_save_dir: Optional directory for saving the thresholds. + force: Whether to overwrite already existing results. """ input_key = "s0" @@ -104,11 +157,17 @@ def evaluate_marker_annotation( annotation_dirs = [entry.path for entry in os.scandir(marker_dir) if os.path.isdir(entry) and "Results" in entry.name] + seg_string = "-".join(seg_name.split("_")) for cochlea in cochleae: + cochlea_str = "-".join(cochlea.split("_")) + out_path = os.path.join(output_dir, f"{cochlea_str}_{marker_name}_{seg_string}.tsv") + if os.path.exists(out_path) and not force: + continue + cochlea_annotations = [a for a in annotation_dirs if len(find_annotations(a, cochlea)["center_strings"]) != 0] print(f"Evaluating data for cochlea {cochlea} in {cochlea_annotations}.") - # get segmentation data + # Get the segmentation data and table. input_path = f"{cochlea}/images/ome-zarr/{seg_name}.ome.zarr" input_path, fs = get_s3_path(input_path) data_seg = read_image_data(input_path, input_key) @@ -118,58 +177,44 @@ def evaluate_marker_annotation( with fs.open(table_path_s3, "r") as f: table_seg = pd.read_csv(f, sep="\t") - seg_string = "-".join(seg_name.split("_")) table_measurement_path = f"{cochlea}/tables/{seg_name}/{marker_name}_{seg_string}_object-measures.tsv" table_path_s3, fs = get_s3_path(table_measurement_path) with fs.open(table_path_s3, "r") as f: table_measurement = pd.read_csv(f, sep="\t") - # find median intensities by averaging all individual annotations for specific crops - annotation_dics = {} - annotated_centers = [] - for annotation_dir in cochlea_annotations: - - annotation_dic = localize_median_intensities(annotation_dir, cochlea, data_seg, table_measurement) - annotated_centers.extend(annotation_dic["center_strings"]) - annotation_dics[annotation_dir] = annotation_dic - - annotated_centers = list(set(annotated_centers)) - intensity_dic = {} - # loop over all annotated blocks - for annotated_center in annotated_centers: - intensities = [] - # loop over annotated block from single user - for annotator_key in annotation_dics.keys(): - if annotated_center not in annotation_dics[annotator_key]["center_strings"]: - continue - else: - intensities.append(annotation_dics[annotator_key][annotated_center]["median_intensity"]) - intensity_dic[annotated_center] = {"median_intensity": float(sum(intensities) / len(intensities))} + # Find the threholds from the annotated blocks and save it if specified. + intensity_dic = find_thresholds(cochlea_annotations, cochlea, data_seg, table_measurement) + if threshold_save_dir is not None: + os.makedirs(threshold_save_dir, exist_ok=True) + threshold_out_path = os.path.join(threshold_save_dir, f"{cochlea_str}_{marker_name}_{seg_string}.json") + with open(threshold_out_path, "w") as f: + json.dump(intensity_dic, f, sort_keys=True, indent=4) + # Apply the threshold to all SGNs. table_seg = apply_nearest_threshold(intensity_dic, table_seg, table_measurement) - cochlea_str = "-".join(cochlea.split("_")) - out_path = os.path.join(output_dir, f"{cochlea_str}_{marker_name}_{seg_string}.tsv") + + # Save the table with positives / negatives for all SGNs. + os.makedirs(output_dir, exist_ok=True) table_seg.to_csv(out_path, sep="\t", index=False) def main(): parser = argparse.ArgumentParser( - description="Assign each segmentation instance a marker based on annotation thresholds.") - - parser.add_argument('-c', "--cochlea", type=str, nargs="+", required=True, - help="Cochlea(e) to process.") - parser.add_argument('-o', "--output", type=str, required=True, help="Output directory.") + description="Assign each segmentation instance a marker based on annotation thresholds." + ) - parser.add_argument('-a', '--annotation_dirs', type=str, nargs="+", default=None, + parser.add_argument("-c", "--cochlea", type=str, nargs="+", default=COCHLEAE, help="Cochlea(e) to process.") + parser.add_argument("-o", "--output", type=str, required=True, help="Output directory.") + parser.add_argument("-a", "--annotation_dirs", type=str, nargs="+", default=None, help="Directories containing marker annotations.") + parser.add_argument("--threshold_save_dir", "-t") + parser.add_argument("-f", "--force", action="store_true") args = parser.parse_args() - evaluate_marker_annotation( - args.cochlea, args.output, args.annotation_dirs, + args.cochlea, args.output, args.annotation_dirs, threshold_save_dir=args.threshold_save_dir, force=args.force ) if __name__ == "__main__": - main() diff --git a/scripts/measurements/measure_synapses.py b/scripts/measurements/measure_synapses.py index f53ec7a..38fdcda 100644 --- a/scripts/measurements/measure_synapses.py +++ b/scripts/measurements/measure_synapses.py @@ -11,7 +11,8 @@ def check_project(plot=False, save_ihc_table=False, max_dist=None): s3 = create_s3_target() - cochleae = ['M_LR_000226_L', 'M_LR_000226_R', 'M_LR_000227_L', 'M_LR_000227_R', 'M_AMD_OTOF1_L'] + # cochleae = ["M_LR_000226_L", "M_LR_000226_R", "M_LR_000227_L", "M_LR_000227_R", "M_AMD_OTOF1_L"] + cochleae = ["M_LR_000226_L", "M_LR_000226_R", "M_LR_000227_L", "M_LR_000227_R"] results = {} for cochlea in cochleae: @@ -19,7 +20,7 @@ def check_project(plot=False, save_ihc_table=False, max_dist=None): ihc_table_name = "IHC_v4c" component_id = [1] - if cochlea == 'M_AMD_OTOF1_L': + if cochlea == "M_AMD_OTOF1_L": synapse_table_name = "synapse_v3_ihc_v4b" ihc_table_name = "IHC_v4b" component_id = [3, 11] @@ -61,6 +62,13 @@ def check_project(plot=False, save_ihc_table=False, max_dist=None): print("@ max dist:", max_dist) print() + run_length_dict = { + ihc_id: run_length for ihc_id, run_length in zip(ihc_table.label_id.values, ihc_table["length[µm]"].values) + } + frequency_dict = { + ihc_id: freq for ihc_id, freq in zip(ihc_table.label_id.values, ihc_table["frequency[kHz]"].values) + } + if save_ihc_table: ihc_to_count = {ihc_id: count for ihc_id, count in zip(ihc_ids, syn_per_ihc)} unmatched_ihcs = np.setdiff1d(valid_ihcs, ihc_ids) @@ -71,6 +79,8 @@ def check_project(plot=False, save_ihc_table=False, max_dist=None): "snyapse_table": [synapse_table_name for _ in list(ihc_to_count.values())], "ihc_table": [ihc_table_name for _ in list(ihc_to_count.values())], "max_dist": [max_dist for _ in list(ihc_to_count.values())], + "run_length": [run_length_dict[ihc_id] for ihc_id in ihc_to_count.keys()], + "frequency": [frequency_dict[ihc_id] for ihc_id in ihc_to_count.keys()] }) os.makedirs(OUTPUT_FOLDER, exist_ok=True) output_path = os.path.join(OUTPUT_FOLDER, f"ihc_count_{cochlea}.tsv")