Skip to content

Commit 6c9a0f2

Browse files
committed
Add Prism style
1 parent ed8c69e commit 6c9a0f2

File tree

5 files changed

+179
-60
lines changed

5 files changed

+179
-60
lines changed

scripts/figures/plot_fig2.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from skimage.segmentation import find_boundaries
1010

1111
from util import literature_reference_values, SYNAPSE_DIR_ROOT
12+
from util import prism_style, prism_cleanup_axes
1213

1314
png_dpi = 300
1415
FILE_EXTENSION = "png"
@@ -122,6 +123,7 @@ def fig_02c(save_path, plot=False, all_versions=False):
122123
"""Scatter plot showing the precision, recall, and F1-score of SGN (distance U-Net, manual),
123124
IHC (distance U-Net, manual), and synapse detection (U-Net).
124125
"""
126+
prism_style()
125127
# precision, recall, f1-score
126128
sgn_unet = [0.887, 0.88, 0.884]
127129
sgn_annotator = [0.95, 0.849, 0.9]
@@ -172,7 +174,7 @@ def fig_02c(save_path, plot=False, all_versions=False):
172174
offset = 0.08 # horizontal shift for scatter separation
173175

174176
# Plot
175-
plt.figure(figsize=(8, 5))
177+
fig, ax = plt.subplots(figsize=(8, 5))
176178

177179
main_label_size = 22
178180
sub_label_size = 16
@@ -192,11 +194,12 @@ def fig_02c(save_path, plot=False, all_versions=False):
192194
plt.yticks(fontsize=main_tick_size)
193195
plt.ylabel("Value", fontsize=main_label_size)
194196
plt.ylim(0.76, 1)
195-
plt.legend(loc="upper center", bbox_to_anchor=(0.5, 1.11),
196-
ncol=3, fancybox=True, shadow=False, framealpha=0.8, fontsize=legendsize)
197+
plt.legend(loc="lower right",
198+
fontsize=legendsize)
197199
plt.grid(axis="y", linestyle="--", alpha=0.5)
198200

199201
plt.tight_layout()
202+
prism_cleanup_axes(ax)
200203

201204
if ".png" in save_path:
202205
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)

scripts/figures/plot_fig3.py

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import imageio.v3 as imageio
77
from glob import glob
8-
from pathlib import Path
98

109
import matplotlib
1110
import matplotlib.pyplot as plt
@@ -14,7 +13,7 @@
1413
from matplotlib import cm, colors
1514

1615
from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target
17-
from util import sliding_runlength_sum, frequency_mapping, SYNAPSE_DIR_ROOT, to_alias
16+
from util import sliding_runlength_sum, frequency_mapping, prism_style, prism_cleanup_axes, SYNAPSE_DIR_ROOT
1817

1918
# INPUT_ROOT = "/home/pape/Work/my_projects/flamingo-tools/scripts/M_LR_000227_R/scale3"
2019
INPUT_ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/frequency_mapping/M_LR_000227_R/scale3"
@@ -181,7 +180,8 @@ def fig_03c_rl(save_path, plot=False):
181180
plt.close()
182181

183182

184-
def fig_03c_octave(tonotopic_data, save_path, plot=False, use_alias=True):
183+
def fig_03c_octave(tonotopic_data, save_path, plot=False, use_alias=True, trendlines=False):
184+
prism_style()
185185
ihc_version = "ihc_counts_v4c"
186186
tables = glob(os.path.join(SYNAPSE_DIR_ROOT, ihc_version, "ihc_count_M_LR*.tsv"))
187187
assert len(tables) == 4, len(tables)
@@ -207,16 +207,55 @@ def fig_03c_octave(tonotopic_data, save_path, plot=False, use_alias=True):
207207
result["x_pos"] = result["octave_band"].map(band_to_x)
208208

209209
fig, ax = plt.subplots(figsize=(8, 4))
210+
trend_dict = {}
210211
for name, grp in result.groupby("cochlea"):
211212
ax.scatter(grp["x_pos"], grp["value"], label=name, s=60, alpha=0.8)
212213

