Skip to content

Commit c51f624

Browse files
Update plotting scripts
1 parent 0a58bcf commit c51f624

File tree

5 files changed

+316
-68
lines changed

5 files changed

+316
-68
lines changed

scripts/figures/plot_fig2.py

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
import argparse
22
import os
3+
from glob import glob
34

45
import numpy as np
56
import pandas as pd
67
import matplotlib.pyplot as plt
78

9+
from util import literature_reference_values
10+
811
png_dpi = 300
912

1013

11-
def fig_02c(save_path, plot=False):
14+
def fig_02c(save_path, plot=False, all_versions=False):
1215
"""Scatter plot showing the precision, recall, and F1-score of SGN (distance U-Net, manual),
1316
IHC (distance U-Net, manual), and synapse detection (U-Net).
1417
"""
15-
setting = ["U-Net", "manual", "U-Net", "manual", "U-Net"]
16-
1718
# precision, recall, f1-score
1819
sgn_unet = [0.887, 0.88, 0.884]
1920
sgn_annotator = [0.95, 0.849, 0.9]
@@ -25,19 +26,32 @@ def fig_02c(save_path, plot=False):
2526
ihc_annotator = [0.958, 0.956, 0.957]
2627
syn_unet = [0.931, 0.905, 0.918]
2728

29+
# This is the version with IHC v4b segmentation:
30+
# 4th version of the network with optimized segmentation params
2831
version_1 = [sgn_unet, sgn_annotator, ihc_v4b, ihc_annotator, syn_unet]
29-
settings_1 = ["SGN_v2", "manual", "IHC_v4b", "manual", "U-Net"]
32+
settings_1 = ["automatic", "manual", "automatic", "manual", "automatic"]
3033

34+
# This is the version with IHC v4c segmentation:
35+
# 4th version of the network with optimized segmentation params and split of falsely merged IHCs
3136
version_2 = [sgn_unet, sgn_annotator, ihc_v4c, ihc_annotator, syn_unet]
32-
settings_2 = ["SGN_v2", "manual", "IHC_v4c", "manual", "U-Net"]
37+
settings_2 = ["automatic", "manual", "automatic", "manual", "automatic"]
3338

39+
# This is the version with IHC v4c + filter segmentation:
40+
# 4th version of the network with optimized segmentation params and split of falsely merged IHCs
41+
# + filtering out IHCs with zero mapped synapses.
3442
version_3 = [sgn_unet, sgn_annotator, ihc_v4c_filter, ihc_annotator, syn_unet]
35-
settings_3 = ["SGN_v2", "manual", "IHC_v4c_filter", "manual", "U-Net"]
43+
settings_3 = ["automatic", "manual", "automatic", "manual", "automatic"]
3644

37-
versions = [version_1, version_2, version_3]
38-
settings = [settings_1, settings_2, settings_3]
39-
save_suffix = ["v4b", "v4c", "v4c_filter"]
40-
save_paths = [save_path + i for i in save_suffix]
45+
if all_versions:
46+
versions = [version_1, version_2, version_3]
47+
settings = [settings_1, settings_2, settings_3]
48+
save_suffix = ["_v4b", "_v4c", "_v4c_filter"]
49+
save_paths = [save_path + i for i in save_suffix]
50+
else:
51+
versions = [version_2]
52+
settings = [settings_2]
53+
save_suffix = ["_v4c"]
54+
save_paths = [save_path + i for i in save_suffix]
4155

4256
for version, setting, save_path in zip(versions, settings, save_paths):
4357
precision = [i[0] for i in version]
@@ -84,22 +98,37 @@ def fig_02c(save_path, plot=False):
8498
plt.close()
8599

86100

87-
def fig_02d_01(save_path, plot=False):
101+
# Load the synapse counts for all IHCs from the relevant tables.
102+
def _load_ribbon_synapse_counts():
103+
tables = glob("ihc_counts/*M_LR*.tsv")
104+
syn_counts = []
105+
for tab in tables:
106+
x = pd.read_csv(tab, sep="\t")
107+
syn_counts.extend(x["synapse_count"].values.tolist())
108+
return syn_counts
109+
110+
111+
def fig_02d_01(save_path, plot=False, all_versions=False, plot_average_ribbon_synapses=False):
88112
"""Box plot showing the counts for SGN and IHC per (mouse) cochlea in comparison to literature values.
89113
"""
90114
main_tick_size = 16
91115
main_label_size = 24
92116

