Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions scripts/figures/plot_fig2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,116 @@
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tifffile
from matplotlib import colors
from skimage.segmentation import find_boundaries

from util import literature_reference_values

png_dpi = 300


def scramble_instance_labels(arr):
"""Scramble indexes of instance segmentation to avoid neighboring colors.
"""
unique = list(np.unique(arr)[1:])
rng = np.random.default_rng(seed=42)
new_list = rng.uniform(1, len(unique) + 1, size=(len(unique)))
new_arr = np.zeros(arr.shape)
for old_id, new_id in zip(unique, new_list):
new_arr[arr == old_id] = new_id
return new_arr


def plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba=[0, 0, 0, 0.5], plot=False):
seg = tifffile.imread(seg_path)
if len(seg.shape) == 3:
seg = seg[10, xlim1:xlim2, ylim1:ylim2]
else:
seg = seg[xlim1:xlim2, ylim1:ylim2]

img = tifffile.imread(img_path)
img = img[10, xlim1:xlim2, ylim1:ylim2]

# create color map with random distribution for coloring instance segmentation
unique = list(np.unique(seg)[1:])
n_instances = len(unique)

seg = scramble_instance_labels(seg)

rng = np.random.default_rng(seed=42) # fixed seed for reproducibility
colors_array = rng.uniform(0, 1, size=(n_instances, 4)) # RGBA values in [0,1]
colors_array[:, 3] = 1.0 # full alpha
colors_array[0, 3] = 0.0 # make label 0 transparent (background)
cmap = colors.ListedColormap(colors_array)

boundaries = find_boundaries(seg, mode="inner")
boundary_overlay = np.zeros((*boundaries.shape, 4))

boundary_overlay[boundaries] = boundary_rgba # RGBA = black

fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(img, cmap="gray")
ax.imshow(seg, cmap=cmap, alpha=0.5, interpolation="nearest")
ax.imshow(boundary_overlay)
ax.axis("off")
plt.tight_layout()
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)

if plot:
plt.show()
else:
plt.close()


def fig_02b_sgn(save_dir, plot=False):
"""Plot crops of SGN segmentation of CochleaNet, Cellpose and micro-sam.
"""
cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet"
val_sgn_dir = f"{cochlea_dir}/predictions/val_sgn"
image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationSGNs/for_consensus_annotation"

crop_name = "MLR169R_PV_z3420_allturns_full"
img_path = os.path.join(image_dir, f"{crop_name}.tif")

xlim1 = 2000
xlim2 = 2500
ylim1 = 3100
ylim2 = 3600
boundary_rgba = [1, 1, 1, 0.5]

for seg_net in ["distance_unet", "cellpose-sam", "micro-sam"]:
save_path = os.path.join(save_dir, f"fig_02b_sgn_{seg_net}.png")
seg_dir = os.path.join(val_sgn_dir, seg_net)
seg_path = os.path.join(seg_dir, f"{crop_name}_seg.tif")

plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba, plot=plot)


def fig_02b_ihc(save_dir, plot=False):
"""Plot crops of IHC segmentation of CochleaNet, Cellpose and micro-sam.
"""
cochlea_dir = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet"
val_sgn_dir = f"{cochlea_dir}/predictions/val_ihc"
image_dir = f"{cochlea_dir}/AnnotatedImageCrops/F1ValidationIHCs"

crop_name = "MLR226L_VGlut3_z1200_3turns_full"
img_path = os.path.join(image_dir, f"{crop_name}.tif")

xlim1 = 1900
xlim2 = 2400
ylim1 = 2000
ylim2 = 2500
boundary_rgba = [1, 1, 1, 0.5]

for seg_net in ["distance_unet_v4b", "cellpose-sam", "micro-sam"]:
save_path = os.path.join(save_dir, f"fig_02b_ihc_{seg_net}.png")
seg_dir = os.path.join(val_sgn_dir, seg_net)
seg_path = os.path.join(seg_dir, f"{crop_name}_seg.tif")

plot_seg_crop(img_path, seg_path, save_path, xlim1, xlim2, ylim1, ylim2, boundary_rgba, plot=plot)