214+
if trendlines:
215+
x_positions = grp["x_pos"]
216+
sorted_idx = np.argsort(x_positions)
217+
x_sorted = np.array(x_positions)[sorted_idx]
218+
y_sorted = np.array(grp["value"])[sorted_idx]
219+
trend_dict[name] = {"x_sorted": x_sorted,
220+
"y_sorted": y_sorted,
221+
}
222+
223+
if trendlines:
224+
def get_trendline_values(trend_dict):
225+
x_sorted = [trend_dict[k]["x_sorted"] for k in trend_dict.keys()][0]
226+
y_sorted_all = [trend_dict[k]["y_sorted"] for k in trend_dict.keys()]
227+
y_sorted = []
228+
for num in range(len(x_sorted)):
229+
y_sorted.append(np.mean([y[num] for y in y_sorted_all]))
230+
return x_sorted, y_sorted
231+
232+
# Trendline left
233+
x_sorted, y_sorted = get_trendline_values(trend_dict)
234+
235+
trend, = ax.plot(
236+
x_sorted,
237+
y_sorted,
238+
linestyle="dotted",
239+
color="grey",
240+
alpha=0.7
241+
)
242+
243+
# trendline_legend = ax.legend(handles=[trend], loc='lower center')
244+
# trendline_legend = ax.legend(
245+
# handles=[trend],
246+
# labels=["Trendline"],
247+
# loc="upper left"
248+
# )
249+
# # Add the legend manually to the Axes.
250+
# ax.add_artist(trendline_legend)
251+
213252
ax.set_xticks(range(len(bin_labels)))
214253
ax.set_xticklabels(bin_labels)
215254
ax.set_xlabel("Octave band (kHz)")
216255

217-
ax.set_ylabel("Average Ribbon Synapse Count per IHC")
218-
ax.set_title("Ribbon synapse count per octave band")
256+
ax.set_ylabel("Average Ribbon Synapse Count per IHC", fontsize=10)
219257
plt.legend(title="Cochlea")
258+
prism_cleanup_axes(ax)
220259

221260
if ".png" in save_path:
222261
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
@@ -314,7 +353,7 @@ def main():
314353
# fig_03c_rl(save_path=os.path.join(args.figure_dir, f"fig_03c_runlength.{FILE_EXTENSION}"), plot=args.plot)
315354
fig_03c_octave(tonotopic_data=tonotopic_data,
316355
save_path=os.path.join(args.figure_dir, f"fig_03c_octave.{FILE_EXTENSION}"),
317-
plot=args.plot)
356+
plot=args.plot, trendlines=True)
318357

319358
# Panel D: Spatial distribution of SGN sub-types.
320359
# fig_03d_fraction(save_path=os.path.join(args.figure_dir, f"fig_03d_fraction.{FILE_EXTENSION}"), plot=args.plot)

scripts/figures/plot_fig4.py

Lines changed: 38 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pandas as pd
99
from flamingo_tools.s3_utils import BUCKET_NAME, create_s3_target
1010

11-
from util import frequency_mapping # , literature_reference_values
11+
from util import frequency_mapping, prism_style, prism_cleanup_axes # , literature_reference_values
1212

1313
# from statsmodels.nonparametric.smoothers_lowess import lowess
1414

