diff --git a/scripts/figures/plot_fig2.py b/scripts/figures/plot_fig2.py index 9b357ae..a897090 100644 --- a/scripts/figures/plot_fig2.py +++ b/scripts/figures/plot_fig2.py @@ -4,12 +4,116 @@ import numpy as np import pandas as pd import matplotlib.pyplot as plt +import tifffile +from matplotlib import colors +from skimage.segmentation import find_boundaries from util import literature_reference_values png_dpi = 300 +def scramble_instance_labels(arr): + """Scramble indexes of instance segmentation to avoid neighboring colors. + """ + unique = list(np.unique(arr)[1:]) + rng = np.random.default_rng(seed=42) + new_list = rng.uniform(1, len(unique) + 1, size=(len(unique))) + new_arr = np.zeros(arr.shape) + for old_id, new_id in zip(unique, new_list): + new_arr[arr == old_id] = new_id + return new_arr + + +def plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba=[0, 0, 0, 0.5], plot=False): + seg = tifffile.imread(seg_path) + if len(seg.shape) == 3: + seg = seg[10, xlim1:xlim2, ylim1:ylim2] + else: + seg = seg[xlim1:xlim2, ylim1:ylim2] + + img = tifffile.imread(img_path) + img = img[10, xlim1:xlim2, ylim1:ylim2] + + # create color map with random distribution for coloring instance segmentation + unique = list(np.unique(seg)[1:]) + n_instances = len(unique) + + seg = scramble_instance_labels(seg) + + rng = np.random.default_rng(seed=42) # fixed seed for reproducibility + colors_array = rng.uniform(0, 1, size=(n_instances, 4)) # RGBA values in [0,1] + colors_array[:, 3] = 1.0 # full alpha + colors_array[0, 3] = 0.0 # make label 0 transparent (background) + cmap = colors.ListedColormap(colors_array) + + boundaries = find_boundaries(seg, mode="inner") + boundary_overlay = np.zeros((*boundaries.shape, 4)) + + boundary_overlay[boundaries] = boundary_rgba # RGBA = black + + fig, ax = plt.subplots(figsize=(6, 6)) + ax.imshow(img, cmap="gray") + ax.imshow(seg, cmap=cmap, alpha=0.5, interpolation="nearest") + ax.imshow(boundary_overlay) + ax.axis("off") + plt.tight_layout() + plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) + + if plot: + plt.show() + else: + plt.close() + + +def fig_02b_sgn(save_dir, plot=False): + """Plot crops of SGN segmentation of CochleaNet, Cellpose and micro-sam. + """ + cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet" + val_sgn_dir = f"{cochlea_dir}/predictions/val_sgn" + image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationSGNs/for_consensus_annotation" + + crop_name = "MLR169R_PV_z3420_allturns_full" + img_path = os.path.join(image_dir, f"{crop_name}.tif") + + xlim1 = 2000 + xlim2 = 2500 + ylim1 = 3100 + ylim2 = 3600 + boundary_rgba = [1, 1, 1, 0.5] + + for seg_net in ["distance_unet", "cellpose-sam", "micro-sam"]: + save_path = os.path.join(save_dir, f"fig_02b_sgn_{seg_net}.png") + seg_dir = os.path.join(val_sgn_dir, seg_net) + seg_path = os.path.join(seg_dir, f"{crop_name}_seg.tif") + + plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba, plot=plot) + + +def fig_02b_ihc(save_dir, plot=False): + """Plot crops of IHC segmentation of CochleaNet, Cellpose and micro-sam. + """ + cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet" + val_sgn_dir = f"{cochlea_dir}/predictions/val_ihc" + image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationIHCs" + + crop_name = "MLR226L_VGlut3_z1200_3turns_full" + img_path = os.path.join(image_dir, f"{crop_name}.tif") + + xlim1 = 1900 + xlim2 = 2400 + ylim1 = 2000 + ylim2 = 2500 + boundary_rgba = [1, 1, 1, 0.5] + + for seg_net in ["distance_unet_v4b", "cellpose-sam", "micro-sam"]: + save_path = os.path.join(save_dir, f"fig_02b_ihc_{seg_net}.png") + seg_dir = os.path.join(val_sgn_dir, seg_net) + seg_path = os.path.join(seg_dir, f"{crop_name}_seg.tif") + + plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba, plot=plot) + + def fig_02c(save_path, plot=False, all_versions=False): """Scatter plot showing the precision, recall, and F1-score of SGN (distance U-Net, manual), IHC (distance U-Net, manual), and synapse detection (U-Net). @@ -299,6 +403,8 @@ def main(): os.makedirs(args.figure_dir, exist_ok=True) # Panel C: Evaluation of the segmentation results: + fig_02b_sgn(save_dir=args.figure_dir, plot=args.plot) + fig_02b_ihc(save_dir=args.figure_dir, plot=args.plot) fig_02c(save_path=os.path.join(args.figure_dir, "fig_02c"), plot=args.plot, all_versions=False) # Panel D: The number of SGNs, IHCs and average number of ribbon synapses per IHC diff --git a/scripts/figures/plot_fig3.py b/scripts/figures/plot_fig3.py index dff146c..0e8e8c8 100644 --- a/scripts/figures/plot_fig3.py +++ b/scripts/figures/plot_fig3.py @@ -1,7 +1,6 @@ import argparse import os import imageio.v3 as imageio -from glob import glob import matplotlib.pyplot as plt import numpy as np @@ -72,7 +71,10 @@ def fig_03b(save_path): def fig_03c_rl(save_path, plot=False): - tables = glob("./ihc_counts/ihc_count_M_LR*.tsv") + ihc_version = "ihc_counts_v4c" + synapse_dir = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/synapses/{ihc_version}" + tables = [entry.path for entry in os.scandir(synapse_dir) if "ihc_count_M_LR" in entry.name] + fig, ax = plt.subplots(figsize=(8, 4)) width = 50 # micron @@ -102,7 +104,9 @@ def fig_03c_rl(save_path, plot=False): def fig_03c_octave(save_path, plot=False): - tables = glob("./ihc_counts/ihc_count_M_LR*.tsv") + ihc_version = "ihc_counts_v4c" + synapse_dir = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/synapses/{ihc_version}" + tables = [entry.path for entry in os.scandir(synapse_dir) if "ihc_count_M_LR" in entry.name] result = {"cochlea": [], "octave_band": [], "value": []} for tab_path in tables: @@ -134,7 +138,7 @@ def fig_03c_octave(save_path, plot=False): ax.set_ylabel("Average Ribbon Synapse Count per IHC") ax.set_title("Ribbon synapse count per octave band") - ax.legend(title="Cochlea") + plt.legend(title="Cochlea") plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi) if plot: diff --git a/scripts/figures/plot_fig4.py b/scripts/figures/plot_fig4.py index 698d954..8b4dcd8 100644 --- a/scripts/figures/plot_fig4.py +++ b/scripts/figures/plot_fig4.py @@ -13,42 +13,21 @@ INTENSITY_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/tables/measurements2" # noqa # The cochlea for the CHReef analysis. -COCHLEAE = [ - "M_LR_000143_L", - "M_LR_000144_L", - "M_LR_000145_L", - "M_LR_000153_L", - "M_LR_000155_L", - "M_LR_000189_L", - "M_LR_000143_R", - "M_LR_000144_R", - "M_LR_000145_R", - "M_LR_000153_R", - "M_LR_000155_R", - "M_LR_000189_R", -] - -COCHLEAE_GERBIL = [ - "G_EK_000049_L", - "G_EK_000049_R", -] - - -COCHLEAE_ALIAS = { - "M_LR_000143_L": "M0L", - "M_LR_000144_L": "M05L", - "M_LR_000145_L": "M06L", - "M_LR_000153_L": "M07L", - "M_LR_000155_L": "M08L", - "M_LR_000189_L": "M09L", - "M_LR_000143_R": "M0R", - "M_LR_000144_R": "M05R", - "M_LR_000145_R": "M06R", - "M_LR_000153_R": "M07R", - "M_LR_000155_R": "M08R", - "M_LR_000189_R": "M09R", - "G_EK_000049_L": "G1L", - "G_EK_000049_R": "G1R", +COCHLEAE_DICT = { + "M_LR_000143_L": {"alias": "M0L", "component": [1]}, + "M_LR_000144_L": {"alias": "M05L", "component": [1]}, + "M_LR_000145_L": {"alias": "M06L", "component": [1]}, + "M_LR_000153_L": {"alias": "M07L", "component": [1]}, + "M_LR_000155_L": {"alias": "M08L", "component": [1, 2, 3]}, + "M_LR_000189_L": {"alias": "M09L", "component": [1]}, + "M_LR_000143_R": {"alias": "M0R", "component": [1]}, + "M_LR_000144_R": {"alias": "M05R", "component": [1]}, + "M_LR_000145_R": {"alias": "M06R", "component": [1]}, + "M_LR_000153_R": {"alias": "M07R", "component": [1]}, + "M_LR_000155_R": {"alias": "M08R", "component": [1]}, + "M_LR_000189_R": {"alias": "M09R", "component": [1]}, + "G_EK_000049_L": {"alias": "G1L", "component": [1, 3, 4, 5]}, + "G_EK_000049_R": {"alias": "G1R", "component": [1, 2]}, } png_dpi = 300 @@ -60,10 +39,10 @@ def get_chreef_data(animal="mouse"): if animal == "mouse": cache_path = "./chreef_data.pkl" - cochleae = COCHLEAE + cochleae = [key for key in COCHLEAE_DICT.keys() if "M_" in key] else: cache_path = "./chreef_data_gerbil.pkl" - cochleae = COCHLEAE_GERBIL + cochleae = [key for key in COCHLEAE_DICT.keys() if "G_" in key] if os.path.exists(cache_path): with open(cache_path, "rb") as f: @@ -83,7 +62,9 @@ def get_chreef_data(animal="mouse"): table = pd.read_csv(table_content, sep="\t") # May need to be adjusted for some cochleae. - table = table[table.component_labels == 1] + component_labels = COCHLEAE_DICT[cochlea]["component"] + print(cochlea, component_labels) + table = table[table.component_labels.isin(component_labels)] # The relevant values for analysis. try: values = table[["label_id", "length[µm]", "frequency[kHz]", "marker_labels"]] @@ -136,7 +117,7 @@ def fig_04c(chreef_data, save_path, plot=False, plot_by_side=False, use_alias=Tr # TODO have central function for alias for all plots? if use_alias: - alias = [COCHLEAE_ALIAS[k] for k in chreef_data.keys()] + alias = [COCHLEAE_DICT[k]["alias"] for k in chreef_data.keys()] else: alias = [name.replace("_", "").replace("0", "") for name in chreef_data.keys()] @@ -206,7 +187,7 @@ def fig_04d(chreef_data, save_path, plot=False, plot_by_side=False, intensity=Fa """Transduction efficiency per cochlea. """ if use_alias: - alias = [COCHLEAE_ALIAS[k] for k in chreef_data.keys()] + alias = [COCHLEAE_DICT[k]["alias"] for k in chreef_data.keys()] else: alias = [name.replace("_", "").replace("0", "") for name in chreef_data.keys()] @@ -237,7 +218,7 @@ def fig_04d(chreef_data, save_path, plot=False, plot_by_side=False, intensity=Fa main_label_size = 20 sub_label_size = 16 main_tick_size = 12 - legendsize = 16 + legendsize = 16 if intensity else 12 label = "Intensity" if intensity else "Transduction efficiency" if plot_by_side: @@ -274,7 +255,7 @@ def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False, use_ali result = {"cochlea": [], "octave_band": [], "value": []} for name, values in chreef_data.items(): if use_alias: - alias = COCHLEAE_ALIAS[name] + alias = COCHLEAE_DICT[name]["alias"] else: alias = name.replace("_", "").replace("0", "") diff --git a/scripts/figures/plot_fig6.py b/scripts/figures/plot_fig6.py index c0cd5fe..2c0a2c7 100644 --- a/scripts/figures/plot_fig6.py +++ b/scripts/figures/plot_fig6.py @@ -3,6 +3,8 @@ import matplotlib.pyplot as plt +from util import literature_reference_values_gerbil + png_dpi = 300 @@ -27,9 +29,10 @@ def fig_06a(save_path, plot=False): # Labels and formatting ax[0].set_xticklabels(["SGN"], fontsize=main_label_size) - ylim0 = 12000 - ylim1 = 22500 - y_ticks = [i for i in range(ylim0, ylim1 + 1, 2000)] + ylim0 = 14000 + ylim1 = 30000 + ytick_gap = 4000 + y_ticks = [i for i in range((((ylim0 - 1) // ytick_gap) + 1) * ytick_gap, ylim1 + 1, ytick_gap)] ax[0].set_ylabel('Count per cochlea', fontsize=main_label_size) ax[0].set_yticks(y_ticks) @@ -40,19 +43,18 @@ def fig_06a(save_path, plot=False): xmin = 0.5 xmax = 1.5 ax[0].set_xlim(xmin, xmax) - upper_y = 15000 - lower_y = 13000 + lower_y, upper_y = literature_reference_values_gerbil("SGN") ax[0].hlines([lower_y, upper_y], xmin, xmax) - ax[0].text(1, upper_y + 100, "literature reference (WIP)", color='C0', fontsize=main_tick_size, ha="center") + ax[0].text(1, upper_y + 100, "literature reference", color='C0', fontsize=main_tick_size, ha="center") ax[0].fill_between([xmin, xmax], lower_y, upper_y, color='C0', alpha=0.05, interpolate=True) - ylim0 = 800 + ylim0 = 900 ylim1 = 1400 - y_ticks = [i for i in range(ylim0, ylim1 + 1, 100)] + ytick_gap = 200 + y_ticks = [i for i in range((((ylim0 - 1) // ytick_gap) + 1) * ytick_gap, ylim1 + 1, ytick_gap)] ax[1].set_xticklabels(["IHC"], fontsize=main_label_size) - ax[1].set_ylabel('Count per cochlea', fontsize=main_label_size) ax[1].set_yticks(y_ticks) ax[1].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size) ax[1].set_ylim(ylim0, ylim1) @@ -61,10 +63,9 @@ def fig_06a(save_path, plot=False): xmin = 0.5 xmax = 1.5 ax[1].set_xlim(xmin, xmax) - upper_y = 1200 - lower_y = 1000 + lower_y, upper_y = literature_reference_values_gerbil("IHC") ax[1].hlines([lower_y, upper_y], xmin, xmax) - ax[1].text(1, upper_y + 10, "literature reference (WIP)", color='C0', fontsize=main_tick_size, ha="center") + ax[1].text(1, upper_y + 10, "literature reference", color='C0', fontsize=main_tick_size, ha="center") ax[1].fill_between([xmin, xmax], lower_y, upper_y, color='C0', alpha=0.05, interpolate=True) plt.tight_layout() diff --git a/scripts/figures/util.py b/scripts/figures/util.py index 7dd3056..8cf6501 100644 --- a/scripts/figures/util.py +++ b/scripts/figures/util.py @@ -70,3 +70,15 @@ def literature_reference_values(structure): else: raise ValueError return lower_bound, upper_bound + + +def literature_reference_values_gerbil(structure): + if structure == "SGN": + lower_bound, upper_bound = 24700, 28450 + elif structure == "IHC": + lower_bound, upper_bound = 1081, 1081 + elif structure == "synapse": + lower_bound, upper_bound = 9.1, 20.7 + else: + raise ValueError + return lower_bound, upper_bound \ No newline at end of file diff --git a/scripts/prediction/run_prediction_distance_unet.py b/scripts/prediction/run_prediction_distance_unet.py index 296cb6b..8c13ecf 100644 --- a/scripts/prediction/run_prediction_distance_unet.py +++ b/scripts/prediction/run_prediction_distance_unet.py @@ -37,6 +37,8 @@ def main(): in which case the boundary distances are not used for the seeds.") parser.add_argument("--fg_threshold", default=0.5, type=float, help="The threshold applied to the foreground prediction for deriving the watershed mask.") + parser.add_argument("--distance_smoothing", default=0, type=float, + help="The sigma value for smoothing the distance predictions with a gaussian kernel.") args = parser.parse_args() @@ -78,7 +80,7 @@ def main(): seg_class=args.seg_class, center_distance_threshold=args.center_distance_threshold, boundary_distance_threshold=args.boundary_distance_threshold, - fg_threshold=args.fg_threshold, + fg_threshold=args.fg_threshold, distance_smoothing=args.distance_smoothing, ) abs_path = os.path.abspath(args.input) @@ -95,7 +97,7 @@ def main(): seg_class=args.seg_class, center_distance_threshold=args.center_distance_threshold, boundary_distance_threshold=args.boundary_distance_threshold, - fg_threshold=args.fg_threshold, + fg_threshold=args.fg_threshold, distance_smoothing=args.distance_smoothing, ) timer_output = os.path.join(args.output_folder, "timer.json")