def fig_02c(save_path, plot=False, all_versions=False):
"""Scatter plot showing the precision, recall, and F1-score of SGN (distance U-Net, manual),
IHC (distance U-Net, manual), and synapse detection (U-Net).
Expand Down Expand Up @@ -299,6 +403,8 @@ def main():
os.makedirs(args.figure_dir, exist_ok=True)

# Panel C: Evaluation of the segmentation results:
fig_02b_sgn(save_dir=args.figure_dir, plot=args.plot)
fig_02b_ihc(save_dir=args.figure_dir, plot=args.plot)
fig_02c(save_path=os.path.join(args.figure_dir, "fig_02c"), plot=args.plot, all_versions=False)

# Panel D: The number of SGNs, IHCs and average number of ribbon synapses per IHC
Expand Down
12 changes: 8 additions & 4 deletions scripts/figures/plot_fig3.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import argparse
import os
import imageio.v3 as imageio
from glob import glob

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -72,7 +71,10 @@ def fig_03b(save_path):


def fig_03c_rl(save_path, plot=False):
tables = glob("./ihc_counts/ihc_count_M_LR*.tsv")
ihc_version = "ihc_counts_v4c"
synapse_dir = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/synapses/{ihc_version}"
tables = [entry.path for entry in os.scandir(synapse_dir) if "ihc_count_M_LR" in entry.name]

fig, ax = plt.subplots(figsize=(8, 4))

width = 50 # micron
Expand Down Expand Up @@ -102,7 +104,9 @@ def fig_03c_rl(save_path, plot=False):


def fig_03c_octave(save_path, plot=False):
tables = glob("./ihc_counts/ihc_count_M_LR*.tsv")
ihc_version = "ihc_counts_v4c"
synapse_dir = f"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/synapses/{ihc_version}"
tables = [entry.path for entry in os.scandir(synapse_dir) if "ihc_count_M_LR" in entry.name]

result = {"cochlea": [], "octave_band": [], "value": []}
for tab_path in tables:
Expand Down Expand Up @@ -134,7 +138,7 @@ def fig_03c_octave(save_path, plot=False):

ax.set_ylabel("Average Ribbon Synapse Count per IHC")
ax.set_title("Ribbon synapse count per octave band")
ax.legend(title="Cochlea")
plt.legend(title="Cochlea")

plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
if plot:
Expand Down
67 changes: 24 additions & 43 deletions scripts/figures/plot_fig4.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,42 +13,21 @@
INTENSITY_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/tables/measurements2" # noqa

# The cochlea for the CHReef analysis.
COCHLEAE = [
"M_LR_000143_L",
"M_LR_000144_L",
"M_LR_000145_L",
"M_LR_000153_L",
"M_LR_000155_L",
"M_LR_000189_L",
"M_LR_000143_R",
"M_LR_000144_R",
"M_LR_000145_R",
"M_LR_000153_R",
"M_LR_000155_R",
"M_LR_000189_R",
]

COCHLEAE_GERBIL = [
"G_EK_000049_L",
"G_EK_000049_R",
]


COCHLEAE_ALIAS = {
"M_LR_000143_L": "M0L",
"M_LR_000144_L": "M05L",
"M_LR_000145_L": "M06L",
"M_LR_000153_L": "M07L",
"M_LR_000155_L": "M08L",
"M_LR_000189_L": "M09L",
"M_LR_000143_R": "M0R",
"M_LR_000144_R": "M05R",
"M_LR_000145_R": "M06R",
"M_LR_000153_R": "M07R",
"M_LR_000155_R": "M08R",
"M_LR_000189_R": "M09R",
"G_EK_000049_L": "G1L",
"G_EK_000049_R": "G1R",
COCHLEAE_DICT = {
"M_LR_000143_L": {"alias": "M0L", "component": [1]},
"M_LR_000144_L": {"alias": "M05L", "component": [1]},
"M_LR_000145_L": {"alias": "M06L", "component": [1]},
"M_LR_000153_L": {"alias": "M07L", "component": [1]},
"M_LR_000155_L": {"alias": "M08L", "component": [1, 2, 3]},
"M_LR_000189_L": {"alias": "M09L", "component": [1]},
"M_LR_000143_R": {"alias": "M0R", "component": [1]},
"M_LR_000144_R": {"alias": "M05R", "component": [1]},
"M_LR_000145_R": {"alias": "M06R", "component": [1]},
"M_LR_000153_R": {"alias": "M07R", "component": [1]},
"M_LR_000155_R": {"alias": "M08R", "component": [1]},
"M_LR_000189_R": {"alias": "M09R", "component": [1]},
"G_EK_000049_L": {"alias": "G1L", "component": [1, 3, 4, 5]},
"G_EK_000049_R": {"alias": "G1R", "component": [1, 2]},
}

png_dpi = 300
Expand All @@ -60,10 +39,10 @@ def get_chreef_data(animal="mouse"):

if animal == "mouse":
cache_path = "./chreef_data.pkl"
cochleae = COCHLEAE
cochleae = [key for key in COCHLEAE_DICT.keys() if "M_" in key]
else:
cache_path = "./chreef_data_gerbil.pkl"
cochleae = COCHLEAE_GERBIL
cochleae = [key for key in COCHLEAE_DICT.keys() if "G_" in key]

if os.path.exists(cache_path):
with open(cache_path, "rb") as f:
Expand All @@ -83,7 +62,9 @@ def get_chreef_data(animal="mouse"):
table = pd.read_csv(table_content, sep="\t")

# May need to be adjusted for some cochleae.
table = table[table.component_labels == 1]
component_labels = COCHLEAE_DICT[cochlea]["component"]
print(cochlea, component_labels)
table = table[table.component_labels.isin(component_labels)]
# The relevant values for analysis.
try:
values = table[["label_id", "length[µm]", "frequency[kHz]", "marker_labels"]]
Expand Down Expand Up @@ -136,7 +117,7 @@ def fig_04c(chreef_data, save_path, plot=False, plot_by_side=False, use_alias=Tr

# TODO have central function for alias for all plots?
if use_alias:
alias = [COCHLEAE_ALIAS[k] for k in chreef_data.keys()]
alias = [COCHLEAE_DICT[k]["alias"] for k in chreef_data.keys()]
else:
alias = [name.replace("_", "").replace("0", "") for name in chreef_data.keys()]

Expand Down Expand Up @@ -206,7 +187,7 @@ def fig_04d(chreef_data, save_path, plot=False, plot_by_side=False, intensity=Fa
"""Transduction efficiency per cochlea.
"""
if use_alias:
alias = [COCHLEAE_ALIAS[k] for k in chreef_data.keys()]
alias = [COCHLEAE_DICT[k]["alias"] for k in chreef_data.keys()]
else:
alias = [name.replace("_", "").replace("0", "") for name in chreef_data.keys()]

Expand Down Expand Up @@ -237,7 +218,7 @@ def fig_04d(chreef_data, save_path, plot=False, plot_by_side=False, intensity=Fa
main_label_size = 20
sub_label_size = 16
main_tick_size = 12
legendsize = 16
legendsize = 16 if intensity else 12

label = "Intensity" if intensity else "Transduction efficiency"
if plot_by_side:
Expand Down Expand Up @@ -274,7 +255,7 @@ def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False, use_ali
result = {"cochlea": [], "octave_band": [], "value": []}
for name, values in chreef_data.items():
if use_alias:
alias = COCHLEAE_ALIAS[name]
alias = COCHLEAE_DICT[name]["alias"]
else:
alias = name.replace("_", "").replace("0", "")

Expand Down
25 changes: 13 additions & 12 deletions scripts/figures/plot_fig6.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import matplotlib.pyplot as plt

from util import literature_reference_values_gerbil

png_dpi = 300


Expand All @@ -27,9 +29,10 @@ def fig_06a(save_path, plot=False):
# Labels and formatting
ax[0].set_xticklabels(["SGN"], fontsize=main_label_size)

ylim0 = 12000
ylim1 = 22500
y_ticks = [i for i in range(ylim0, ylim1 + 1, 2000)]
ylim0 = 14000
ylim1 = 30000
ytick_gap = 4000
y_ticks = [i for i in range((((ylim0 - 1) // ytick_gap) + 1) * ytick_gap, ylim1 + 1, ytick_gap)]

ax[0].set_ylabel('Count per cochlea', fontsize=main_label_size)
ax[0].set_yticks(y_ticks)
Expand All @@ -40,19 +43,18 @@ def fig_06a(save_path, plot=False):
xmin = 0.5
xmax = 1.5
ax[0].set_xlim(xmin, xmax)
upper_y = 15000
lower_y = 13000
lower_y, upper_y = literature_reference_values_gerbil("SGN")
ax[0].hlines([lower_y, upper_y], xmin, xmax)
ax[0].text(1, upper_y + 100, "literature reference (WIP)", color='C0', fontsize=main_tick_size, ha="center")
ax[0].text(1, upper_y + 100, "literature reference", color='C0', fontsize=main_tick_size, ha="center")
ax[0].fill_between([xmin, xmax], lower_y, upper_y, color='C0', alpha=0.05, interpolate=True)

ylim0 = 800
ylim0 = 900
ylim1 = 1400
y_ticks = [i for i in range(ylim0, ylim1 + 1, 100)]
ytick_gap = 200
y_ticks = [i for i in range((((ylim0 - 1) // ytick_gap) + 1) * ytick_gap, ylim1 + 1, ytick_gap)]

ax[1].set_xticklabels(["IHC"], fontsize=main_label_size)

ax[1].set_ylabel('Count per cochlea', fontsize=main_label_size)
ax[1].set_yticks(y_ticks)
ax[1].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size)
ax[1].set_ylim(ylim0, ylim1)
Expand All @@ -61,10 +63,9 @@ def fig_06a(save_path, plot=False):
xmin = 0.5
xmax = 1.5
ax[1].set_xlim(xmin, xmax)
upper_y = 1200
lower_y = 1000
lower_y, upper_y = literature_reference_values_gerbil("IHC")
ax[1].hlines([lower_y, upper_y], xmin, xmax)
ax[1].text(1, upper_y + 10, "literature reference (WIP)", color='C0', fontsize=main_tick_size, ha="center")
ax[1].text(1, upper_y + 10, "literature reference", color='C0', fontsize=main_tick_size, ha="center")
ax[1].fill_between([xmin, xmax], lower_y, upper_y, color='C0', alpha=0.05, interpolate=True)

plt.tight_layout()
Expand Down
12 changes: 12 additions & 0 deletions scripts/figures/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,15 @@ def literature_reference_values(structure):
else:
raise ValueError
return lower_bound, upper_bound


def literature_reference_values_gerbil(structure):
if structure == "SGN":
lower_bound, upper_bound = 24700, 28450
elif structure == "IHC":
lower_bound, upper_bound = 1081, 1081
elif structure == "synapse":
lower_bound, upper_bound = 9.1, 20.7
else:
raise ValueError
return lower_bound, upper_bound
6 changes: 4 additions & 2 deletions scripts/prediction/run_prediction_distance_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ def main():
in which case the boundary distances are not used for the seeds.")
parser.add_argument("--fg_threshold", default=0.5, type=float,
help="The threshold applied to the foreground prediction for deriving the watershed mask.")
parser.add_argument("--distance_smoothing", default=0, type=float,
help="The sigma value for smoothing the distance predictions with a gaussian kernel.")

args = parser.parse_args()

Expand Down Expand Up @@ -78,7 +80,7 @@ def main():
seg_class=args.seg_class,
center_distance_threshold=args.center_distance_threshold,
boundary_distance_threshold=args.boundary_distance_threshold,
fg_threshold=args.fg_threshold,
fg_threshold=args.fg_threshold, distance_smoothing=args.distance_smoothing,
)

abs_path = os.path.abspath(args.input)
Expand All @@ -95,7 +97,7 @@ def main():
seg_class=args.seg_class,
center_distance_threshold=args.center_distance_threshold,
boundary_distance_threshold=args.boundary_distance_threshold,
fg_threshold=args.fg_threshold,
fg_threshold=args.fg_threshold, distance_smoothing=args.distance_smoothing,
)
timer_output = os.path.join(args.output_folder, "timer.json")

Expand Down
Loading