@@ -117,6 +117,7 @@ def fig_04c(chreef_data, save_path, plot=False, plot_by_side=False, use_alias=Tr
117117
# cochlea = ["M_LR_000144_L", "M_LR_000145_L", "M_LR_000151_R"]
118118
# alias = ["c01", "c02", "c03"]
119119
# sgns = [7796, 6119, 9225]
120+
prism_style()
120121

121122
# TODO have central function for alias for all plots?
122123
if use_alias:
@@ -132,27 +133,25 @@ def fig_04c(chreef_data, save_path, plot=False, plot_by_side=False, use_alias=Tr
132133
x = np.arange(len(alias))
133134

134135
# Plot
135-
plt.figure(figsize=(8, 5))
136+
fig, ax = plt.subplots(figsize=(5, 5))
136137

137138
main_label_size = 20
138139
sub_label_size = 16
139140
main_tick_size = 16
140-
legendsize = 16
141+
legendsize = 12
141142

142143
if plot_by_side:
143-
plt.scatter(x, sgns_left, label="Left", marker="o", s=80)
144-
plt.scatter(x, sgns_right, label="Right", marker="x", s=80)
144+
plt.scatter(x, sgns_left, label="Injected", marker="o", s=80)
145+
plt.scatter(x, sgns_right, label="Non-Injected", marker="x", s=80)
145146
else:
146147
plt.scatter(x, sgns, label="SGN count", marker="o", s=80)
147148

148149
# Labels and formatting
149150
plt.xticks(x, alias, fontsize=sub_label_size)
150151
plt.yticks(fontsize=main_tick_size)
151152
plt.ylabel("SGN count per cochlea", fontsize=main_label_size)
152-
plt.ylim(4000, 13800)
153-
plt.legend(loc="best", fontsize=sub_label_size)
154-
plt.legend(loc="upper center", bbox_to_anchor=(0.5, 1.11),
155-
ncol=3, fancybox=True, shadow=False, framealpha=0.8, fontsize=legendsize)
153+
plt.ylim(4000, 15800)
154+
plt.legend(loc="upper right", fontsize=legendsize)
156155

157156
xmin = -0.5
158157
xmax = len(alias) - 0.5
@@ -172,12 +171,14 @@ def fig_04c(chreef_data, save_path, plot=False, plot_by_side=False, use_alias=Tr
172171
lower_y = sgn_value - 1.96 * sgn_std
173172

174173
plt.hlines([lower_y, upper_y], xmin, xmax, colors=["C1" for _ in range(2)])
175-
plt.text(1.5, upper_y + 100, "healthy cochleae (95% confidence interval)",
176-
color="C1", fontsize=main_tick_size, ha="center")
174+
plt.text(2, upper_y + 200, "untreated cochleae\n(95% confidence interval)",
175+
color="C1", fontsize=14, ha="center")
177176
plt.fill_between([xmin, xmax], lower_y, upper_y, color="C1", alpha=0.05, interpolate=True)
178177

179178
plt.tight_layout()
180179

180+
prism_cleanup_axes(ax)
181+
181182
if ".png" in save_path:
182183
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
183184
else:
@@ -192,6 +193,7 @@ def fig_04c(chreef_data, save_path, plot=False, plot_by_side=False, use_alias=Tr
192193
def fig_04d(chreef_data, save_path, plot=False, plot_by_side=False, intensity=False, gerbil=False, use_alias=True):
193194
"""Transduction efficiency per cochlea.
194195
"""
196+
prism_style()
195197
if use_alias:
196198
alias = [COCHLEAE_DICT[k]["alias"] for k in chreef_data.keys()]
197199
else:
@@ -219,35 +221,34 @@ def fig_04d(chreef_data, save_path, plot=False, plot_by_side=False, intensity=Fa
219221
x = np.arange(len(alias))
220222

221223
# Plot
222-
plt.figure(figsize=(8, 5))
224+
fig, ax = plt.subplots(figsize=(5, 5))
223225

224226
main_label_size = 20
225227
sub_label_size = 16
226228
main_tick_size = 16
227-
legendsize = 16
229+
legendsize = 12
228230

229231
label = "Intensity" if intensity else "Transduction efficiency"
230232

231233
if plot_by_side:
232-
plt.scatter(x, values_left, label="Left", marker="o", s=80)
233-
plt.scatter(x, values_right, label="Right", marker="x", s=80)
234+
plt.scatter(x, values_left, label="Injected", marker="o", s=80)
235+
plt.scatter(x, values_right, label="Non-Injected", marker="x", s=80)
234236
else:
235237
plt.scatter(x, values, label=label, marker="o", s=80)
236238

237239
# Labels and formatting
238240
plt.xticks(x, alias, fontsize=sub_label_size)
239241
plt.yticks(fontsize=main_tick_size)
240242
plt.ylabel(label, fontsize=main_label_size)
241-
plt.legend(loc="best", fontsize=sub_label_size)
242-
plt.legend(loc="upper center", bbox_to_anchor=(0.5, 1.11),
243-
ncol=3, fancybox=True, shadow=False, framealpha=0.8, fontsize=legendsize)
243+
plt.legend(loc="upper right", fontsize=legendsize)
244244
if not intensity:
245245
if gerbil:
246246
plt.ylim(0.3, 1.05)
247247
else:
248248
plt.ylim(0.5, 1.05)
249249

250250
plt.tight_layout()
251+
prism_cleanup_axes(ax)
251252

252253
if ".png" in save_path:
253254
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
@@ -261,6 +262,7 @@ def fig_04d(chreef_data, save_path, plot=False, plot_by_side=False, intensity=Fa
261262

262263

263264
def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False, use_alias=True, trendlines=False):
265+
prism_style()
264266

265267
result = {"cochlea": [], "octave_band": [], "value": []}
266268
for name, values in chreef_data.items():
@@ -295,15 +297,14 @@ def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False, use_ali
295297
if intensity:
296298
band_label_offset_y = 0.09
297299
else:
298-
band_label_offset_y = 0.07
300+
band_label_offset_y = 0.09
299301
if gerbil:
300302
ax.set_ylim(0.05, 1.05)
301303
else:
302304
ax.set_ylim(0.45, 1.05)
303305

304306
# Offsets within each octave band
305307
offset_map = {"L": -0.15, "R": 0.15}
306-
sublabels = {"L": "L", "R": "R"}
307308

308309
# Assign a color to each cochlea (ignoring side)
309310
cochleas = sorted({name_lr[:-1] for name_lr in result["cochlea"].unique()})
@@ -315,9 +316,6 @@ def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False, use_ali
315316
# Track which cochlea names we have already added to the legend
316317
legend_added = set()
317318

318-
all_x_positions = []
319-
all_x_labels = []
320-
321319
trend_dict = {}
322320

323321
for name_lr, grp in result.groupby("cochlea"):
@@ -352,10 +350,6 @@ def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False, use_ali
352350
"side": side,
353351
}
354352

355-
# Store for sublabel ticks
356-
all_x_positions.extend(x_positions)
357-
all_x_labels.extend([sublabels[side]] * len(x_positions))
358-
359353
if trendlines:
360354
def get_trendline_values(trend_dict, side):
361355
x_sorted = [trend_dict[k]["x_sorted"] for k in trend_dict.keys() if trend_dict[k]["side"] == side][0]
@@ -387,7 +381,7 @@ def get_trendline_values(trend_dict, side):
387381
trendline_legend = ax.legend(handles=[trend_l, trend_r], loc='lower center')
388382
trendline_legend = ax.legend(
389383
handles=[trend_l, trend_r],
390-
labels=["Left", "Right"],
384+
labels=["Injected", "Non-Injected"],
391385
loc="lower center",
392386
fontsize=legend_size,
393387
title="Trendlines"
@@ -398,10 +392,9 @@ def get_trendline_values(trend_dict, side):
398392
# Create combined tick positions & labels
399393
main_ticks = range(len(bin_labels))
400394
# add a final tick for label '>64k'
401-
ax.set_xticks([pos + offset_map["L"] for pos in main_ticks[:-1]] +
402-
[pos + offset_map["R"] for pos in main_ticks[:-1]] +
403-
[pos for pos in main_ticks[-1:]])
404-
ax.set_xticklabels(["L"] * len(main_ticks[:-1]) + ["R"] * len(main_ticks[:-1]) + [""], fontsize=sub_tick_label_size)
395+
ax.set_xticks([pos + offset_map["L"] for pos in main_ticks] +
396+
[pos + offset_map["R"] for pos in main_ticks])
397+
ax.set_xticklabels(["I"] * len(main_ticks) + ["N"] * len(main_ticks), fontsize=sub_tick_label_size)
405398

406399
# Add main octave band labels above sublabels
407400
for i, label in enumerate(bin_labels):
@@ -416,11 +409,11 @@ def get_trendline_values(trend_dict, side):
416409
ax.set_title("Intensity per octave band (Left/Right)")
417410
else:
418411
ax.set_ylabel("Transduction Efficiency", fontsize=label_size)
419-
ax.set_title("Transduction efficiency per octave band (Left/Right)")
420412

421413
ax.legend(title="Cochlea", fontsize=legend_size)
422414

423415
plt.tight_layout()
416+
prism_cleanup_axes(ax)
424417

425418
if ".png" in save_path:
426419
plt.savefig(save_path, bbox_inches="tight", pad_inches=0.1, dpi=png_dpi)
@@ -462,33 +455,33 @@ def main():
462455
fig_04d(chreef_data,
463456
save_path=os.path.join(args.figure_dir, f"fig_04d_transduction.{FILE_EXTENSION}"),
464457
plot=args.plot, plot_by_side=True, use_alias=use_alias)
465-
fig_04d(chreef_data,
466-
save_path=os.path.join(args.figure_dir, f"fig_04d_intensity.{FILE_EXTENSION}"),
467-
plot=args.plot, plot_by_side=True, intensity=True, use_alias=use_alias)
458+
# fig_04d(chreef_data,
459+
# save_path=os.path.join(args.figure_dir, f"fig_04d_intensity.{FILE_EXTENSION}"),
460+
# plot=args.plot, plot_by_side=True, intensity=True, use_alias=use_alias)
468461

469462
fig_04e(chreef_data,
470463
save_path=os.path.join(args.figure_dir, f"fig_04e_transduction.{FILE_EXTENSION}"),
471464
plot=args.plot, use_alias=use_alias, trendlines=True)
472-
fig_04e(chreef_data,
473-
save_path=os.path.join(args.figure_dir, f"fig_04e_intensity.{FILE_EXTENSION}"),
474-
plot=args.plot, intensity=True, use_alias=use_alias)
465+
# fig_04e(chreef_data,
466+
# save_path=os.path.join(args.figure_dir, f"fig_04e_intensity.{FILE_EXTENSION}"),
467+
# plot=args.plot, intensity=True, use_alias=use_alias)
475468

476469
chreef_data_gerbil = get_chreef_data(animal="gerbil")
477470
fig_04d(chreef_data_gerbil,
478471
save_path=os.path.join(args.figure_dir, f"fig_04d_gerbil_transduction.{FILE_EXTENSION}"),
479472
plot=args.plot, plot_by_side=True, gerbil=True, use_alias=use_alias)
480473

481-
fig_04d(chreef_data_gerbil,
482-
save_path=os.path.join(args.figure_dir, f"fig_04d_gerbil_intensity.{FILE_EXTENSION}"),
483-
plot=args.plot, plot_by_side=True, intensity=True, use_alias=use_alias)
474+
# fig_04d(chreef_data_gerbil,
475+
# save_path=os.path.join(args.figure_dir, f"fig_04d_gerbil_intensity.{FILE_EXTENSION}"),
476+
# plot=args.plot, plot_by_side=True, intensity=True, use_alias=use_alias)
484477

485478
fig_04e(chreef_data_gerbil,
486479
save_path=os.path.join(args.figure_dir, f"fig_04e_gerbil_transduction.{FILE_EXTENSION}"),
487480
plot=args.plot, gerbil=True, use_alias=use_alias)
488481

489-
fig_04e(chreef_data_gerbil,
490-
save_path=os.path.join(args.figure_dir, f"fig_04e_gerbil_intensity.{FILE_EXTENSION}"),
491-
plot=args.plot, intensity=True, use_alias=use_alias)
482+
# fig_04e(chreef_data_gerbil,
483+
# save_path=os.path.join(args.figure_dir, f"fig_04e_gerbil_intensity.{FILE_EXTENSION}"),
484+
# plot=args.plot, intensity=True, use_alias=use_alias)
492485

493486

494487
if __name__ == "__main__":

0 commit comments

Comments
 (0)