diff --git a/flamingo_tools/segmentation/unet_prediction.py b/flamingo_tools/segmentation/unet_prediction.py index 63c76b4..902ad39 100644 --- a/flamingo_tools/segmentation/unet_prediction.py +++ b/flamingo_tools/segmentation/unet_prediction.py @@ -571,6 +571,8 @@ def run_unet_prediction_slurm( """ os.makedirs(output_folder, exist_ok=True) prediction_instances = int(prediction_instances) + if isinstance(scale, str): + scale = float(scale) slurm_task_id = os.environ.get("SLURM_ARRAY_TASK_ID") if s3 is not None: diff --git a/scripts/figures/plot_fig2.py b/scripts/figures/plot_fig2.py index b1bb2bb..295acd1 100644 --- a/scripts/figures/plot_fig2.py +++ b/scripts/figures/plot_fig2.py @@ -4,62 +4,189 @@ import numpy as np import pandas as pd import matplotlib.pyplot as plt -import tifffile -from matplotlib import colors -from skimage.segmentation import find_boundaries +import matplotlib.ticker as mticker +from matplotlib.lines import Line2D from util import literature_reference_values, SYNAPSE_DIR_ROOT -from util import prism_style, prism_cleanup_axes +from util import prism_style, prism_cleanup_axes, export_legend, custom_formatter_2 png_dpi = 300 FILE_EXTENSION = "png" +COLOR_P = "#9C5027" +COLOR_R = "#67279C" +COLOR_F = "#9C276F" +COLOR_T = "#279C52" -def scramble_instance_labels(arr): - """Scramble indexes of instance segmentation to avoid neighboring colors. +COLOR_MEASUREMENT = "#9C7427" +COLOR_LITERATURE = "#27339C" + + +def plot_legend_suppfig02(save_path): + """Plot common legend for figure 2c. + + Args: + save_path: save path to save legend. """ - unique = list(np.unique(arr)[1:]) - rng = np.random.default_rng(seed=42) - new_list = rng.uniform(1, len(unique) + 1, size=(len(unique))) - new_arr = np.zeros(arr.shape) - for old_id, new_id in zip(unique, new_list): - new_arr[arr == old_id] = new_id - return new_arr - - -def plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba=[0, 0, 0, 0.5], plot=False): - seg = tifffile.imread(seg_path) - if len(seg.shape) == 3: - seg = seg[10, xlim1:xlim2, ylim1:ylim2] - else: - seg = seg[xlim1:xlim2, ylim1:ylim2] + # Colors + color = [COLOR_P, COLOR_R, COLOR_F, COLOR_T] + label = ["Precision", "Recall", "F1-score", "Processing time"] + + fl = lambda c: Line2D([], [], lw=3, color=c) + handles = [fl(c) for c in color] + legend = plt.legend(handles, label, loc=3, ncol=len(label), framealpha=1, frameon=False) + export_legend(legend, save_path) + legend.remove() + plt.close() + + +def supp_fig_02(save_path, plot=False, segm="SGN", mode="precision"): + # SGN + value_dict = { + "SGN": { + "stardist": { + "label": "Stardist", + "precision": 0.706, + "recall": 0.630, + "f1-score": 0.628, + "marker": "o", + "runtime": 536.5, + "runtime_std": 148.4 + + }, + "micro_sam": { + "label": "µSAM", + "precision": 0.140, + "recall": 0.782, + "f1-score": 0.228, + "marker": "D", + "runtime": 407.5, + "runtime_std": 107.5 + }, + "cellpose_3": { + "label": "Cellpose 3", + "precision": 0.117, + "recall": 0.607, + "f1-score": 0.186, + "marker": "v", + "runtime": None, + "runtime_std": None + }, + "cellpose_sam": { + "label": "Cellpose-SAM", + "precision": 0.250, + "recall": 0.003, + "f1-score": 0.005, + "marker": "^", + "runtime": 167.9, + "runtime_std": 40.2 + }, + "distance_unet": { + "label": "CochleaNet", + "precision": 0.886, + "recall": 0.804, + "f1-score": 0.837, + "marker": "s", + "runtime": 168.8, + "runtime_std": 21.8 + }, + }, + "IHC": { + "micro_sam": { + "label": "µSAM", + "precision": 0.053, + "recall": 0.684, + "f1-score": 0.094, + "marker": "D", + "runtime": 445.6, + "runtime_std": 106.6 + }, + "cellpose_3": { + "label": "Cellpose 3", + "precision": 0.375, + "recall": 0.554, + "f1-score": 0.329, + "marker": "v", + "runtime": 30.1, + "runtime_std": 162.3 + }, + "cellpose_sam": { + "label": "Cellpose-SAM", + "precision": 0.636, + "recall": 0.025, + "f1-score": 0.047, + "marker": "^", + "runtime": None, + "runtime_std": None + }, + "distance_unet": { + "label": "CochleaNet", + "precision": 0.664, + "recall": 0.661, + "f1-score": 0.659, + "marker": "s", + "runtime": 65.7, + "runtime_std": 72.6 + }, + } + } + + # Convert setting labels to numerical x positions + offset = 0.08 # horizontal shift for scatter separation + + # Plot + fig, ax = plt.subplots(figsize=(8, 5)) - img = tifffile.imread(img_path) - img = img[10, xlim1:xlim2, ylim1:ylim2] + main_label_size = 20 + main_tick_size = 16 + + labels = [value_dict[segm][key]["label"] for key in value_dict[segm].keys()] + + if mode == "precision": + # Convert setting labels to numerical x positions + offset = 0.08 # horizontal shift for scatter separation + for num, key in enumerate(list(value_dict[segm].keys())): + precision = [value_dict[segm][key]["precision"]] + recall = [value_dict[segm][key]["recall"]] + f1score = [value_dict[segm][key]["f1-score"]] + marker = value_dict[segm][key]["marker"] + x_pos = num + 1 - # create color map with random distribution for coloring instance segmentation - unique = list(np.unique(seg)[1:]) - n_instances = len(unique) + plt.scatter([x_pos - offset], precision, label="Precision manual", color=COLOR_P, marker=marker, s=80) + plt.scatter([x_pos], recall, label="Recall manual", color=COLOR_R, marker=marker, s=80) + plt.scatter([x_pos + offset], f1score, label="F1-score manual", color=COLOR_F, marker=marker, s=80) - seg = scramble_instance_labels(seg) + # Labels and formatting + x_pos = np.arange(1, len(labels)+1) + plt.xticks(x_pos, labels, fontsize=16) + plt.yticks(fontsize=main_tick_size) + plt.ylabel("Value", fontsize=main_label_size) + plt.ylim(-0.1, 1) + # plt.legend(loc="lower right", fontsize=legendsize) + plt.grid(axis="y", linestyle="solid", alpha=0.5) - rng = np.random.default_rng(seed=42) # fixed seed for reproducibility - colors_array = rng.uniform(0, 1, size=(n_instances, 4)) # RGBA values in [0,1] - colors_array[:, 3] = 1.0 # full alpha - colors_array[0, 3] = 0.0 # make label 0 transparent (background) - cmap = colors.ListedColormap(colors_array) + elif mode == "runtime": + # Convert setting labels to numerical x positions + offset = 0.08 # horizontal shift for scatter separation + for num, key in enumerate(list(value_dict[segm].keys())): + runtime = [value_dict[segm][key]["runtime"]] + marker = value_dict[segm][key]["marker"] + x_pos = num + 1 - boundaries = find_boundaries(seg, mode="inner") - boundary_overlay = np.zeros((*boundaries.shape, 4)) + plt.scatter([x_pos], runtime, label="Runtime", color=COLOR_T, marker=marker, s=80) - boundary_overlay[boundaries] = boundary_rgba # RGBA = black + # Labels and formatting + x_pos = np.arange(1, len(labels)+1) + plt.xticks(x_pos, labels, fontsize=16) + plt.yticks(fontsize=main_tick_size) + plt.ylabel("Processing time [s]", fontsize=main_label_size) + plt.ylim(-0.1, 600) + # plt.legend(loc="lower right", fontsize=legendsize) + plt.grid(axis="y", linestyle="solid", alpha=0.5) - fig, ax = plt.subplots(figsize=(6, 6)) - ax.imshow(img, cmap="gray") - ax.imshow(seg, cmap=cmap, alpha=0.5, interpolation="nearest") - ax.imshow(boundary_overlay) - ax.axis("off") plt.tight_layout() + prism_cleanup_axes(ax) + if ".png" in save_path: plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) else: @@ -71,52 +198,40 @@ def plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, bou plt.close() -def fig_02a_sgn(save_dir, plot=False): - """Plot crops of SGN segmentation of CochleaNet, Cellpose and micro-sam. - """ - cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet" - val_sgn_dir = f"{cochlea_dir}/predictions/val_sgn" - image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationSGNs/for_consensus_annotation" - - crop_name = "MLR169R_PV_z3420_allturns_full" - img_path = os.path.join(image_dir, f"{crop_name}.tif") - - xlim1 = 2000 - xlim2 = 2500 - ylim1 = 3100 - ylim2 = 3600 - boundary_rgba = [1, 1, 1, 0.5] - - for seg_net in ["distance_unet", "cellpose-sam", "micro-sam"]: - save_path = os.path.join(save_dir, f"fig_02b_sgn_{seg_net}.png") - seg_dir = os.path.join(val_sgn_dir, seg_net) - seg_path = os.path.join(seg_dir, f"{crop_name}_seg.tif") - - plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba, plot=plot) +def plot_legend_fig02c(save_path, plot_mode="shapes"): + """Plot common legend for figure 2c. - -def fig_02b_ihc(save_dir, plot=False): - """Plot crops of IHC segmentation of CochleaNet, Cellpose and micro-sam. + Args:. + save_path: save path to save legend. + plot_mode: Plot either 'shapes' or 'colors' of data points. """ - cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet" - val_sgn_dir = f"{cochlea_dir}/predictions/val_ihc" - image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationIHCs" - - crop_name = "MLR226L_VGlut3_z1200_3turns_full" - img_path = os.path.join(image_dir, f"{crop_name}.tif") + if plot_mode == "shapes": + # Shapes + color = ["black", "black"] + marker = ["o", "s"] + label = ["Manual", "Automatic"] + + f = lambda m, c: plt.plot([], [], marker=m, color=c, ls="none")[0] + handles = [f(m, c) for (c, m) in zip(color, marker)] + legend = plt.legend(handles, label, loc=3, ncol=len(label), framealpha=1, frameon=False) + export_legend(legend, save_path) + legend.remove() + plt.close() - xlim1 = 1900 - xlim2 = 2400 - ylim1 = 2000 - ylim2 = 2500 - boundary_rgba = [1, 1, 1, 0.5] + elif plot_mode =="colors": + # Colors + color = [COLOR_P, COLOR_R, COLOR_F] + label = ["Precision", "Recall", "F1-score"] - for seg_net in ["distance_unet_v4b", "cellpose-sam", "micro-sam"]: - save_path = os.path.join(save_dir, f"fig_02b_ihc_{seg_net}.png") - seg_dir = os.path.join(val_sgn_dir, seg_net) - seg_path = os.path.join(seg_dir, f"{crop_name}_seg.tif") + fl = lambda c: Line2D([], [], lw=3, color=c) + handles = [fl(c) for c in color] + legend = plt.legend(handles, label, loc=3, ncol=len(label), framealpha=1, frameon=False) + export_legend(legend, save_path) + legend.remove() + plt.close() - plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba, plot=plot) + else: + raise ValueError("Choose either 'shapes' or 'colors' as plot_mode.") def fig_02c(save_path, plot=False, all_versions=False): @@ -128,88 +243,66 @@ def fig_02c(save_path, plot=False, all_versions=False): sgn_unet = [0.887, 0.88, 0.884] sgn_annotator = [0.95, 0.849, 0.9] - ihc_v4b = [0.91, 0.819, 0.862] ihc_v4c = [0.905, 0.831, 0.866] - ihc_v4c_filter = [0.919, 0.775, 0.841] ihc_annotator = [0.958, 0.956, 0.957] syn_unet = [0.931, 0.905, 0.918] - # This is the version with IHC v4b segmentation: - # 4th version of the network with optimized segmentation params - version_1 = [sgn_unet, sgn_annotator, ihc_v4b, ihc_annotator, syn_unet] - settings_1 = ["automatic", "manual", "automatic", "manual", "automatic"] + setting = ["SGN", "IHC", "Synapse"] # This is the version with IHC v4c segmentation: # 4th version of the network with optimized segmentation params and split of falsely merged IHCs - version_2 = [sgn_unet, sgn_annotator, ihc_v4c, ihc_annotator, syn_unet] - settings_2 = ["automatic", "manual", "automatic", "manual", "automatic"] - - # This is the version with IHC v4c + filter segmentation: - # 4th version of the network with optimized segmentation params and split of falsely merged IHCs - # + filtering out IHCs with zero mapped synapses. - version_3 = [sgn_unet, sgn_annotator, ihc_v4c_filter, ihc_annotator, syn_unet] - settings_3 = ["automatic", "manual", "automatic", "manual", "automatic"] - - if all_versions: - versions = [version_1, version_2, version_3] - settings = [settings_1, settings_2, settings_3] - save_suffix = ["_v4b", "_v4c", "_v4c_filter"] - save_paths = [save_path.split(".")[0] + i + "." + save_path.split(".")[1] for i in save_suffix] - else: - versions = [version_2] - settings = [settings_2] - save_suffix = ["_v4c"] - save_paths = [save_path.split(".")[0] + i + "." + save_path.split(".")[1] for i in save_suffix] + manual = [sgn_annotator, ihc_annotator] + automatic = [sgn_unet, ihc_v4c, syn_unet] - for version, setting, save_path in zip(versions, settings, save_paths): - precision = [i[0] for i in version] - recall = [i[1] for i in version] - f1score = [i[2] for i in version] + precision_manual = [i[0] for i in manual] + recall_manual = [i[1] for i in manual] + f1score_manual = [i[2] for i in manual] - descr_y = 0.72 + precision_automatic = [i[0] for i in automatic] + recall_automatic = [i[1] for i in automatic] + f1score_automatic = [i[2] for i in automatic] - # Convert setting labels to numerical x positions - x = np.array([0.8, 1.2, 1.8, 2.2, 3]) - offset = 0.08 # horizontal shift for scatter separation + # Convert setting labels to numerical x positions + x_manual = np.array([0.8, 1.8]) + x_automatic = np.array([1.2, 2.2, 3]) + offset = 0.08 # horizontal shift for scatter separation - # Plot - fig, ax = plt.subplots(figsize=(8, 5)) + # Plot + fig, ax = plt.subplots(figsize=(8, 4.5)) - main_label_size = 22 - sub_label_size = 16 - main_tick_size = 16 - legendsize = 18 + main_label_size = 20 + main_tick_size = 16 - plt.scatter(x - offset, precision, label="Precision", marker="o", s=80) - plt.scatter(x, recall, label="Recall", marker="^", s=80) - plt.scatter(x + offset, f1score, label="F1-score", marker="*", s=80) + plt.scatter(x_manual - offset, precision_manual, label="Precision manual", color=COLOR_P, marker="o", s=80) + plt.scatter(x_manual, recall_manual, label="Recall manual", color=COLOR_R, marker="o", s=80) + plt.scatter(x_manual + offset, f1score_manual, label="F1-score manual", color=COLOR_F, marker="o", s=80) - plt.text(1, descr_y, "SGN", fontsize=main_label_size, horizontalalignment="center") - plt.text(2, descr_y, "IHC", fontsize=main_label_size, horizontalalignment="center") - plt.text(3, descr_y, "Synapse", fontsize=main_label_size, horizontalalignment="center") + plt.scatter(x_automatic - offset, precision_automatic, label="Precision automatic", color=COLOR_P, marker="s", s=80) + plt.scatter(x_automatic, recall_automatic, label="Recall automatic", color=COLOR_R, marker="s", s=80) + plt.scatter(x_automatic + offset, f1score_automatic, label="F1-score automatic", color=COLOR_F, marker="s", s=80) - # Labels and formatting - plt.xticks(x, setting, fontsize=sub_label_size) - plt.yticks(fontsize=main_tick_size) - plt.ylabel("Value", fontsize=main_label_size) - plt.ylim(0.76, 1) - plt.legend(loc="lower right", - fontsize=legendsize) - plt.grid(axis="y", linestyle="--", alpha=0.5) + # Labels and formatting + plt.xticks([1, 2, 3], setting, fontsize=main_label_size) + plt.yticks(fontsize=main_tick_size) + ax.yaxis.set_major_formatter(mticker.FuncFormatter(custom_formatter_2)) + plt.ylabel("Value", fontsize=main_label_size) + plt.ylim(0.76, 1) + # plt.legend(loc="lower right", fontsize=legendsize) + plt.grid(axis="y", linestyle="solid", alpha=0.5) - plt.tight_layout() - prism_cleanup_axes(ax) + plt.tight_layout() + prism_cleanup_axes(ax) - if ".png" in save_path: - plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) - else: - plt.savefig(save_path, bbox_inches='tight', pad_inches=0) + if ".png" in save_path: + plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) + else: + plt.savefig(save_path, bbox_inches='tight', pad_inches=0) - if plot: - plt.show() - else: - plt.close() + if plot: + plt.show() + else: + plt.close() # Load the synapse counts for all IHCs from the relevant tables. @@ -224,183 +317,102 @@ def _load_ribbon_synapse_counts(): return syn_counts -def fig_02d_01(save_path, plot=False, all_versions=False, plot_average_ribbon_synapses=False): +def fig_02d(save_path, plot=False, plot_average_ribbon_synapses=False): """Box plot showing the counts for SGN and IHC per (mouse) cochlea in comparison to literature values. """ - main_tick_size = 20 - main_label_size = 26 + prism_style() + main_tick_size = 16 + main_label_size = 20 rows = 1 columns = 3 if plot_average_ribbon_synapses else 2 sgn_values = [11153, 11398, 10333, 11820] - ihc_v4b_values = [836, 808, 796, 901] - ihc_v4c_values = [712, 710, 721, 675] - ihc_v4c_filtered_values = [562, 647, 626, 628] - - if all_versions: - ihc_list = [ihc_v4b_values, ihc_v4c_values, ihc_v4c_filtered_values] - suffixes = ["_v4b", "_v4c", "_v4c_filtered"] - assert not plot_average_ribbon_synapses - else: - ihc_list = [ihc_v4c_values] - suffixes = ["_v4c"] + ihc_values = [712, 710, 721, 675] + + fig, axes = plt.subplots(rows, columns, figsize=(10, 4.5)) + ax = axes.flatten() + box_plot = ax[0].boxplot(sgn_values, patch_artist=True, zorder=1) + for median in box_plot['medians']: + median.set_color(COLOR_MEASUREMENT) + for boxcolor in box_plot['boxes']: + boxcolor.set_facecolor("white") + + box_plot = ax[1].boxplot(ihc_values, patch_artist=True, zorder=1) + for median in box_plot['medians']: + median.set_color(COLOR_MEASUREMENT) + for boxcolor in box_plot['boxes']: + boxcolor.set_facecolor("white") - for (ihc_values, suffix) in zip(ihc_list, suffixes): - fig, axes = plt.subplots(rows, columns, figsize=(columns*4, rows*4)) - ax = axes.flatten() - - save_path_new = save_path.split(".")[0] + suffix + "." + save_path.split(".")[1] - ax[0].boxplot(sgn_values) - ax[1].boxplot(ihc_values) - - # Labels and formatting - ax[0].set_xticklabels(["SGN"], fontsize=main_label_size) - - ylim0 = 8500 - ylim1 = 12500 - y_ticks = [i for i in range(9000, 12000 + 1, 1000)] - - ax[0].set_ylabel("Count per cochlea", fontsize=main_label_size) - ax[0].set_yticks(y_ticks) - ax[0].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) - ax[0].set_ylim(ylim0, ylim1) - ax[0].yaxis.set_ticks_position("left") + # Labels and formatting + ax[0].set_xticklabels(["SGN"], fontsize=main_label_size) + + ylim0 = 8500 + ylim1 = 12500 + y_ticks = [i for i in range(9000, 12000 + 1, 1000)] + + ax[0].set_ylabel("Count per cochlea", fontsize=main_label_size) + ax[0].set_yticks(y_ticks) + ax[0].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) + ax[0].set_ylim(ylim0, ylim1) + ax[0].yaxis.set_ticks_position("left") + + # set range of literature values + xmin = 0.5 + xmax = 1.5 + ax[0].set_xlim(xmin, xmax) + lower_y, upper_y = literature_reference_values("SGN") + ax[0].hlines([lower_y, upper_y], xmin, xmax, color=COLOR_LITERATURE) + ax[0].text(1., lower_y + (upper_y - lower_y) * 0.2, "literature", + color=COLOR_LITERATURE, fontsize=main_label_size, ha="center") + ax[0].fill_between([xmin, xmax], lower_y, upper_y, color="C0", alpha=0.05, interpolate=True) + + ylim0 = 600 + ylim1 = 800 + y_ticks = [i for i in range(600, 800 + 1, 100)] + + ax[1].set_xticklabels(["IHC"], fontsize=main_label_size) + ax[1].set_yticks(y_ticks) + ax[1].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) + ax[1].set_ylim(ylim0, ylim1) + if not plot_average_ribbon_synapses: + ax[1].yaxis.tick_right() + ax[1].yaxis.set_ticks_position("right") + + # set range of literature values + xmin = 0.5 + xmax = 1.5 + lower_y, upper_y = literature_reference_values("IHC") + ax[1].set_xlim(xmin, xmax) + ax[1].hlines([lower_y, upper_y], xmin, xmax, color=COLOR_LITERATURE) + ax[1].fill_between([xmin, xmax], lower_y, upper_y, color=COLOR_LITERATURE, alpha=0.05, interpolate=True) + + if plot_average_ribbon_synapses: + ribbon_synapse_counts = _load_ribbon_synapse_counts() + ylim0 = -1 + ylim1 = 41 + y_ticks = [0, 10, 20, 30, 40, 50] + + box_plot = ax[2].boxplot(ribbon_synapse_counts, patch_artist=True, zorder=1) + for median in box_plot['medians']: + median.set_color(COLOR_MEASUREMENT) + for boxcolor in box_plot['boxes']: + boxcolor.set_facecolor("white") + + ax[2].set_xticklabels(["Synapses per IHC"], fontsize=main_label_size) + ax[2].set_yticks(y_ticks) + ax[2].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) + ax[2].set_ylim(ylim0, ylim1) # set range of literature values xmin = 0.5 xmax = 1.5 - ax[0].set_xlim(xmin, xmax) - lower_y, upper_y = literature_reference_values("SGN") - ax[0].hlines([lower_y, upper_y], xmin, xmax) - ax[0].text(1., lower_y + (upper_y - lower_y) * 0.2, "literature", - color="C0", fontsize=main_tick_size, ha="center") - ax[0].fill_between([xmin, xmax], lower_y, upper_y, color="C0", alpha=0.05, interpolate=True) - - ylim0 = 600 - ylim1 = 800 - y_ticks = [i for i in range(600, 800 + 1, 100)] - - ax[1].set_xticklabels(["IHC"], fontsize=main_label_size) - ax[1].set_yticks(y_ticks) - ax[1].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) - ax[1].set_ylim(ylim0, ylim1) - if not plot_average_ribbon_synapses: - ax[1].yaxis.tick_right() - ax[1].yaxis.set_ticks_position("right") + lower_y, upper_y = literature_reference_values("synapse") + ax[2].set_xlim(xmin, xmax) + ax[2].hlines([lower_y, upper_y], xmin, xmax, color=COLOR_LITERATURE) + ax[2].fill_between([xmin, xmax], lower_y, upper_y, color=COLOR_LITERATURE, alpha=0.05, interpolate=True) - # set range of literature values - xmin = 0.5 - xmax = 1.5 - lower_y, upper_y = literature_reference_values("IHC") - ax[1].set_xlim(xmin, xmax) - ax[1].hlines([lower_y, upper_y], xmin, xmax) - # ax[1].text(1.1, (lower_y + upper_y) // 2, "literature", color="C0", fontsize=main_tick_size, ha="left") - ax[1].fill_between([xmin, xmax], lower_y, upper_y, color="C0", alpha=0.05, interpolate=True) - - if plot_average_ribbon_synapses: - ribbon_synapse_counts = _load_ribbon_synapse_counts() - ylim0 = -1 - ylim1 = 41 - y_ticks = [0, 10, 20, 30, 40, 50] - - ax[2].boxplot(ribbon_synapse_counts) - ax[2].set_xticklabels(["Ribbon Syn. per IHC"], fontsize=main_label_size) - ax[2].set_yticks(y_ticks) - ax[2].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) - ax[2].set_ylim(ylim0, ylim1) - - # set range of literature values - xmin = 0.5 - xmax = 1.5 - lower_y, upper_y = literature_reference_values("synapse") - ax[2].set_xlim(xmin, xmax) - ax[2].hlines([lower_y, upper_y], xmin, xmax) - # ax[2].text(1.1, (lower_y + upper_y) // 2, "literature", color="C0", fontsize=main_tick_size, ha="left") - ax[2].fill_between([xmin, xmax], lower_y, upper_y, color="C0", alpha=0.05, interpolate=True) - - plt.tight_layout() - - if ".png" in save_path: - plt.savefig(save_path_new, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) - else: - plt.savefig(save_path_new, bbox_inches='tight', pad_inches=0) - - if plot: - plt.show() - else: - plt.close() - - -def fig_02d_02(save_path, filter_zeros=True, plot=False): - """Bar plot showing the distribution of synapse markers per IHC segmentation average over multiple clochleae. - """ - cochleae = ["M_LR_000226_L", "M_LR_000226_R", "M_LR_000227_L", "M_LR_000227_R"] - ihc_version = "ihc_counts_v4b" - synapse_dir = os.path.join(SYNAPSE_DIR_ROOT, ihc_version) - - max_dist = 3 - bins = 10 - cap = 30 - plot_density = False - - results = [] - for cochlea in cochleae: - synapse_file = os.path.join(synapse_dir, f"ihc_count_{cochlea}.tsv") - ihc_table = pd.read_csv(synapse_file, sep="\t") - syn_per_ihc = list(ihc_table["synapse_count"]) - if filter_zeros: - syn_per_ihc = [s for s in syn_per_ihc if s != 0] - results.append(syn_per_ihc) - - results = [np.clip(r, 0, cap) for r in results] - - # Define bins (shared for all datasets) - bins = np.linspace(0, 30, 11) # 29 bins - - # Compute histogram (relative) for each dataset - histograms = [] - for data in results: - counts, _ = np.histogram(data, bins=bins, density=plot_density) - histograms.append(counts) - - histograms = np.array(histograms) - - # Compute mean and std for each bin across datasets - mean_counts = histograms.mean(axis=0) - std_counts = histograms.std(axis=0) - - # Get bin centers for plotting - bin_centers = 0.5 * (bins[1:] + bins[:-1]) - - # Plot - plt.figure(figsize=(8, 5)) - plt.bar(bin_centers, mean_counts, width=(bins[1] - bins[0]), yerr=std_counts, alpha=0.7, capsize=4, - label="Mean ± Std Dev", edgecolor="black") - - main_label_size = 20 - main_tick_size = 16 - legendsize = 16 - - # Labels and formatting - x_ticks = [i for i in range(0, cap + 1, 5)] - if plot_density: - y_ticks = [i * 0.02 for i in range(0, 10, 2)] - else: - y_ticks = [i for i in range(0, 300, 50)] - - plt.xticks(x_ticks, fontsize=main_tick_size) - plt.yticks(y_ticks, fontsize=main_tick_size) - if plot_density: - plt.ylabel("Proportion of IHCs [%]", fontsize=main_label_size) - else: - plt.ylabel("Number of IHCs", fontsize=main_label_size) - plt.xlabel(f"Ribbon Synapses per IHC @ {max_dist} µm", fontsize=main_label_size) - - plt.title("Average Synapses per IHC for a Dataset of 4 Cochleae") - - plt.grid(axis="y", linestyle="--", alpha=0.5) - plt.legend(fontsize=legendsize) + prism_cleanup_axes(axes) plt.tight_layout() if ".png" in save_path: @@ -422,22 +434,21 @@ def main(): os.makedirs(args.figure_dir, exist_ok=True) - # Panes A and B: Qualitative comparison of visualization results. - fig_02a_sgn(save_dir=args.figure_dir, plot=args.plot) - fig_02b_ihc(save_dir=args.figure_dir, plot=args.plot) - # Panel C: Evaluation of the segmentation results: fig_02c(save_path=os.path.join(args.figure_dir, f"fig_02c.{FILE_EXTENSION}"), plot=args.plot, all_versions=False) + plot_legend_fig02c(os.path.join(args.figure_dir, f"fig_02c_legend_shapes.{FILE_EXTENSION}"), plot_mode="shapes") + plot_legend_fig02c(os.path.join(args.figure_dir, f"fig_02c_legend_colors.{FILE_EXTENSION}"), plot_mode="colors") # Panel D: The number of SGNs, IHCs and average number of ribbon synapses per IHC - fig_02d_01(save_path=os.path.join(args.figure_dir, f"fig_02d.{FILE_EXTENSION}"), + fig_02d(save_path=os.path.join(args.figure_dir, f"fig_02d.{FILE_EXTENSION}"), plot=args.plot, plot_average_ribbon_synapses=True) - # Alternative version of synapse distribution for panel D. - # fig_02d_02(save_path=os.path.join(args.figure_dir, "fig_02d_02"), plot=args.plot) - # fig_02d_02(save_path=os.path.join(args.figure_dir, "fig_02d_02_v4c"), filter_zeros=False, plot=plot) - # fig_02d_02(save_path=os.path.join(args.figure_dir, "fig_02d_02_v4c_filtered"), filter_zeros=True, plot=plot) - # fig_02d_02(save_path=os.path.join(args.figure_dir, "fig_02d_02_v4b"), filter_zeros=True, plot=args.plot) + # Supplementary Figure 2: Comparing other methods in terms of segmentation accuracy and runtime + plot_legend_suppfig02(save_path=os.path.join(args.figure_dir, f"suppfig02_legend_colors.{FILE_EXTENSION}")) + supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_sgn.{FILE_EXTENSION}"), segm="SGN") + supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_ihc.{FILE_EXTENSION}"), segm="IHC") + supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_sgn_time.{FILE_EXTENSION}"), segm="SGN", mode="runtime") + supp_fig_02(save_path=os.path.join(args.figure_dir, f"figsupp_02_ihc_time.{FILE_EXTENSION}"), segm="IHC", mode="runtime") if __name__ == "__main__": diff --git a/scripts/figures/plot_fig3.py b/scripts/figures/plot_fig3.py index b3a5389..945d276 100644 --- a/scripts/figures/plot_fig3.py +++ b/scripts/figures/plot_fig3.py @@ -13,11 +13,10 @@ from matplotlib import cm, colors from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target -from util import sliding_runlength_sum, frequency_mapping, prism_style, prism_cleanup_axes, SYNAPSE_DIR_ROOT +from util import sliding_runlength_sum, frequency_mapping, SYNAPSE_DIR_ROOT +from util import prism_style, prism_cleanup_axes, export_legend -# INPUT_ROOT = "/home/pape/Work/my_projects/flamingo-tools/scripts/M_LR_000227_R/scale3" INPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/frequency_mapping/M_LR_000227_R/scale3" -FILE_EXTENSION = "png" TYPE_TO_CHANNEL = { "Type-Ia": "CR", @@ -32,10 +31,10 @@ # The cochlea for the CHReef analysis. COCHLEAE_DICT = { - "M_LR_000226_L": {"alias": "M01L", "component": [1]}, - "M_LR_000226_R": {"alias": "M01R", "component": [1]}, - "M_LR_000227_L": {"alias": "M02L", "component": [1]}, - "M_LR_000227_R": {"alias": "M02R", "component": [1]}, + "M_LR_000226_L": {"alias": "M01L", "component": [1], "color": "#9C5027"}, + "M_LR_000226_R": {"alias": "M01R", "component": [1], "color": "#279C52"}, + "M_LR_000227_L": {"alias": "M02L", "component": [1], "color": "#67279C"}, + "M_LR_000227_R": {"alias": "M02R", "component": [1], "color": "#27339C"}, } @@ -93,7 +92,7 @@ def get_tonotopic_data(): return pickle.load(f) -def _plot_colormap(vol, title, plot, save_path): +def _plot_colormap(vol, title, plot, save_path, cmap="viridis"): # before creating the figure: matplotlib.rcParams.update({ "font.size": 14, # base font size @@ -110,10 +109,16 @@ def _plot_colormap(vol, title, plot, save_path): freq_min = np.min(np.nonzero(vol)) freq_max = vol.max() - norm = colors.Normalize(vmin=freq_min, vmax=freq_max, clip=True) - cmap = plt.get_cmap("viridis") + # norm = colors.Normalize(vmin=freq_min, vmax=freq_max, clip=True) + norm = colors.LogNorm(vmin=freq_min, vmax=freq_max, clip=True) + tick_values = np.array([10, 20, 40, 80]) + + cmap = plt.get_cmap(cmap) - cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax, orientation="horizontal") + cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax, orientation="horizontal", + ticks=tick_values) + cb.ax.xaxis.set_major_formatter(matplotlib.ticker.ScalarFormatter()) + cb.ax.xaxis.set_minor_locator(matplotlib.ticker.NullLocator()) cb.set_label("Frequency [kHz]") plt.title(title) plt.tight_layout() @@ -127,19 +132,29 @@ def _plot_colormap(vol, title, plot, save_path): plt.close() -def fig_03a(save_path, plot, plot_napari): +def fig_03a(save_path, plot, plot_napari, cmap="viridis"): path_ihc = os.path.join(INPUT_ROOT, "frequencies_IHC_v4c.tif") path_sgn = os.path.join(INPUT_ROOT, "frequencies_SGN_v2.tif") sgn = imageio.imread(path_sgn) ihc = imageio.imread(path_ihc) - _plot_colormap(sgn, title="Tonotopic Mapping", plot=plot, save_path=save_path) + _plot_colormap(sgn, title="Tonotopic Mapping", plot=plot, save_path=save_path, cmap=cmap) # Show the image in napari for rendering. if plot_napari: import napari + from napari.utils import Colormap + # cmap = plt.get_cmap(cmap) + mpl_cmap = plt.get_cmap(cmap) + + # Sample it into an array of RGBA values + colors = mpl_cmap(np.linspace(0, 1, 256)) + + # Wrap into napari Colormap + napari_cmap = Colormap(colors, name=f"{cmap}_custom") + v = napari.Viewer() - v.add_image(ihc, colormap="viridis") - v.add_image(sgn, colormap="viridis") + v.add_image(ihc, colormap=napari_cmap) + v.add_image(sgn, colormap=napari_cmap) napari.run() @@ -152,7 +167,13 @@ def fig_03c_rl(save_path, plot=False): width = 50 # micron - for tab_path in tables: + colors = ["#664970", # M01L + "#704954", # M01R + "#537049", # M02L + "#49705B", # M02R + ] + + for num, tab_path in enumerate(tables): # TODO map to alias alias = os.path.basename(tab_path)[10:-4].replace("_", "").replace("0", "") tab = pd.read_csv(tab_path, sep="\t") @@ -161,7 +182,7 @@ def fig_03c_rl(save_path, plot=False): # Compute the running sum of 10 micron. run_length, syn_count_running = sliding_runlength_sum(run_length, syn_count, width=width) - ax.plot(run_length, syn_count_running, label=alias) + ax.plot(run_length, syn_count_running, label=alias, color=colors[num]) ax.set_xlabel("Length [µm]") ax.set_ylabel("Synapse Count") @@ -180,19 +201,84 @@ def fig_03c_rl(save_path, plot=False): plt.close() -def fig_03c_octave(tonotopic_data, save_path, plot=False, use_alias=True, trendlines=False): - prism_style() +def plot_legend_fig03c(save_path): + color_dict = {} + for key in COCHLEAE_DICT.keys(): + color_dict[COCHLEAE_DICT[key]["alias"]] = COCHLEAE_DICT[key]["color"] + + marker = ["o" for _ in color_dict] + label = list(color_dict.keys()) + color = [color_dict[key] for key in color_dict.keys()] + + f = lambda m, c: plt.plot([], [], marker=m, color=c, ls="none")[0] + handles = [f(m, c) for (c, m) in zip(color, marker)] + legend = plt.legend(handles, label, loc=3, ncol=2, framealpha=1, frameon=False) + export_legend(legend, save_path) + legend.remove() + plt.close() + + +def _get_trendline_dict(trend_dict,): + x_sorted = [trend_dict[k]["x_sorted"] for k in trend_dict.keys()] + x_dict = {} + for num in range(len(x_sorted[0])): + x_dict[num] = {"pos": num, "values": []} + + for s in x_sorted: + for num, pos in enumerate(s): + x_dict[num]["values"].append(pos) + + y_sorted_all = [trend_dict[k]["y_sorted"] for k in trend_dict.keys()] + y_dict = {} + for num in range(len(x_sorted[0])): + y_dict[num] = {"pos": num, "values": []} + + for num in range(len(x_sorted[0])): + y_dict[num]["mean"] = np.mean([y[num] for y in y_sorted_all]) + y_dict[num]["stdv"] = np.std([y[num] for y in y_sorted_all]) + return x_dict, y_dict + + +def _get_trendline_params(trend_dict): + x_dict, y_dict = _get_trendline_dict(trend_dict) + + x_values = [] + for key in x_dict.keys(): + x_values.append(min(x_dict[key]["values"])) + x_values.append(max(x_dict[key]["values"])) + + y_values_center = [] + y_values_upper = [] + y_values_lower = [] + for key in y_dict.keys(): + y_values_center.append(y_dict[key]["mean"]) + y_values_center.append(y_dict[key]["mean"]) + + y_values_upper.append(y_dict[key]["mean"] + y_dict[key]["stdv"]) + y_values_upper.append(y_dict[key]["mean"] + y_dict[key]["stdv"]) + + y_values_lower.append(y_dict[key]["mean"] - y_dict[key]["stdv"]) + y_values_lower.append(y_dict[key]["mean"] - y_dict[key]["stdv"]) + + return x_values, y_values_center, y_values_upper, y_values_lower + + +def fig_03c_octave(tonotopic_data, save_path, plot=False, use_alias=True, trendline=False): ihc_version = "ihc_counts_v4c" + prism_style() tables = glob(os.path.join(SYNAPSE_DIR_ROOT, ihc_version, "ihc_count_M_LR*.tsv")) assert len(tables) == 4, len(tables) + label_size = 20 result = {"cochlea": [], "octave_band": [], "value": []} + color_dict = {} for name, values in tonotopic_data.items(): if use_alias: alias = COCHLEAE_DICT[name]["alias"] else: alias = name.replace("_", "").replace("0", "") + color_dict[alias] = COCHLEAE_DICT[name]["color"] freq = values["frequency[kHz]"].values syn_count = values["syn_per_IHC"].values octave_binned = frequency_mapping(freq, syn_count, animal="mouse") @@ -207,12 +293,18 @@ def fig_03c_octave(tonotopic_data, save_path, plot=False, use_alias=True, trendl result["x_pos"] = result["octave_band"].map(band_to_x) fig, ax = plt.subplots(figsize=(8, 4)) + + offset = 0.08 + y_values = [] trend_dict = {} - for name, grp in result.groupby("cochlea"): - ax.scatter(grp["x_pos"], grp["value"], label=name, s=60, alpha=0.8) + for num, (name, grp) in enumerate(result.groupby("cochlea")): + x_sorted = grp["x_pos"] + x_positions = [x - len(grp["x_pos"]) // 2 * offset + offset * num for x in grp["x_pos"]] + ax.scatter(x_positions, grp["value"], marker="o", label=name, s=80, alpha=1, color=color_dict[name]) + + # y_values.append(list(grp["value"])) - if trendlines: - x_positions = grp["x_pos"] + if trendline: sorted_idx = np.argsort(x_positions) x_sorted = np.array(x_positions)[sorted_idx] y_sorted = np.array(grp["value"])[sorted_idx] @@ -220,41 +312,47 @@ def fig_03c_octave(tonotopic_data, save_path, plot=False, use_alias=True, trendl "y_sorted": y_sorted, } - if trendlines: - def get_trendline_values(trend_dict): - x_sorted = [trend_dict[k]["x_sorted"] for k in trend_dict.keys()][0] - y_sorted_all = [trend_dict[k]["y_sorted"] for k in trend_dict.keys()] - y_sorted = [] - for num in range(len(x_sorted)): - y_sorted.append(np.mean([y[num] for y in y_sorted_all])) - return x_sorted, y_sorted - - # Trendline left - x_sorted, y_sorted = get_trendline_values(trend_dict) + ax.set_xticks(range(len(bin_labels))) + ax.set_xticklabels(bin_labels) + ax.set_xlabel("Octave band [kHz]", fontsize=label_size) - trend, = ax.plot( + # central line + if trendline: + #mean, std = _get_trendline_params(y_values) + x_sorted, y_sorted, y_sorted_upper, y_sorted_lower = _get_trendline_params(trend_dict) + trend_center, = ax.plot( x_sorted, y_sorted, linestyle="dotted", - color="grey", - alpha=0.7 + color="gray", + alpha=0.6, + linewidth=3, + zorder=2 ) + # y_sorted_upper = np.array(mean) + np.array(std) + # y_sorted_lower = np.array(mean) - np.array(std) + # upper and lower standard deviation + trend_upper, = ax.plot( + x_sorted, + y_sorted_upper, + linestyle="solid", + color="gray", + alpha=0.08, + zorder=0 + ) + trend_lower, = ax.plot( + x_sorted, + y_sorted_lower, + linestyle="solid", + color="gray", + alpha=0.08, + zorder=0 + ) + plt.fill_between(x_sorted, y_sorted_lower, y_sorted_upper, + color="gray", alpha=0.05, interpolate=True) - # trendline_legend = ax.legend(handles=[trend], loc='lower center') - # trendline_legend = ax.legend( - # handles=[trend], - # labels=["Trendline"], - # loc="upper left" - # ) - # # Add the legend manually to the Axes. - # ax.add_artist(trendline_legend) - - ax.set_xticks(range(len(bin_labels))) - ax.set_xticklabels(bin_labels) - ax.set_xlabel("Octave band (kHz)") - - ax.set_ylabel("Average Ribbon Synapse Count per IHC", fontsize=10) - plt.legend(title="Cochlea") + ax.set_ylabel("Ribbon synapses per IHC") + plt.tight_layout() prism_cleanup_axes(ax) if ".png" in save_path: @@ -308,9 +406,6 @@ def fig_03d_fraction(save_path, plot): results["fraction"].append(subtype_fraction) results["cochlea"].append(cochlea) - # coexpr = np.logical_and(subtype_table.iloc[:, 0].values, subtype_table.iloc[:, 1].values) - # print("Co-expression:", coexpr.sum()) - results = pd.DataFrame(results) fig, ax = plt.subplots() for cochlea, group in results.groupby("cochlea"): @@ -338,22 +433,23 @@ def fig_03d_octave(save_path, plot): def main(): parser = argparse.ArgumentParser(description="Generate plots for Fig 3 of the cochlea paper.") parser.add_argument("--figure_dir", "-f", type=str, help="Output directory for plots.", default="./panels/fig3") + parser.add_argument("--napari", action="store_true", help="Visualize tonotopic mapping in napari.") parser.add_argument("--plot", action="store_true") args = parser.parse_args() os.makedirs(args.figure_dir, exist_ok=True) tonotopic_data = get_tonotopic_data() - # Panel A: Tonotopic mapping of SGNs and IHCs (rendering in napari + heatmap) - # fig_03a(save_path=os.path.join(args.figure_dir, f"fig_03a_cmap.{FILE_EXTENSION}"), - # plot=args.plot, plot_napari=True) + # Panel C: Tonotopic mapping of SGNs and IHCs (rendering in napari + heatmap) + cmap = "plasma" + fig_03a(save_path=os.path.join(args.figure_dir, f"fig_03a_cmap_{cmap}.{FILE_EXTENSION}"), + plot=args.plot, plot_napari=args.napari, cmap=cmap) - # Panel C: Spatial distribution of synapses across the cochlea. - # We have two options: running sum over the runlength or per octave band - # fig_03c_rl(save_path=os.path.join(args.figure_dir, f"fig_03c_runlength.{FILE_EXTENSION}"), plot=args.plot) + # Panel C: Spatial distribution of synapses across the cochlea (running sum per octave band) fig_03c_octave(tonotopic_data=tonotopic_data, save_path=os.path.join(args.figure_dir, f"fig_03c_octave.{FILE_EXTENSION}"), - plot=args.plot, trendlines=True) + plot=args.plot, trendline=True) + plot_legend_fig03c(save_path=os.path.join(args.figure_dir, f"fig_03c_legend.{FILE_EXTENSION}")) # Panel D: Spatial distribution of SGN sub-types. # fig_03d_fraction(save_path=os.path.join(args.figure_dir, f"fig_03d_fraction.{FILE_EXTENSION}"), plot=args.plot) diff --git a/scripts/figures/plot_fig4.py b/scripts/figures/plot_fig4.py index 91a2f7b..8692d73 100644 --- a/scripts/figures/plot_fig4.py +++ b/scripts/figures/plot_fig4.py @@ -4,11 +4,12 @@ import pickle import matplotlib.pyplot as plt +import matplotlib.ticker as mticker import numpy as np import pandas as pd from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target -from util import frequency_mapping, prism_style, prism_cleanup_axes # , literature_reference_values +from util import frequency_mapping, prism_style, prism_cleanup_axes, export_legend, custom_formatter_1 # from statsmodels.nonparametric.smoothers_lowess import lowess @@ -32,9 +33,39 @@ "G_EK_000049_R": {"alias": "G1R", "component": [1, 2]}, } +COLORS_ANIMAL = { + "M05": "#9C5027", + "M06": "#279C52", + "M07": "#67279C", + "M08": "#27339C", + "M09": "#9C276F" +} + +COLORS_LEFT = { + "M05R": "#A600FF", + "M06R": "#8F00DB", + "M07R": "#7D1DB1", + "M08R": "#672D86", + "M09R": "#4C2E5C" +} + +COLORS_RIGHT = { + "M05L": "#FF0063", + "M06L": "#DB0063", + "M07L": "#B11D60", + "M08L": "#862D55", + "M09L": "#5C2E43" +} + FILE_EXTENSION = "png" png_dpi = 300 +COLOR_LEFT = "#8E00DB" +COLOR_RIGHT = "#DB0063" +COLOR_UNTREATED = "#DB7B00" +MARKER_LEFT = "o" +MARKER_RIGHT = "^" + def get_chreef_data(animal="mouse"): s3 = create_s3_target() @@ -110,13 +141,116 @@ def group_lr(names_lr, values): return names, values_left, values_right -def fig_04c(chreef_data, save_path, plot=False, plot_by_side=False, use_alias=True): +def plot_legend(chreef_data, save_path, grouping="side_mono", use_alias=True, + alignment="horizontal"): + """Plot common legend for figures 4c, 4d, and 4e. + + Args: + chreef_data: Data of ChReef cochleae. + save_path: save path to save legend. + grouping: Grouping for cochleae. + "side_mono" for division in Injected and Non-Injected. + "side_multi" for division per cochlea. + "animal" for division per animal. + use_alias: Use alias. + """ + if use_alias: + alias = [COCHLEAE_DICT[k]["alias"] for k in chreef_data.keys()] + else: + alias = [name.replace("_", "").replace("0", "") for name in chreef_data.keys()] + + sgns = [len(vals) for vals in chreef_data.values()] + alias, values_left, values_right = group_lr(alias, sgns) + + colors = ["crimson", "purple", "gold"] + if grouping == "side_mono": + colors = [COLOR_LEFT, COLOR_RIGHT] + labels = ["Injected", "Non-Injected"] + markers = [MARKER_LEFT, MARKER_RIGHT] + ncol = 2 + + elif grouping == "side_multi": + colors = [] + labels = [] + markers = [] + keys_left = list(COLORS_LEFT.keys()) + keys_right = list(COLORS_RIGHT.keys()) + for num in range(len(COLORS_LEFT)): + colors.append(COLORS_LEFT[keys_left[num]]) + colors.append(COLORS_RIGHT[keys_right[num]]) + labels.append(f"{alias[num]}L") + labels.append(f"{alias[num]}R") + markers.append(MARKER_LEFT) + markers.append(MARKER_RIGHT) + if alignment == "vertical": + colors = colors[::2] + colors[1::2] + labels = labels[::2] + labels[1::2] + markers = markers[::2] + markers[1::2] + ncol = 2 + else: + ncol = 5 + + elif grouping == "animal": + colors = [] + labels = [] + markers = [] + ncol = 5 + keys_animal = list(COLORS_ANIMAL.keys()) + for num in range(len(COLORS_ANIMAL)): + colors.append(COLORS_ANIMAL[keys_animal[num]]) + colors.append(COLORS_ANIMAL[keys_animal[num]]) + labels.append(f"{alias[num]}L") + labels.append(f"{alias[num]}R") + markers.append(MARKER_LEFT) + markers.append(MARKER_RIGHT) + if alignment == "vertical": + colors = colors[::2] + colors[1::2] + labels = labels[::2] + labels[1::2] + markers = markers[::2] + markers[1::2] + ncol = 2 + else: + ncol = 5 + + else: + raise ValueError("Choose a correct 'grouping' parameter.") + + f = lambda m, c: plt.plot([], [], marker=m, color=c, ls="none")[0] + handles = [f(marker, color) for (color, marker) in zip(colors, markers)] + legend = plt.legend(handles, labels, loc=3, ncol=ncol, framealpha=1, frameon=False) + + export_legend(legend, save_path) + legend.remove() + plt.close() + + +def plot_legend_fig05e_gerbil(save_path): + """Plot common legend for figure 5e gerbil. + + Args: + chreef_data: Data of ChReef cochleae. + save_path: save path to save legend. + grouping: Grouping for cochleae. + "side_mono" for division in Injected and Non-Injected. + "side_multi" for division per cochlea. + "animal" for division per animal. + use_alias: Use alias. + """ + # Shapes + color = [COLOR_LEFT, COLOR_RIGHT] + marker = ["o", "^"] + label = ["G1L", "G1R"] + + f = lambda m, c: plt.plot([], [], marker=m, color=c, ls="none")[0] + handles = [f(m, c) for (c, m) in zip(color, marker)] + legend = plt.legend(handles, label, loc=3, ncol=len(label), framealpha=1, frameon=False) + export_legend(legend, save_path) + legend.remove() + plt.close() + + +def fig_04c(chreef_data, save_path, plot=False, grouping="side_mono", use_alias=True): """Box plot showing the SGN counts of ChReef treated cochleae compared to healthy ones. """ - # Previous version with hard-coded values. - # cochlea = ["M_LR_000144_L", "M_LR_000145_L", "M_LR_000151_R"] - # alias = ["c01", "c02", "c03"] - # sgns = [7796, 6119, 9225] prism_style() # TODO have central function for alias for all plots? @@ -127,42 +261,71 @@ def fig_04c(chreef_data, save_path, plot=False, plot_by_side=False, use_alias=Tr sgns = [len(vals) for vals in chreef_data.values()] - if plot_by_side: - alias, sgns_left, sgns_right = group_lr(alias, sgns) - - x = np.arange(len(alias)) + alias, values_left, values_right = group_lr(alias, sgns) # Plot - fig, ax = plt.subplots(figsize=(5, 5)) + fig, ax = plt.subplots(figsize=(4, 5)) main_label_size = 20 sub_label_size = 16 main_tick_size = 16 - legendsize = 12 - if plot_by_side: - plt.scatter(x, sgns_left, label="Injected", marker="o", s=80) - plt.scatter(x, sgns_right, label="Non-Injected", marker="x", s=80) + offset = 0.08 + x_left = 1 + x_right = 2 + y_ticks = [i for i in range(6000, 13000, 2000)] + + x_pos_inj = [x_left - len(values_left) // 2 * offset + offset * i for i in range(len(values_left))] + x_pos_non = [x_right - len(values_right) // 2 * offset + offset * i for i in range(len(values_right))] + + # lines between cochleae of same animal + for num, (left, right) in enumerate(zip(values_left, values_right)): + ax.plot( + [x_pos_inj[num], x_pos_non[num]], + [left, right], + linestyle="solid", + color="grey", + alpha=0.4, + zorder=0 + ) + + if grouping == "side_mono": + plt.scatter(x_pos_inj, values_left, label="Injected", + color=COLOR_LEFT, marker=MARKER_LEFT, s=80, zorder=1) + plt.scatter(x_pos_non, values_right, label="Non-Injected", + color=COLOR_RIGHT, marker=MARKER_RIGHT, s=80, zorder=1) + + elif grouping == "side_multi": + for num, key in enumerate(COLORS_LEFT.keys()): + plt.scatter(x_pos_inj[num], values_left[num], label=f"{alias[num]}L", + color=COLORS_LEFT[key], marker=MARKER_LEFT, s=80, zorder=1) + for num, key in enumerate(COLORS_RIGHT.keys()): + plt.scatter(x_pos_non[num], values_right[num], label=f"{alias[num]}R", + color=COLORS_RIGHT[key], marker=MARKER_RIGHT, s=80, zorder=1) + + elif grouping == "animal": + for num, key in enumerate(COLORS_ANIMAL.keys()): + plt.scatter(x_pos_inj[num], values_left[num], label=f"{alias[num]}", + color=COLORS_ANIMAL[key], marker=MARKER_LEFT, s=80, zorder=1) + plt.scatter(x_pos_non[num], values_right[num], + color=COLORS_ANIMAL[key], marker=MARKER_RIGHT, s=80, zorder=1) + else: - plt.scatter(x, sgns, label="SGN count", marker="o", s=80) + raise ValueError("Choose a correct 'grouping' parameter.") # Labels and formatting - plt.xticks(x, alias, fontsize=sub_label_size) - plt.yticks(fontsize=main_tick_size) + plt.xticks([x_left, x_right], ["Injected", "Non-\nInjected"], fontsize=sub_label_size) + for label in plt.gca().get_xticklabels(): + label.set_verticalalignment('center') + ax.tick_params(axis='x', which='major', pad=16) + plt.yticks(y_ticks, fontsize=main_tick_size) plt.ylabel("SGN count per cochlea", fontsize=main_label_size) - plt.ylim(4000, 15800) - plt.legend(loc="upper right", fontsize=legendsize) + plt.ylim(5000, 14000) - xmin = -0.5 - xmax = len(alias) - 0.5 + xmin = 0.5 + xmax = 2.5 plt.xlim(xmin, xmax) - # set range of literature values - # lower_y, upper_y = literature_reference_values("SGN") - # plt.hlines([lower_y, upper_y], xmin, xmax) - # plt.text(1.5, lower_y - 400, "literature", color="C0", fontsize=main_tick_size, ha="center") - # plt.fill_between([xmin, xmax], lower_y, upper_y, color="C0", alpha=0.05, interpolate=True) - sgn_values = [11153, 11398, 10333, 11820] sgn_value = np.mean(sgn_values) sgn_std = np.std(sgn_values) @@ -170,10 +333,12 @@ def fig_04c(chreef_data, save_path, plot=False, plot_by_side=False, use_alias=Tr upper_y = sgn_value + 1.96 * sgn_std lower_y = sgn_value - 1.96 * sgn_std - plt.hlines([lower_y, upper_y], xmin, xmax, colors=["C1" for _ in range(2)]) - plt.text(2, upper_y + 200, "untreated cochleae\n(95% confidence interval)", - color="C1", fontsize=14, ha="center") - plt.fill_between([xmin, xmax], lower_y, upper_y, color="C1", alpha=0.05, interpolate=True) + c_untreated = COLOR_UNTREATED + + plt.hlines([lower_y, upper_y], xmin, xmax, colors=[c_untreated for _ in range(2)], zorder=-1) + plt.text((xmin + xmax) / 2, upper_y + 200, "untreated cochleae\n(95% confidence interval)", + color=c_untreated, fontsize=11, ha="center") + plt.fill_between([xmin, xmax], lower_y, upper_y, color=c_untreated, alpha=0.05, interpolate=True) plt.tight_layout() @@ -190,7 +355,8 @@ def fig_04c(chreef_data, save_path, plot=False, plot_by_side=False, use_alias=Tr plt.close() -def fig_04d(chreef_data, save_path, plot=False, plot_by_side=False, intensity=False, gerbil=False, use_alias=True): +def fig_04d(chreef_data, save_path, plot=False, grouping="animal", + intensity=False, gerbil=False, use_alias=True): """Transduction efficiency per cochlea. """ prism_style() @@ -215,37 +381,79 @@ def fig_04d(chreef_data, save_path, plot=False, plot_by_side=False, intensity=Fa eff = float(n_pos) / (n_pos + n_neg) values.append(eff) - if plot_by_side: - alias, values_left, values_right = group_lr(alias, values) - - x = np.arange(len(alias)) + alias, values_left, values_right = group_lr(alias, values) # Plot - fig, ax = plt.subplots(figsize=(5, 5)) + fig, ax = plt.subplots(figsize=(4, 5)) main_label_size = 20 sub_label_size = 16 main_tick_size = 16 - legendsize = 12 - label = "Intensity" if intensity else "Transduction efficiency" + label = "Intensity" if intensity else "Expression efficiency" + x_left = 1 + x_right = 2 + offset = 0.08 + + x_pos_inj = [x_left - len(values_left) // 2 * offset + offset * i for i in range(len(values_left))] + x_pos_non = [x_right - len(values_right) // 2 * offset + offset * i for i in range(len(values_right))] + + if grouping == "side_mono": + plt.scatter(x_pos_inj, values_left, label="Injected", + color=COLOR_LEFT, marker=MARKER_LEFT, s=80, zorder=1) + plt.scatter(x_pos_non, values_right, label="Non-Injected", + color=COLOR_RIGHT, marker=MARKER_RIGHT, s=80, zorder=1) + + elif grouping == "side_multi": + for num, key in enumerate(COLORS_LEFT.keys()): + plt.scatter(x_pos_inj[num], values_left[num], label=f"{alias[num]}L", + color=COLORS_LEFT[key], marker=MARKER_LEFT, s=80, zorder=1) + for num, key in enumerate(COLORS_RIGHT.keys()): + plt.scatter(x_pos_non[num], values_right[num], label=f"{alias[num]}R", + color=COLORS_RIGHT[key], marker=MARKER_RIGHT, s=80, zorder=1) + + elif grouping == "animal": + for num, key in enumerate(COLORS_ANIMAL.keys()): + plt.scatter(x_pos_inj[num], values_left[num], label=f"{alias[num]}", + color=COLORS_ANIMAL[key], marker=MARKER_LEFT, s=80, zorder=1) + plt.scatter(x_pos_non[num], values_right[num], + color=COLORS_ANIMAL[key], marker=MARKER_RIGHT, s=80, zorder=1) - if plot_by_side: - plt.scatter(x, values_left, label="Injected", marker="o", s=80) - plt.scatter(x, values_right, label="Non-Injected", marker="x", s=80) else: - plt.scatter(x, values, label=label, marker="o", s=80) + raise ValueError("Choose a correct 'grouping' parameter.") + + # lines between cochleae of same animal + for num, (left, right) in enumerate(zip(values_left, values_right)): + ax.plot( + [x_pos_inj[num], x_pos_non[num]], + [left, right], + linestyle="solid", + color="grey", + alpha=0.4, + zorder=0 + ) - # Labels and formatting - plt.xticks(x, alias, fontsize=sub_label_size) - plt.yticks(fontsize=main_tick_size) - plt.ylabel(label, fontsize=main_label_size) - plt.legend(loc="upper right", fontsize=legendsize) if not intensity: if gerbil: - plt.ylim(0.3, 1.05) + plt.ylim(0.25, 0.65) + plt.yticks(np.arange(0.3, 0.7, 0.1), fontsize=main_tick_size) else: - plt.ylim(0.5, 1.05) + plt.ylim(0.65, 1.05) + plt.yticks(np.arange(0.7, 1, 0.1), fontsize=main_tick_size) + + # Labels and formatting + plt.xticks([x_left, x_right], ["Injected", "Non-\nInjected"], fontsize=sub_label_size) + for la in plt.gca().get_xticklabels(): + la.set_verticalalignment('center') + ax.tick_params(axis='x', which='major', pad=16) + plt.ylabel(label, fontsize=main_label_size) + ax.yaxis.set_major_formatter(mticker.FuncFormatter(custom_formatter_1)) + + xmin = 0.5 + xmax = 2.5 + plt.xlim(xmin, xmax) + + # plt.legend(loc="upper right", fontsize=legendsize) plt.tight_layout() prism_cleanup_axes(ax) @@ -261,10 +469,63 @@ def fig_04d(chreef_data, save_path, plot=False, plot_by_side=False, intensity=Fa plt.close() -def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False, use_alias=True, trendlines=False): +def _get_trendline_dict(trend_dict, side): + x_sorted = [trend_dict[k]["x_sorted"] for k in trend_dict.keys() if trend_dict[k]["side"] == side] + x_dict = {} + for num in range(len(x_sorted[0])): + x_dict[num] = {"pos": num, "values": []} + + for s in x_sorted: + for num, pos in enumerate(s): + x_dict[num]["values"].append(pos) + + y_sorted_all = [trend_dict[k]["y_sorted"] for k in trend_dict.keys() if trend_dict[k]["side"] == side] + y_dict = {} + for num in range(len(x_sorted[0])): + y_dict[num] = {"pos": num, "values": []} + + for num in range(len(x_sorted[0])): + y_dict[num]["mean"] = np.mean([y[num] for y in y_sorted_all]) + y_dict[num]["stdv"] = np.std([y[num] for y in y_sorted_all]) + return x_dict, y_dict + + +def _get_trendline_params(trend_dict, side): + x_dict, y_dict = _get_trendline_dict(trend_dict, side) + + x_values = [] + for key in x_dict.keys(): + x_values.append(min(x_dict[key]["values"])) + x_values.append(max(x_dict[key]["values"])) + + y_values_center = [] + y_values_upper = [] + y_values_lower = [] + for key in y_dict.keys(): + y_values_center.append(y_dict[key]["mean"]) + y_values_center.append(y_dict[key]["mean"]) + + y_values_upper.append(y_dict[key]["mean"] + y_dict[key]["stdv"]) + y_values_upper.append(y_dict[key]["mean"] + y_dict[key]["stdv"]) + + y_values_lower.append(y_dict[key]["mean"] - y_dict[key]["stdv"]) + y_values_lower.append(y_dict[key]["mean"] - y_dict[key]["stdv"]) + + return x_values, y_values_center, y_values_upper, y_values_lower + + +def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False, + use_alias=True, trendlines=False, grouping="side_mono", + trendline_std=False): prism_style() + if gerbil: + animal = "gerbil" + else: + animal = "mouse" + result = {"cochlea": [], "octave_band": [], "value": []} + aliases = [] for name, values in chreef_data.items(): if use_alias: alias = COCHLEAE_DICT[name]["alias"] @@ -274,68 +535,122 @@ def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False, use_ali freq = values["frequency[kHz]"].values if intensity: intensity_values = values["median"].values - octave_binned = frequency_mapping(freq, intensity_values, animal="mouse") + octave_binned = frequency_mapping(freq, intensity_values, animal=animal) else: marker_labels = values["marker_labels"].values - octave_binned = frequency_mapping(freq, marker_labels, animal="mouse", transduction_efficiency=True) + octave_binned = frequency_mapping(freq, marker_labels, animal=animal, transduction_efficiency=True) result["cochlea"].extend([alias] * len(octave_binned)) result["octave_band"].extend(octave_binned.axes[0].values.tolist()) result["value"].extend(octave_binned.values.tolist()) + aliases.append(alias) + + if gerbil: + values = [] + for vals in chreef_data.values(): + if intensity: + intensities = vals["median"].values + values.append(intensities.mean()) + else: + # marker labels + # 0: unlabeled - no median intensity in object_measures table + # 1: positive + # 2: negative + marker_labels = vals["marker_labels"].values + n_pos = (marker_labels == 1).sum() + n_neg = (marker_labels == 2).sum() + eff = float(n_pos) / (n_pos + n_neg) + values.append(eff) + alias, values_left, values_right = group_lr(aliases, values) + print(f"Average expression efficiency left: {round(values_left[0], 4)}") + print(f"Average expression efficiency right: {round(values_right[0], 4)}") result = pd.DataFrame(result) bin_labels = pd.unique(result["octave_band"]) band_to_x = {band: i for i, band in enumerate(bin_labels)} result["x_pos"] = result["octave_band"].map(band_to_x) - fig, ax = plt.subplots(figsize=(8, 4)) + fig, ax = plt.subplots(figsize=(8, 5)) sub_tick_label_size = 8 - tick_label_size = 12 - label_size = 12 + tick_label_size = 14 + yaxis_tick_size = 16 + label_size = 20 legend_size = 8 if intensity: band_label_offset_y = 0.09 else: band_label_offset_y = 0.09 if gerbil: - ax.set_ylim(0.05, 1.05) + ymin = 0.1 + ymax = 0.81 + ax.set_ylim(0.05, 0.95) else: + ymin = 0.5 + ymax = 1.01 ax.set_ylim(0.45, 1.05) # Offsets within each octave band - offset_map = {"L": -0.15, "R": 0.15} + offset_map = {"L": -0.2, "R": 0.2} # Assign a color to each cochlea (ignoring side) cochleas = sorted({name_lr[:-1] for name_lr in result["cochlea"].unique()}) - colors = plt.cm.tab10.colors # pick a colormap - color_map = {cochlea: colors[i % len(colors)] for i, cochlea in enumerate(cochleas)} + + if grouping == "side_mono": + colors_l = [COLOR_LEFT for _ in range(5)] + colors_r = [COLOR_RIGHT for _ in range(5)] + + elif grouping == "side_multi": + colors_l = [COLORS_LEFT[key] for key in COLORS_LEFT.keys()] + colors_r = [COLORS_RIGHT[key] for key in COLORS_RIGHT.keys()] + + elif grouping == "animal": + colors_l = [COLORS_ANIMAL[key] for key in COLORS_ANIMAL.keys()] + colors_r = [COLORS_ANIMAL[key] for key in COLORS_ANIMAL.keys()] + + else: + raise ValueError("Choose a correct 'grouping' parameter.") + + color_map = {} + count_l = 0 + count_r = 0 + for num, (name_lr, grp) in enumerate(result.groupby("cochlea")): + name, side = name_lr[:-1], name_lr[-1] + if side == "L": + color_map[name_lr] = colors_l[count_l] + count_l += 1 + else: + color_map[name_lr] = colors_r[count_r] + count_r += 1 + if len(cochleas) == 1: - color_map = {"L": colors[0], "R": colors[1]} + color_map = {"L": colors_l[0], "R": colors_r[1]} # Track which cochlea names we have already added to the legend legend_added = set() + offset = 0.018 trend_dict = {} - for name_lr, grp in result.groupby("cochlea"): + for num, (name_lr, grp) in enumerate(result.groupby("cochlea")): name, side = name_lr[:-1], name_lr[-1] if len(cochleas) == 1: label_name = name_lr color = color_map[side] else: label_name = name - color = color_map[name] + color = color_map[name_lr] - x_positions = grp["x_pos"] + offset_map[side] + x_positions = grp["x_pos"] + offset_map[side] - len(cochleas) / 2 * offset + offset * num ax.scatter( x_positions, grp["value"], label=label_name if label_name not in legend_added else None, s=60, alpha=0.8, - marker="o" if side == "L" else "x", + marker=MARKER_LEFT if side == "L" else MARKER_RIGHT, color=color, + zorder=1 ) if name not in legend_added: @@ -351,46 +666,137 @@ def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False, use_ali } if trendlines: - def get_trendline_values(trend_dict, side): - x_sorted = [trend_dict[k]["x_sorted"] for k in trend_dict.keys() if trend_dict[k]["side"] == side][0] - y_sorted_all = [trend_dict[k]["y_sorted"] for k in trend_dict.keys() if trend_dict[k]["side"] == side] - y_sorted = [] - for num in range(len(x_sorted)): - y_sorted.append(np.mean([y[num] for y in y_sorted_all])) - return x_sorted, y_sorted - - # Trendline left - x_sorted, y_sorted = get_trendline_values(trend_dict, "L") - - trend_l, = ax.plot( - x_sorted, - y_sorted, - linestyle="dotted", - color="grey", - alpha=0.7 - ) + trendline_width = 3 + if not gerbil: + x_sorted, y_sorted, y_sorted_upper, y_sorted_lower = _get_trendline_params(trend_dict, "L") + + if grouping == "animal": + color_trend_l = "gray" + color_trend_r = "gray" + else: + color_trend_l = COLOR_LEFT + color_trend_r = COLOR_RIGHT + + # central line + trend_l, = ax.plot( + x_sorted, + y_sorted, + linestyle="dashed", + color=color_trend_l, + alpha=0.6, + linewidth=trendline_width, + zorder=2 + ) + + if trendline_std: + # upper and lower standard deviation + trend_l_upper, = ax.plot( + x_sorted, + y_sorted_upper, + linestyle="solid", + color=color_trend_l, + alpha=0.08, + zorder=0 + ) + trend_l_lower, = ax.plot( + x_sorted, + y_sorted_lower, + linestyle="solid", + color=color_trend_l, + alpha=0.08, + zorder=0 + ) + plt.fill_between(x_sorted, y_sorted_lower, y_sorted_upper, + color=COLOR_LEFT, alpha=0.05, interpolate=True) + + # Trendline Non-Injected (Right) + x_sorted, y_sorted, y_sorted_upper, y_sorted_lower = _get_trendline_params(trend_dict, "R") + # central line + trend_r, = ax.plot( + x_sorted, + y_sorted, + linestyle="dotted", + color=color_trend_r, + alpha=0.7, + linewidth=trendline_width, + zorder=0 + ) + + if trendline_std: + # upper and lower standard deviation + trend_r_upper, = ax.plot( + x_sorted, + y_sorted_upper, + linestyle="solid", + color=color_trend_r, + alpha=0.08, + zorder=0 + ) + trend_r_lower, = ax.plot( + x_sorted, + y_sorted_lower, + linestyle="solid", + color=color_trend_r, + alpha=0.08, + zorder=0 + ) + plt.fill_between(x_sorted, y_sorted_lower, y_sorted_upper, + color=COLOR_RIGHT, alpha=0.05, interpolate=True) + + # Trendline legend + trendline_legend = ax.legend(handles=[trend_l, trend_r], loc='lower center') + trendline_legend = ax.legend( + handles=[trend_l, trend_r], + labels=["Injected", "Non-Injected"], + loc="lower left", + fontsize=legend_size, + title="Trendlines" + ) + # Add the legend manually to the Axes. + ax.add_artist(trendline_legend) - x_sorted, y_sorted = get_trendline_values(trend_dict, "R") - trend_r, = ax.plot( - x_sorted, - y_sorted, - linestyle="dashed", - color="grey", - alpha=0.7 - ) - trendline_legend = ax.legend(handles=[trend_l, trend_r], loc='lower center') - trendline_legend = ax.legend( - handles=[trend_l, trend_r], - labels=["Injected", "Non-Injected"], - loc="lower center", - fontsize=legend_size, - title="Trendlines" - ) - # Add the legend manually to the Axes. - ax.add_artist(trendline_legend) + else: + x_sorted = [trend_dict[k]["x_sorted"] for k in trend_dict.keys() if trend_dict[k]["side"] == "L"][0] + y_left = [values_left[0], values_left[0]] + y_right = [values_right[0], values_right[0]] + xlim_left, xlim_right = ax.get_xlim() + if grouping == "animal": + color_trend_l = "gray" + color_trend_r = "gray" + else: + color_trend_l = COLOR_LEFT + color_trend_r = COLOR_RIGHT + trend_l, = ax.plot( + [xlim_left, xlim_right], + y_left, + linestyle="dotted", + color=color_trend_l, + alpha=0.7, + zorder=0 + ) + x_offset = 0.5 + y_offset = 0.01 + ax.text(xlim_left + x_offset, y_left[0] + y_offset, "mean", + color=color_trend_l, fontsize=tick_label_size, ha="center") + ax.text(xlim_left + x_offset, y_right[0] + y_offset, "mean", + color=color_trend_r, fontsize=tick_label_size, ha="center") + x_sorted = [trend_dict[k]["x_sorted"] for k in trend_dict.keys() if trend_dict[k]["side"] == "R"][0] + plt.xlim(xlim_left, xlim_right) + trend_r, = ax.plot( + [xlim_left, xlim_right], + y_right, + linestyle="dashed", + color=color_trend_r, + alpha=0.7, + zorder=0 + ) # Create combined tick positions & labels main_ticks = range(len(bin_labels)) + ax.yaxis.set_major_formatter(mticker.FuncFormatter(custom_formatter_1)) + plt.yticks(np.arange(ymin, ymax, 0.1), fontsize=yaxis_tick_size) + plt.grid(axis="y", linestyle="solid", alpha=0.5) + # add a final tick for label '>64k' ax.set_xticks([pos + offset_map["L"] for pos in main_ticks] + [pos + offset_map["R"] for pos in main_ticks]) @@ -401,16 +807,14 @@ def get_trendline_values(trend_dict, side): ax.text(i, ax.get_ylim()[0] - band_label_offset_y*(ax.get_ylim()[1]-ax.get_ylim()[0]), label, ha='center', va='top', fontsize=tick_label_size, fontweight='bold') - ax.set_xlabel("Octave band (kHz)", fontsize=label_size) + ax.set_xlabel("Octave band [kHz]", fontsize=label_size) ax.xaxis.set_label_coords(.5, -.16) if intensity: ax.set_ylabel("Marker Intensity", fontsize=label_size) ax.set_title("Intensity per octave band (Left/Right)") else: - ax.set_ylabel("Transduction Efficiency", fontsize=label_size) - - ax.legend(title="Cochlea", fontsize=legend_size) + ax.set_ylabel("Expression efficiency", fontsize=label_size) plt.tight_layout() prism_cleanup_axes(ax) @@ -444,44 +848,38 @@ def main(): chreef_data.pop("M_LR_000143_R") # Create the panels: + grouping = "animal" + plot_legend(chreef_data, grouping=grouping, + save_path=os.path.join(args.figure_dir, f"fig_04_legend_{grouping}.{FILE_EXTENSION}")) # C: The SGN count compared to reference values from literature and healthy # Maybe remove literature reference from plot? fig_04c(chreef_data, save_path=os.path.join(args.figure_dir, f"fig_04c.{FILE_EXTENSION}"), - plot=args.plot, plot_by_side=True, use_alias=use_alias) + plot=args.plot, grouping=grouping, use_alias=use_alias) - # D: The transduction efficiency. We also plot GFP intensities. + # D: The expression efficiency. We also plot GFP intensities. fig_04d(chreef_data, save_path=os.path.join(args.figure_dir, f"fig_04d_transduction.{FILE_EXTENSION}"), - plot=args.plot, plot_by_side=True, use_alias=use_alias) - # fig_04d(chreef_data, - # save_path=os.path.join(args.figure_dir, f"fig_04d_intensity.{FILE_EXTENSION}"), - # plot=args.plot, plot_by_side=True, intensity=True, use_alias=use_alias) + plot=args.plot, grouping=grouping, use_alias=use_alias) + # E: The expression efficiency per octave band. + # trendlines without standard deviation fig_04e(chreef_data, save_path=os.path.join(args.figure_dir, f"fig_04e_transduction.{FILE_EXTENSION}"), - plot=args.plot, use_alias=use_alias, trendlines=True) - # fig_04e(chreef_data, - # save_path=os.path.join(args.figure_dir, f"fig_04e_intensity.{FILE_EXTENSION}"), - # plot=args.plot, intensity=True, use_alias=use_alias) + plot=args.plot, grouping=grouping, use_alias=use_alias, trendlines=True) + # trendlines with standard deviation + fig_04e(chreef_data, + save_path=os.path.join(args.figure_dir, f"fig_04e_transduction_std.{FILE_EXTENSION}"), + plot=args.plot, grouping=grouping, use_alias=use_alias, trendlines=True, trendline_std=True) + # Figures for gerbil (Figure 5) chreef_data_gerbil = get_chreef_data(animal="gerbil") - fig_04d(chreef_data_gerbil, - save_path=os.path.join(args.figure_dir, f"fig_04d_gerbil_transduction.{FILE_EXTENSION}"), - plot=args.plot, plot_by_side=True, gerbil=True, use_alias=use_alias) - - # fig_04d(chreef_data_gerbil, - # save_path=os.path.join(args.figure_dir, f"fig_04d_gerbil_intensity.{FILE_EXTENSION}"), - # plot=args.plot, plot_by_side=True, intensity=True, use_alias=use_alias) - fig_04e(chreef_data_gerbil, - save_path=os.path.join(args.figure_dir, f"fig_04e_gerbil_transduction.{FILE_EXTENSION}"), - plot=args.plot, gerbil=True, use_alias=use_alias) + save_path=os.path.join(args.figure_dir, f"fig_05e_gerbil_transduction.{FILE_EXTENSION}"), + plot=args.plot, gerbil=True, use_alias=use_alias, trendlines=True) - # fig_04e(chreef_data_gerbil, - # save_path=os.path.join(args.figure_dir, f"fig_04e_gerbil_intensity.{FILE_EXTENSION}"), - # plot=args.plot, intensity=True, use_alias=use_alias) + plot_legend_fig05e_gerbil(save_path=os.path.join(args.figure_dir, f"fig_05e_gerbil_legend.{FILE_EXTENSION}")) if __name__ == "__main__": diff --git a/scripts/figures/plot_fig5.py b/scripts/figures/plot_fig5.py index b98eb6b..3cd3491 100644 --- a/scripts/figures/plot_fig5.py +++ b/scripts/figures/plot_fig5.py @@ -6,111 +6,203 @@ import pandas as pd import matplotlib.pyplot as plt -from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target +from flamingo_tools.s3_utils import BUCKET_NAME +from util import literature_reference_values_gerbil, prism_cleanup_axes, prism_style, SYNAPSE_DIR_ROOT from util import SYNAPSE_DIR_ROOT -from plot_fig4 import get_chreef_data FILE_EXTENSION = "png" png_dpi = 300 +COLOR_LEFT = "#8E00DB" +COLOR_RIGHT = "#DB0063" +MARKER_LEFT = "o" +MARKER_RIGHT = "^" +COLOR_MEASUREMENT = "#9C7427" +COLOR_LITERATURE = "#27339C" +COLOR_UNTREATED = "#DB7B00" +# Load the synapse counts for all IHCs from the relevant tables. def _load_ribbon_synapse_counts(): - # TODO update the version! - ihc_version = "ihc_counts_v4b" - table_path = os.path.join(SYNAPSE_DIR_ROOT, ihc_version, "ihc_count_M_AMD_OTOF1_L.tsv") - x = pd.read_csv(table_path, sep="\t") - syn_counts = x.synapse_count.values.tolist() + ihc_version = "ihc_counts_v6" + synapse_dir = os.path.join(SYNAPSE_DIR_ROOT, ihc_version) + tables = [entry.path for entry in os.scandir(synapse_dir) if "ihc_count_G_" in entry.name] + print(f"Synapse count for tables {tables}.") + syn_counts = [] + for tab in tables: + x = pd.read_csv(tab, sep="\t") + syn_counts.extend(x["synapse_count"].values.tolist()) return syn_counts def fig_05c(save_path, plot=False): - """Bar plot showing the IHC count and distribution of synapse markers per IHC segmentation over OTOF cochlea. + """Box plot showing the counts for SGN and IHC per gerbil cochlea in comparison to literature values. """ - # TODO update the alias. - # For MOTOF1L - alias = "M10L" - + main_tick_size = 20 main_label_size = 20 - main_tick_size = 12 - htext_size = 10 + prism_style() - ribbon_synapse_counts = _load_ribbon_synapse_counts() + rows = 1 + columns = 3 - rows, columns = 1, 2 - fig, axes = plt.subplots(rows, columns, figsize=(columns*4, rows*4)) + fig, ax = plt.subplots(rows, columns, figsize=(8.5, 4.5)) - # - # Create the plot for IHCs. - # - ihc_values = [len(ribbon_synapse_counts)] + sgn_values = [18541] + ihc_values = [1180] - ylim0 = 600 - ylim1 = 800 - y_ticks = [i for i in range(600, 800 + 1, 100)] + ax[0].scatter([1], sgn_values, color=COLOR_MEASUREMENT, marker="x", s=100) + ax[1].scatter([1], ihc_values, color=COLOR_MEASUREMENT, marker="x", s=100) - axes[0].set_ylabel("IHC count", fontsize=main_label_size) - axes[0].set_yticks(y_ticks) - axes[0].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) - axes[0].set_ylim(ylim0, ylim1) + # Labels and formatting + ax[0].set_xticks([1]) + ax[0].set_xticklabels(["SGN"], fontsize=main_label_size) - axes[0].boxplot(ihc_values) - axes[0].set_xticklabels([alias], fontsize=main_label_size) + ylim0 = 14000 + ylim1 = 30000 + ytick_gap = 4000 + y_ticks = [i for i in range((((ylim0 - 1) // ytick_gap) + 1) * ytick_gap, ylim1 + 1, ytick_gap)] - # Set the reference values for healthy cochleae + ax[0].set_ylabel('Count per cochlea', fontsize=main_label_size) + ax[0].set_yticks(y_ticks) + ax[0].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) + ax[0].set_ylim(ylim0, ylim1) + + # set range of literature values xmin = 0.5 xmax = 1.5 - ihc_reference_values = [712, 710, 721, 675] # MLR226L, MLR226R, MLR227L, MLR227R + ax[0].set_xlim(xmin, xmax) + lower_y, upper_y = literature_reference_values_gerbil("SGN") + ax[0].hlines([lower_y, upper_y], xmin, xmax, color=COLOR_LITERATURE) + ax[0].text(1, upper_y - 2000, "literature", color=COLOR_LITERATURE, fontsize=main_tick_size, ha="center") + ax[0].fill_between([xmin, xmax], lower_y, upper_y, color=COLOR_LITERATURE, alpha=0.05, interpolate=True) + + ylim0 = 900 + ylim1 = 1400 + ytick_gap = 200 + y_ticks = [i for i in range((((ylim0 - 1) // ytick_gap) + 1) * ytick_gap, ylim1 + 1, ytick_gap)] - ihc_value = np.mean(ihc_reference_values) - ihc_std = np.std(ihc_reference_values) + ax[1].set_xticks([1]) + ax[1].set_xticklabels(["IHC"], fontsize=main_label_size) - upper_y = ihc_value + 1.96 * ihc_std - lower_y = ihc_value - 1.96 * ihc_std + ax[1].set_yticks(y_ticks) + ax[1].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) + ax[1].set_ylim(ylim0, ylim1) - axes[0].hlines([lower_y, upper_y], xmin, xmax, colors=["C1" for _ in range(2)]) - axes[0].text(1, upper_y + 10, "healthy cochleae", color="C1", fontsize=main_tick_size, ha="center") - axes[0].fill_between([xmin, xmax], lower_y, upper_y, color="C1", alpha=0.05, interpolate=True) + # set range of literature values + xmin = 0.5 + xmax = 1.5 + ax[1].set_xlim(xmin, xmax) + lower_y, upper_y = literature_reference_values_gerbil("IHC") + ax[1].hlines([lower_y, upper_y], xmin, xmax, color=COLOR_LITERATURE) + ax[1].fill_between([xmin, xmax], lower_y, upper_y, color=COLOR_LITERATURE, alpha=0.05, interpolate=True) - # - # Create the plot for ribbon synapse distribution. - # + ribbon_synapse_counts = _load_ribbon_synapse_counts() ylim0 = -1 - ylim1 = 24 - y_ticks = [i for i in range(0, 25, 5)] + ylim1 = 80 + ytick_gap = 20 + y_ticks = [i for i in range((((ylim0 - 1) // ytick_gap) + 1) * ytick_gap, ylim1 + 1, ytick_gap)] + + box_plot = ax[2].boxplot(ribbon_synapse_counts, patch_artist=True) + for median in box_plot['medians']: + median.set_color(COLOR_MEASUREMENT) + for boxcolor in box_plot['boxes']: + boxcolor.set_facecolor("white") + + ax[2].set_xticklabels(["Synapses per IHC"], fontsize=main_label_size) + ax[2].set_yticks(y_ticks) + ax[2].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) + ax[2].set_ylim(ylim0, ylim1) + + # set range of literature values + xmin = 0.5 + xmax = 1.5 + lower_y, upper_y = literature_reference_values_gerbil("synapse") + ax[2].set_xlim(xmin, xmax) + ax[2].hlines([lower_y, upper_y], xmin, xmax, color=COLOR_LITERATURE) + ax[2].fill_between([xmin, xmax], lower_y, upper_y, color=COLOR_LITERATURE, alpha=0.05, interpolate=True) - axes[1].set_ylabel("Ribbon Syn. per IHC", fontsize=main_label_size) - axes[1].set_yticks(y_ticks) - axes[1].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) - axes[1].set_ylim(ylim0, ylim1) + plt.tight_layout() + prism_cleanup_axes(ax) - axes[1].boxplot(ribbon_synapse_counts) - axes[1].set_xticklabels([alias], fontsize=main_label_size) + if ".png" in save_path: + plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) + else: + plt.savefig(save_path, bbox_inches='tight', pad_inches=0) + if plot: + plt.show() + else: + plt.close() - axes[1].yaxis.tick_right() - axes[1].yaxis.set_ticks_position("right") - axes[1].yaxis.set_label_position("right") - # Set the reference values for healthy cochleae +def fig_05d(save_path, plot=False): + """Box plot showing the SGN counts of ChReef treated cochleae compared to healthy ones. + """ + prism_style() + values_left = [11351] + values_right = [21995] + + # Plot + fig, ax = plt.subplots(figsize=(4, 5)) + + main_label_size = 20 + sub_label_size = 16 + main_tick_size = 16 + + offset = 0.08 + x_left = 1 + x_right = 2 + + x_pos_inj = [x_left - len(values_left) // 2 * offset + offset * i for i in range(len(values_left))] + x_pos_non = [x_right - len(values_right) // 2 * offset + offset * i for i in range(len(values_right))] + + # lines between cochleae of same animal + for num, (left, right) in enumerate(zip(values_left, values_right)): + ax.plot( + [x_pos_inj[num], x_pos_non[num]], + [left, right], + linestyle="solid", + color="grey", + alpha=0.4, + zorder=0 + ) + plt.scatter(x_pos_inj, values_left, label="Injected", + color=COLOR_LEFT, marker=MARKER_LEFT, s=80, zorder=1) + plt.scatter(x_pos_non, values_right, label="Non-Injected", + color=COLOR_RIGHT, marker=MARKER_RIGHT, s=80, zorder=1) + + # Labels and formatting + plt.xticks([x_left, x_right], ["Injected", "Non-\nInjected"], fontsize=sub_label_size) + for label in plt.gca().get_xticklabels(): + label.set_verticalalignment('center') + ax.tick_params(axis='x', which='major', pad=16) + + plt.ylim(10000, 24000) + y_ticks = [i for i in range(10000, 24000, 4000)] + + plt.yticks(y_ticks, fontsize=main_tick_size) + plt.ylabel("SGN count per cochlea", fontsize=main_label_size) xmin = 0.5 - xmax = 1.5 - syn_reference_values = [14.1, 12.7, 13.8, 13.4] # MLR226L, MLR226R, MLR227L, MLR227R + xmax = 2.5 + plt.xlim(xmin, xmax) + + sgn_values = [18541] # G_EK_000233_L + sgn_value = np.mean(sgn_values) + sgn_std = np.std(sgn_values) - syn_value = np.mean(syn_reference_values) - syn_std = np.std(syn_reference_values) + upper_y = sgn_value + 1.96 * sgn_std + lower_y = sgn_value - 1.96 * sgn_std - upper_y = syn_value + 1.96 * syn_std - lower_y = syn_value - 1.96 * syn_std + c_untreated = COLOR_UNTREATED - plt.hlines([lower_y, upper_y], xmin, xmax, colors=["C1" for _ in range(2)]) - plt.text( - 1.25, upper_y + 0.01*axes[1].get_ylim()[1]-axes[1].get_ylim()[0], "healthy cochleae", - color="C1", fontsize=htext_size, ha="center" - ) - plt.fill_between([xmin, xmax], lower_y, upper_y, color="C1", alpha=0.05, interpolate=True) + plt.hlines([lower_y, upper_y], xmin, xmax, colors=[c_untreated for _ in range(2)], zorder=-1) + plt.text((xmin + xmax) / 2, upper_y + 200, "untreated cochleae\n(95% confidence interval)", + color=c_untreated, fontsize=11, ha="center") + plt.fill_between([xmin, xmax], lower_y, upper_y, color=c_untreated, alpha=0.05, interpolate=True) - # Save and plot the figure. plt.tight_layout() + + prism_cleanup_axes(ax) + if ".png" in save_path: plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) else: @@ -122,33 +214,6 @@ def fig_05c(save_path, plot=False): plt.close() -# TODO -def fig_05d(save_path, plot): - if False: - s3 = create_s3_target() - - # Intensity distribution for OTOF - cochlea = "M_AMD_OTOF1_L" - content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8") - info = json.loads(content.read()) - sources = info["sources"] - - # Load the seg table and filter the compartments. - source_name = "IHC_v4c" - source = sources[source_name]["segmentation"] - rel_path = source["tableData"]["tsv"]["relativePath"] - table_content = s3.open(os.path.join(BUCKET_NAME, cochlea, rel_path, "default.tsv"), mode="rb") - table = pd.read_csv(table_content, sep="\t") - print(table) - - # TODO would need the new intensity subtracted data here. - # Reference: intensity distributions for ChReef - chreef_data = get_chreef_data() - for cochlea, tab in chreef_data.items(): - plt.hist(tab["median"]) - plt.show() - - def main(): parser = argparse.ArgumentParser(description="Generate plots for Fig 5 of the cochlea paper.") parser.add_argument("--figure_dir", "-f", type=str, help="Output directory for plots.", default="./panels/fig5") @@ -157,8 +222,8 @@ def main(): os.makedirs(args.figure_dir, exist_ok=True) - # Panel C: Monitoring of the Syn / IHC loss - # fig_05c(save_path=os.path.join(args.figure_dir, "fig_05c"), plot=args.plot) + # Panel C: The number of SGNs, IHCs and average number of ribbon synapses per IHC + fig_05c(save_path=os.path.join(args.figure_dir, "fig_05c"), plot=args.plot) # Panel D: Tonotopic mapping of the intensities. fig_05d(save_path=os.path.join(args.figure_dir, f"fig_05d.{FILE_EXTENSION}"), plot=args.plot) diff --git a/scripts/figures/plot_fig6.py b/scripts/figures/plot_fig6.py index 3736339..681a4e8 100644 --- a/scripts/figures/plot_fig6.py +++ b/scripts/figures/plot_fig6.py @@ -1,106 +1,258 @@ import argparse +import json +import numpy as np import os +import pickle import pandas as pd import matplotlib.pyplot as plt -from util import literature_reference_values_gerbil, prism_cleanup_axes, prism_style, SYNAPSE_DIR_ROOT +from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target +from util import prism_cleanup_axes, prism_style +from util import frequency_mapping, export_legend FILE_EXTENSION = "png" png_dpi = 300 - -# Load the synapse counts for all IHCs from the relevant tables. -def _load_ribbon_synapse_counts(): - ihc_version = "ihc_counts_v6" - synapse_dir = os.path.join(SYNAPSE_DIR_ROOT, ihc_version) - tables = [entry.path for entry in os.scandir(synapse_dir) if "ihc_count_G_" in entry.name] - syn_counts = [] - for tab in tables: - x = pd.read_csv(tab, sep="\t") - syn_counts.extend(x["synapse_count"].values.tolist()) - return syn_counts - - -def fig_06b(save_path, plot=False): - """Box plot showing the counts for SGN and IHC per gerbil cochlea in comparison to literature values. - """ - main_tick_size = 20 - main_label_size = 20 +INTENSITY_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/tables/LaVision-OTOF" # noqa + +# The cochlea for the CHReef analysis. +COCHLEAE_DICT = { + "LaVision-OTOF23R": {"alias": "01", "component": [4, 18, 7], "color": "#9C5027"}, + "LaVision-OTOF25R": {"alias": "02", "component": [1], "color": "#67279C"}, +} + + +def get_otof_data(): + s3 = create_s3_target() + source_name = "IHC_LOWRES-v3" + + cache_path = "./otof_data.pkl" + cochleae = [key for key in COCHLEAE_DICT.keys()] + + if os.path.exists(cache_path): + with open(cache_path, "rb") as f: + return pickle.load(f) + + chreef_data = {} + for cochlea in cochleae: + print("Processsing cochlea:", cochlea) + content = s3.open(f"{BUCKET_NAME}/{cochlea}/dataset.json", mode="r", encoding="utf-8") + info = json.loads(content.read()) + sources = info["sources"] + + # Load the seg table and filter the compartments. + source = sources[source_name]["segmentation"] + rel_path = source["tableData"]["tsv"]["relativePath"] + table_content = s3.open(os.path.join(BUCKET_NAME, cochlea, rel_path, "default.tsv"), mode="rb") + table = pd.read_csv(table_content, sep="\t") + print(table.columns) + + # May need to be adjusted for some cochleae. + component_labels = COCHLEAE_DICT[cochlea]["component"] + print(cochlea, component_labels) + table = table[table.component_labels.isin(component_labels)] + # The relevant values for analysis. + try: + values = table[["label_id", "length[µm]", "frequency[kHz]", "frequency-mueller[kHz]", + "expression_classification"]] + except KeyError: + print("Could not find the values for", cochlea, "it will be skippped.") + continue + + fname = f"{cochlea.replace('_', '-')}_rbOtof_IHC-LOWRES-v3_object-measures.tsv" + intensity_file = os.path.join(INTENSITY_ROOT, fname) + assert os.path.exists(intensity_file), intensity_file + intensity_table = pd.read_csv(intensity_file, sep="\t") + values = values.merge(intensity_table, on="label_id") + + chreef_data[cochlea] = values + + with open(cache_path, "wb") as f: + pickle.dump(chreef_data, f) + with open(cache_path, "rb") as f: + return pickle.load(f) + + +def plot_legend_fig06e(save_path): + color_dict = {} + for key in COCHLEAE_DICT.keys(): + color_dict[COCHLEAE_DICT[key]["alias"]] = COCHLEAE_DICT[key]["color"] + + marker = ["o" for _ in color_dict] + label = list(color_dict.keys()) + color = [color_dict[key] for key in color_dict.keys()] + + f = lambda m, c: plt.plot([], [], marker=m, color=c, ls="none")[0] + handles = [f(m, c) for (c, m) in zip(color, marker)] + legend = plt.legend(handles, label, loc=3, ncol=2, framealpha=1, frameon=False) + export_legend(legend, save_path) + legend.remove() + plt.close() + + +def _get_trendline_dict(trend_dict,): + x_sorted = [trend_dict[k]["x_sorted"] for k in trend_dict.keys()] + x_dict = {} + for num in range(len(x_sorted[0])): + x_dict[num] = {"pos": num, "values": []} + + for s in x_sorted: + for num, pos in enumerate(s): + x_dict[num]["values"].append(pos) + + y_sorted_all = [trend_dict[k]["y_sorted"] for k in trend_dict.keys()] + y_dict = {} + for num in range(len(x_sorted[0])): + y_dict[num] = {"pos": num, "values": []} + + for num in range(len(x_sorted[0])): + y_dict[num]["mean"] = np.mean([y[num] for y in y_sorted_all]) + y_dict[num]["stdv"] = np.std([y[num] for y in y_sorted_all]) + return x_dict, y_dict + + +def _get_trendline_params(trend_dict): + x_dict, y_dict = _get_trendline_dict(trend_dict) + + x_values = [] + for key in x_dict.keys(): + x_values.append(min(x_dict[key]["values"])) + x_values.append(max(x_dict[key]["values"])) + + y_values_center = [] + y_values_upper = [] + y_values_lower = [] + for key in y_dict.keys(): + y_values_center.append(y_dict[key]["mean"]) + y_values_center.append(y_dict[key]["mean"]) + + y_values_upper.append(y_dict[key]["mean"] + y_dict[key]["stdv"]) + y_values_upper.append(y_dict[key]["mean"] + y_dict[key]["stdv"]) + + y_values_lower.append(y_dict[key]["mean"] - y_dict[key]["stdv"]) + y_values_lower.append(y_dict[key]["mean"] - y_dict[key]["stdv"]) + + return x_values, y_values_center, y_values_upper, y_values_lower + + +def fig_06e_octave(otof_data, save_path, plot=False, use_alias=True, trendline_mode=None, mapping="default"): prism_style() - - rows = 1 - columns = 3 - - fig, ax = plt.subplots(rows, columns, figsize=(columns*3, rows*4)) - - sgn_values = [20050, 21995] - ihc_values = [1100] - - ax[0].boxplot(sgn_values) - ax[1].boxplot(ihc_values) - - # Labels and formatting - ax[0].set_xticklabels(["SGN"], fontsize=main_label_size) - - ylim0 = 14000 - ylim1 = 30000 - ytick_gap = 4000 - y_ticks = [i for i in range((((ylim0 - 1) // ytick_gap) + 1) * ytick_gap, ylim1 + 1, ytick_gap)] - - ax[0].set_ylabel('Count per cochlea', fontsize=main_label_size) - ax[0].set_yticks(y_ticks) - ax[0].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) - ax[0].set_ylim(ylim0, ylim1) - - # set range of literature values - xmin = 0.5 - xmax = 1.5 - ax[0].set_xlim(xmin, xmax) - lower_y, upper_y = literature_reference_values_gerbil("SGN") - ax[0].hlines([lower_y, upper_y], xmin, xmax) - ax[0].text(1, upper_y - 2000, "literature", color='C0', fontsize=main_tick_size, ha="center") - ax[0].fill_between([xmin, xmax], lower_y, upper_y, color='C0', alpha=0.05, interpolate=True) - - ylim0 = 900 - ylim1 = 1400 - ytick_gap = 200 - y_ticks = [i for i in range((((ylim0 - 1) // ytick_gap) + 1) * ytick_gap, ylim1 + 1, ytick_gap)] - - ax[1].set_xticklabels(["IHC"], fontsize=main_label_size) - - ax[1].set_yticks(y_ticks) - ax[1].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) - ax[1].set_ylim(ylim0, ylim1) - - # set range of literature values - xmin = 0.5 - xmax = 1.5 - ax[1].set_xlim(xmin, xmax) - lower_y, upper_y = literature_reference_values_gerbil("IHC") - ax[1].hlines([lower_y, upper_y], xmin, xmax) - ax[1].fill_between([xmin, xmax], lower_y, upper_y, color='C0', alpha=0.05, interpolate=True) - - ribbon_synapse_counts = _load_ribbon_synapse_counts() - ylim0 = -1 - ylim1 = 80 - ytick_gap = 20 - y_ticks = [i for i in range((((ylim0 - 1) // ytick_gap) + 1) * ytick_gap, ylim1 + 1, ytick_gap)] - - ax[2].boxplot(ribbon_synapse_counts) - ax[2].set_xticklabels(["Ribbon Syn. per IHC"], fontsize=main_label_size) - ax[2].set_yticks(y_ticks) - ax[2].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) - ax[2].set_ylim(ylim0, ylim1) - - # set range of literature values - xmin = 0.5 - xmax = 1.5 - lower_y, upper_y = literature_reference_values_gerbil("synapse") - ax[2].set_xlim(xmin, xmax) - ax[2].hlines([lower_y, upper_y], xmin, xmax) - ax[2].fill_between([xmin, xmax], lower_y, upper_y, color="C0", alpha=0.05, interpolate=True) - + label_size = 20 + tick_label_size = 14 + + result = {"cochlea": [], "octave_band": [], "value": []} + expression_eff_dic = {} + color_dict = {} + for name, values in otof_data.items(): + if use_alias: + alias = COCHLEAE_DICT[name]["alias"] + else: + alias = name.replace("_", "").replace("0", "") + + color_dict[alias] = COCHLEAE_DICT[name]["color"] + if mapping == "default": + freq = values["frequency[kHz]"].values + elif mapping == "mueller": + freq = values["frequency-mueller[kHz]"].values + else: + raise ValueError("Choose either 'default' or 'mueller' for tonotopic mapping.") + marker_labels = values["expression_classification"].values + marker_pos = len([1 for i in marker_labels if i == 1]) + marker_neg = len([1 for i in marker_labels if i == 2]) + expression_eff = marker_pos / (marker_pos + marker_neg) + print(f"Cochlea {name}, average expression efficiency {expression_eff}") + octave_binned = frequency_mapping(freq, marker_labels, animal="mouse", transduction_efficiency=True) + + result["cochlea"].extend([alias] * len(octave_binned)) + result["octave_band"].extend(octave_binned.axes[0].values.tolist()) + result["value"].extend(octave_binned.values.tolist()) + expression_eff_dic[alias] = expression_eff + + result = pd.DataFrame(result) + bin_labels = pd.unique(result["octave_band"]) + band_to_x = {band: i for i, band in enumerate(bin_labels)} + result["x_pos"] = result["octave_band"].map(band_to_x) + + fig, ax = plt.subplots(figsize=(8, 4)) + + offset = 0.08 + trend_dict = {} + for num, (name, grp) in enumerate(result.groupby("cochlea")): + x_sorted = grp["x_pos"] + x_positions = [x - len(grp["x_pos"]) // 2 * offset + offset * num for x in grp["x_pos"]] + ax.scatter(x_positions, grp["value"], marker="o", label=name, s=80, alpha=1, color=color_dict[name]) + + # y_values.append(list(grp["value"])) + + if trendline_mode == "filled": + sorted_idx = np.argsort(x_positions) + x_sorted = np.array(x_positions)[sorted_idx] + y_sorted = np.array(grp["value"])[sorted_idx] + trend_dict[name] = {"x_sorted": x_sorted, + "y_sorted": y_sorted, + } + # central line + if trendline_mode == "filled": + # mean, std = _get_trendline_params(y_values) + x_sorted, y_sorted, y_sorted_upper, y_sorted_lower = _get_trendline_params(trend_dict) + trend_center, = ax.plot( + x_sorted, + y_sorted, + linestyle="dotted", + color="gray", + alpha=0.6, + linewidth=3, + zorder=2 + ) + # y_sorted_upper = np.array(mean) + np.array(std) + # y_sorted_lower = np.array(mean) - np.array(std) + # upper and lower standard deviation + trend_upper, = ax.plot( + x_sorted, + y_sorted_upper, + linestyle="solid", + color="gray", + alpha=0.08, + zorder=0 + ) + trend_lower, = ax.plot( + x_sorted, + y_sorted_lower, + linestyle="solid", + color="gray", + alpha=0.08, + zorder=0 + ) + plt.fill_between(x_sorted, y_sorted_lower, y_sorted_upper, + color="gray", alpha=0.05, interpolate=True) + + elif trendline_mode == "mean": + xlim_left, xlim_right = ax.get_xlim() + y_offset = [0.01, -0.04] + x_offset = 0.5 + plt.xlim(xlim_left, xlim_right) + for num, key in enumerate(color_dict.keys()): + color = color_dict[key] + expression_eff = expression_eff_dic[key] + + ax.text(xlim_left + x_offset, expression_eff + y_offset[num], "mean", + color=color, fontsize=tick_label_size, ha="center") + trend_r, = ax.plot( + [xlim_left, xlim_right], + [expression_eff, expression_eff], + linestyle="dashed", + color=color, + alpha=0.7, + zorder=0 + ) + + ax.set_xticks(range(len(bin_labels))) + ax.set_xticklabels(bin_labels) + ax.set_xlabel("Octave band [kHz]", fontsize=label_size) + + ax.set_ylabel("Expression efficiency") + # plt.legend(title="Cochlea") plt.tight_layout() prism_cleanup_axes(ax) @@ -108,6 +260,7 @@ def fig_06b(save_path, plot=False): plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) else: plt.savefig(save_path, bbox_inches='tight', pad_inches=0) + if plot: plt.show() else: @@ -131,12 +284,17 @@ def fig_06d(save_path, plot=False): def main(): parser = argparse.ArgumentParser(description="Generate plots for Fig 6 of the cochlea paper.") - parser.add_argument("figure_dir", type=str, help="Output directory for plots.", default="./panels") + parser.add_argument("-f", "--figure_dir", type=str, help="Output directory for plots.", default="./panels") args = parser.parse_args() plot = False - fig_06b(save_path=os.path.join(args.figure_dir, f"fig_06b.{FILE_EXTENSION}"), plot=plot) - fig_06d(save_path=os.path.join(args.figure_dir, f"fig_06d.{FILE_EXTENSION}"), plot=plot) + tonotopic_mapping = "mueller" + otof_data = get_otof_data() + plot_legend_fig06e(save_path=os.path.join(args.figure_dir, f"fig_06e_legend.{FILE_EXTENSION}")) + fig_06e_octave(otof_data, save_path=os.path.join(args.figure_dir, f"fig_06e.{FILE_EXTENSION}"), plot=plot, + trendline_mode="mean", mapping=tonotopic_mapping) + + # fig_06d(save_path=os.path.join(args.figure_dir, f"fig_06d.{FILE_EXTENSION}"), plot=plot) if __name__ == "__main__": diff --git a/scripts/figures/util.py b/scripts/figures/util.py index 9d1b832..02c1bc7 100644 --- a/scripts/figures/util.py +++ b/scripts/figures/util.py @@ -5,6 +5,7 @@ # Directory with synapse measurement tables SYNAPSE_DIR_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/synapses" # SYNAPSE_DIR_ROOT = "./synapses" +png_dpi = 300 def ax_prism_boxplot(ax, data, positions=None, color="tab:blue"): @@ -31,6 +32,42 @@ def ax_prism_boxplot(ax, data, positions=None, color="tab:blue"): return bp +prism_palette = [ + "#4E79A7", # blue + "#F28E2B", # orange + "#E15759", # red + "#76B7B2", # teal + "#59A14F", # green + "#EDC948", # yellow + "#B07AA1", # purple + "#FF9DA7", # pink + "#9C755F", # brown + "#BAB0AC" # gray +] + + +def custom_formatter_1(x, pos): + if np.isclose(x, 1.0): + return '1' # no decimal + else: + return f"{x:.1f}" + + +def custom_formatter_2(x, pos): + if np.isclose(x, 1.0): + return '1' # no decimal + else: + return f"{x:.2f}" + + +def export_legend(legend, filename="legend.png"): + legend.axes.axis("off") + fig = legend.figure + fig.canvas.draw() + bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) + fig.savefig(filename, bbox_inches=bbox, dpi=png_dpi) + + def prism_style(): plt.style.use("default") # reset any active styles plt.rcParams.update({ @@ -44,6 +81,7 @@ def prism_style(): "axes.linewidth": 1.2, "axes.labelsize": 14, "axes.labelweight": "bold", + "axes.prop_cycle": plt.cycler("color", prism_palette), # Ticks "xtick.direction": "out", @@ -93,12 +131,12 @@ def _get_mapping(animal): if animal == "mouse": bin_edges = [0, 2, 4, 8, 16, 32, 64, np.inf] bin_labels = [ - "<2 k", "2–4 k", "4–8 k", "8–16 k", "16–32 k", "32–64 k", ">64 k" + "<2", "2–4", "4–8", "8–16", "16–32", "32–64", ">64" ] elif animal == "gerbil": bin_edges = [0, 0.5, 1, 2, 4, 8, 16, 32, np.inf] bin_labels = [ - "<0.5 k", "0.5–1 k", "1–2 k", "2–4 k", "4–8 k", "8–16 k", "16–32 k", ">32 k" + "<0.5", "0.5–1", "1–2", "2–4", "4–8", "8–16", "16–32", ">32" ] else: raise ValueError