93117
rows = 1
94-
columns = 2
118+
columns = 3 if plot_average_ribbon_synapses else 2
95119

96120
sgn_values = [11153, 11398, 10333, 11820]
97121
ihc_v4b_values = [836, 808, 796, 901]
98122
ihc_v4c_values = [712, 710, 721, 675]
99123
ihc_v4c_filtered_values = [562, 647, 626, 628]
100124

101-
ihc_list = [ihc_v4b_values, ihc_v4c_values, ihc_v4c_filtered_values]
102-
suffixes = ["_v4b", "_v4c", "_v4c_filtered"]
125+
if all_versions:
126+
ihc_list = [ihc_v4b_values, ihc_v4c_values, ihc_v4c_filtered_values]
127+
suffixes = ["_v4b", "_v4c", "_v4c_filtered"]
128+
assert not plot_average_ribbon_synapses
129+
else:
130+
ihc_list = [ihc_v4c_values]
131+
suffixes = ["_v4c"]
103132

104133
for (ihc_values, suffix) in zip(ihc_list, suffixes):
105134
fig, axes = plt.subplots(rows, columns, figsize=(columns*4, rows*4))
@@ -126,34 +155,46 @@ def fig_02d_01(save_path, plot=False):
126155
xmin = 0.5
127156
xmax = 1.5
128157
ax[0].set_xlim(xmin, xmax)
129-
upper_y = 12000
130-
lower_y = 10000
158+
lower_y, upper_y = literature_reference_values("SGN")
131159
ax[0].hlines([lower_y, upper_y], xmin, xmax)
132-
ax[0].text(1, upper_y + 100, "literature reference (WIP)", color="C0", fontsize=main_tick_size, ha="center")
160+
ax[0].text(1, upper_y + 100, "literature", color="C0", fontsize=main_tick_size, ha="center")
133161
ax[0].fill_between([xmin, xmax], lower_y, upper_y, color="C0", alpha=0.05, interpolate=True)
134162

135163
ylim0 = 500
136164
ylim1 = 950
137165
y_ticks = [i for i in range(500, 900 + 1, 100)]
138166

139167
ax[1].set_xticklabels(["IHC"], fontsize=main_label_size)
140-
141168
ax[1].set_yticks(y_ticks)
142169
ax[1].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size)
143170
ax[1].set_ylim(ylim0, ylim1)
144-
ax[1].yaxis.tick_right()
145-
ax[1].yaxis.set_ticks_position("right")
171+
if not plot_average_ribbon_synapses:
172+
ax[1].yaxis.tick_right()
173+
ax[1].yaxis.set_ticks_position("right")
146174

147175
# set range of literature values
148176
xmin = 0.5
149177
xmax = 1.5
178+
lower_y, upper_y = literature_reference_values("IHC")
150179
ax[1].set_xlim(xmin, xmax)
151-
upper_y = 850
152-
lower_y = 780
153180
ax[1].hlines([lower_y, upper_y], xmin, xmax)
154-
ax[1].text(1, lower_y - 10, "literature reference (WIP)", color="C0", fontsize=main_tick_size, ha="center")
181+
ax[1].text(1, upper_y + 20, "literature", color="C0", fontsize=main_tick_size, ha="center")
155182
ax[1].fill_between([xmin, xmax], lower_y, upper_y, color="C0", alpha=0.05, interpolate=True)
156183

184+
if plot_average_ribbon_synapses:
185+
ribbon_synapse_counts = _load_ribbon_synapse_counts()
186+
# ylim0 = 4.9
187+
# ylim1 = 25.1
188+
y_ticks = [0, 10, 20, 30, 40, 50]
189+
190+
ax[2].boxplot(ribbon_synapse_counts)
191+
ax[2].set_xticklabels(["Ribbon Syn. per IHC"], fontsize=main_label_size)
192+
ax[2].set_yticks(y_ticks)
193+
ax[2].set_yticklabels(y_ticks, rotation=0, fontsize=main_tick_size)
194+
# ax[2].set_ylim(ylim0, ylim1)
195+
196+
# TODO range of values from literature.
197+
157198
plt.tight_layout()
158199
plt.savefig(save_path_new, dpi=png_dpi)
159200

@@ -249,12 +290,17 @@ def main():
249290

250291
os.makedirs(args.figure_dir, exist_ok=True)
251292

252-
fig_02c(save_path=os.path.join(args.figure_dir, "fig_02c"), plot=args.plot)
253-
fig_02d_01(save_path=os.path.join(args.figure_dir, "fig_02d_01"), plot=args.plot)
254-
fig_02d_02(save_path=os.path.join(args.figure_dir, "fig_02d_02"), plot=args.plot)
293+
# Panel C: Evaluation of the segmentation results:
294+
fig_02c(save_path=os.path.join(args.figure_dir, "fig_02c"), plot=args.plot, all_versions=False)
295+
296+
# Panel D: The number of SGNs, IHCs and average number of ribbon synapses per IHC
297+
fig_02d_01(save_path=os.path.join(args.figure_dir, "fig_02d"), plot=args.plot, plot_average_ribbon_synapses=True)
298+
299+
# Alternative version of synapse distribution for panel D.
300+
# fig_02d_02(save_path=os.path.join(args.figure_dir, "fig_02d_02"), plot=args.plot)
255301
# fig_02d_02(save_path=os.path.join(args.figure_dir, "fig_02d_02_v4c"), filter_zeros=False, plot=plot)
256302
# fig_02d_02(save_path=os.path.join(args.figure_dir, "fig_02d_02_v4c_filtered"), filter_zeros=True, plot=plot)
257-
fig_02d_02(save_path=os.path.join(args.figure_dir, "fig_02d_02_v4b"), filter_zeros=True, plot=args.plot)
303+
# fig_02d_02(save_path=os.path.join(args.figure_dir, "fig_02d_02_v4b"), filter_zeros=True, plot=args.plot)
258304

259305

260306
if __name__ == "__main__":

scripts/figures/plot_fig3.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import argparse
2+
import os
3+
import imageio.v3 as imageio
4+
from glob import glob
5+
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
import pandas as pd
9+
from matplotlib import cm, colors
10+
11+
from util import sliding_runlength_sum, frequency_mapping
12+
13+
INPUT_ROOT = "/home/pape/Work/my_projects/flamingo-tools/scripts/M_LR_000227_R/scale3/frequency_mapping"
14+
15+
png_dpi = 300
16+
17+
18+
def fig_03a(save_path):
19+
import napari
20+
21+
path = os.path.join(INPUT_ROOT, "frequencies_IHC_v4c.tif")
22+
vol = imageio.imread(path)
23+
24+
# Create the colormap
25+
fig, ax = plt.subplots(figsize=(6, 1.3))
26+
fig.subplots_adjust(bottom=0.5)
27+
28+
freq_min = np.min(np.nonzero(vol))
29+
freq_max = vol.max()
30+
norm = colors.Normalize(vmin=freq_min, vmax=freq_max, clip=True)
31+
cmap = plt.get_cmap("viridis")
32+
33+
cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax, orientation="horizontal")
34+
cb.set_label("Frequency [kHz]")
35+
plt.title("Tonotopic Mapping: IHCs")
36+
plt.tight_layout()
37+
out_path = os.path.join(save_path)
38+
plt.savefig(out_path)
39+
40+
# Show the image in napari for rendering.
41+
v = napari.Viewer()
42+
v.add_image(vol, colormap="viridis")
43+
napari.run()
44+
45+
46+
def fig_03b(save_path):
47+
import napari
48+
49+
path = os.path.join(INPUT_ROOT, "frequencies_SGN_v2.tif")
50+
vol = imageio.imread(path)
51+
52+
# Create the colormap
53+
fig, ax = plt.subplots(figsize=(6, 1.3))
54+
fig.subplots_adjust(bottom=0.5)
55+
56+
freq_min = np.min(np.nonzero(vol))
57+
freq_max = vol.max()
58+
norm = colors.Normalize(vmin=freq_min, vmax=freq_max, clip=True)
59+
cmap = plt.get_cmap("viridis")
60+
61+
cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), cax=ax, orientation="horizontal")
62+
cb.set_label("Frequency [kHz]")
63+
plt.title("Tonotopic Mapping: SGNs")
64+
plt.tight_layout()
65+
out_path = os.path.join(save_path)
66+
plt.savefig(out_path)
67+
68+
# Show the image in napari for rendering.
69+
v = napari.Viewer()
70+
v.add_image(vol, colormap="viridis")
71+
napari.run()
72+
73+
74+
def fig_03c_rl(save_path, plot=False):
75+
tables = glob("./ihc_counts/ihc_count_M_LR*.tsv")
76+
fig, ax = plt.subplots(figsize=(8, 4))
77+
78+
width = 50 # micron
79+
80+
for tab_path in tables:
81+
# TODO map to alias
82+
alias = os.path.basename(tab_path)[10:-4].replace("_", "").replace("0", "")
83+
tab = pd.read_csv(tab_path, sep="\t")
84+
run_length = tab["run_length"].values
85+
syn_count = tab["synapse_count"].values
86+
87+
# Compute the running sum of 10 micron.
88+
run_length, syn_count_running = sliding_runlength_sum(run_length, syn_count, width=width)
89+
ax.plot(run_length, syn_count_running, label=alias)
90+
91+
ax.set_xlabel("Length [µm]")
92+
ax.set_ylabel("Synapse Count")
93+
ax.set_title(f"Ribbon Syn. per IHC: Runnig sum @ {width} µm")
94+
ax.legend(title="cochlea")
95+
plt.tight_layout()
96+
97+
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
98+
if plot:
99+
plt.show()
100+
else:
101+
plt.close()
102+
103+
104+
def fig_03c_octave(save_path, plot=False):
105+
tables = glob("./ihc_counts/ihc_count_M_LR*.tsv")
106+
107+
result = {"cochlea": [], "octave_band": [], "value": []}
108+
for tab_path in tables:
109+
# TODO map to alias
110+
alias = os.path.basename(tab_path)[10:-4].replace("_", "").replace("0", "")
111+
tab = pd.read_csv(tab_path, sep="\t")
112+
freq = tab["frequency"].values
113+
syn_count = tab["synapse_count"].values
114+
115+
# Compute the running sum of 10 micron.
116+
octave_binned = frequency_mapping(freq, syn_count, animal="mouse")
117+
118+
result["cochlea"].extend([alias] * len(octave_binned))
119+
result["octave_band"].extend(octave_binned.axes[0].values.tolist())
120+
result["value"].extend(octave_binned.values.tolist())
121+
122+
result = pd.DataFrame(result)
123+
bin_labels = pd.unique(result["octave_band"])
124+
band_to_x = {band: i for i, band in enumerate(bin_labels)}
125+
result["x_pos"] = result["octave_band"].map(band_to_x)
126+
127+
fig, ax = plt.subplots(figsize=(8, 4))
128+
for name, grp in result.groupby("cochlea"):
129+
ax.scatter(grp["x_pos"], grp["value"], label=name, s=60, alpha=0.8)
130+
131+
ax.set_xticks(range(len(bin_labels)))
132+
ax.set_xticklabels(bin_labels)
133+
ax.set_xlabel("Octave band (kHz)")
134+
135+
ax.set_ylabel("Average Ribbon Synapse Count per IHC")
136+
ax.set_title("Ribbon synapse count per octave band")
137+
ax.legend(title="Cochlea")
138+
139+
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
140+
if plot:
141+
plt.show()
142+
else:
143+
plt.close()
144+
145+
146+
def main():
147+
parser = argparse.ArgumentParser(description="Generate plots for Fig 3 of the cochlea paper.")
148+
parser.add_argument("--figure_dir", "-f", type=str, help="Output directory for plots.", default="./panels/fig3")
149+
parser.add_argument("--plot", action="store_true")
150+
args = parser.parse_args()
151+
152+
os.makedirs(args.figure_dir, exist_ok=True)
153+
154+
# Panel A: Tonotopic mapping of IHCs (rendering in napari)
155+
# fig_03a(save_path=os.path.join(args.figure_dir, "fig_03a.png"))
156+
157+
# Panel B: Tonotopic mapping of SGNs (rendering in napari)
158+
# fig_03b(save_path=os.path.join(args.figure_dir, "fig_03b.png"))
159+
160+
# Panel C: Spatial distribution of synapses across the cochlea.
161+
# We have two options: running sum over the runlength or per octave band
162+
fig_03c_rl(save_path=os.path.join(args.figure_dir, "fig_03c_runlength.png"), plot=args.plot)
163+
fig_03c_octave(save_path=os.path.join(args.figure_dir, "fig_03c_octave.png"), plot=args.plot)
164+
165+
# TODO: Panel D: Spatial distribution of SGN sub-types.
166+
167+
168+
if __name__ == "__main__":
169+
main()

0 commit comments

Comments
 (0)