Skip to content

Commit 06951f8

Browse files
committed
Added standard deviation to trendlines
1 parent 703f2e2 commit 06951f8

File tree

1 file changed

+155
-89
lines changed

1 file changed

+155
-89
lines changed

scripts/figures/plot_fig4.py

Lines changed: 155 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,22 @@
3333
"G_EK_000049_R": {"alias": "G1R", "component": [1, 2]},
3434
}
3535

36-
COLORS_ANIMAL = {
36+
COLORS_ANIMAL_01 = {
3737
"M05" : "#DB3000",
3838
"M06" : "#DB0063",
3939
"M07" : "#8F00DB",
4040
"M08" : "#0004DB",
4141
"M09" : "#0093DB"
4242
}
4343

44+
COLORS_ANIMAL = {
45+
"M05" : "#9C5027",
46+
"M06" : "#279C52",
47+
"M07" : "#67279C",
48+
"M08" : "#27339C",
49+
"M09" : "#9C276F"
50+
}
51+
4452
COLORS_LEFT = {
4553
"M05R" : "#A600FF",
4654
"M06R" : "#8F00DB",
@@ -294,17 +302,17 @@ def fig_04c(chreef_data, save_path, plot=False, grouping="side_mono", use_alias=
294302

295303
elif grouping == "side_multi":
296304
for num, key in enumerate(COLORS_LEFT.keys()):
297-
plt.scatter(x_pos_inj[num], values_left[num], label="Injected",
305+
plt.scatter(x_pos_inj[num], values_left[num], label=f"{alias[num]}L",
298306
color=COLORS_LEFT[key], marker=MARKER_LEFT, s=80, zorder=1)
299307
for num, key in enumerate(COLORS_RIGHT.keys()):
300-
plt.scatter(x_pos_non[num], values_right[num], label="Non-Injected",
308+
plt.scatter(x_pos_non[num], values_right[num], label=f"{alias[num]}R",
301309
color=COLORS_RIGHT[key], marker=MARKER_RIGHT, s=80, zorder=1)
302310

303311
elif grouping == "animal":
304312
for num, key in enumerate(COLORS_ANIMAL.keys()):
305-
plt.scatter(x_pos_inj[num], values_left[num], label="Injected",
313+
plt.scatter(x_pos_inj[num], values_left[num], label=f"{alias[num]}",
306314
color=COLORS_ANIMAL[key], marker=MARKER_LEFT, s=80, zorder=1)
307-
plt.scatter(x_pos_non[num], values_right[num], label="Non-Injected",
315+
plt.scatter(x_pos_non[num], values_right[num],
308316
color=COLORS_ANIMAL[key], marker=MARKER_RIGHT, s=80, zorder=1)
309317

310318
else:
@@ -315,7 +323,6 @@ def fig_04c(chreef_data, save_path, plot=False, grouping="side_mono", use_alias=
315323
plt.yticks(y_ticks, fontsize=main_tick_size)
316324
plt.ylabel("SGN count per cochlea", fontsize=main_label_size)
317325
plt.ylim(5000, 14000)
318-
# plt.legend(loc="upper right", fontsize=legendsize)
319326

320327
xmin = 0.5
321328
xmax = 2.5
@@ -401,17 +408,17 @@ def fig_04d(chreef_data, save_path, plot=False, grouping="animal", intensity=Fal
401408

402409
elif grouping == "side_multi":
403410
for num, key in enumerate(COLORS_LEFT.keys()):
404-
plt.scatter(x_pos_inj[num], values_left[num], label="Injected",
411+
plt.scatter(x_pos_inj[num], values_left[num], label=f"{alias[num]}L",
405412
color=COLORS_LEFT[key], marker=MARKER_LEFT, s=80, zorder=1)
406413
for num, key in enumerate(COLORS_RIGHT.keys()):
407-
plt.scatter(x_pos_non[num], values_right[num], label="Non-Injected",
414+
plt.scatter(x_pos_non[num], values_right[num], label=f"{alias[num]}R",
408415
color=COLORS_RIGHT[key], marker=MARKER_RIGHT, s=80, zorder=1)
409416

410417
elif grouping == "animal":
411418
for num, key in enumerate(COLORS_ANIMAL.keys()):
412-
plt.scatter(x_pos_inj[num], values_left[num], label="Injected",
419+
plt.scatter(x_pos_inj[num], values_left[num], label=f"{alias[num]}",
413420
color=COLORS_ANIMAL[key], marker=MARKER_LEFT, s=80, zorder=1)
414-
plt.scatter(x_pos_non[num], values_right[num], label="Non-Injected",
421+
plt.scatter(x_pos_non[num], values_right[num],
415422
color=COLORS_ANIMAL[key], marker=MARKER_RIGHT, s=80, zorder=1)
416423

417424
else:
@@ -509,6 +516,7 @@ def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False,
509516
prism_style()
510517

511518
result = {"cochlea": [], "octave_band": [], "value": []}
519+
aliases = []
512520
for name, values in chreef_data.items():
513521
if use_alias:
514522
alias = COCHLEAE_DICT[name]["alias"]
@@ -526,13 +534,32 @@ def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False,
526534
result["cochlea"].extend([alias] * len(octave_binned))
527535
result["octave_band"].extend(octave_binned.axes[0].values.tolist())
528536
result["value"].extend(octave_binned.values.tolist())
537+
aliases.append(alias)
538+
539+
if gerbil:
540+
values = []
541+
for vals in chreef_data.values():
542+
if intensity:
543+
intensities = vals["median"].values
544+
values.append(intensities.mean())
545+
else:
546+
# marker labels
547+
# 0: unlabeled - no median intensity in object_measures table
548+
# 1: positive
549+
# 2: negative
550+
marker_labels = vals["marker_labels"].values
551+
n_pos = (marker_labels == 1).sum()
552+
n_neg = (marker_labels == 2).sum()
553+
eff = float(n_pos) / (n_pos + n_neg)
554+
values.append(eff)
555+
alias, values_left, values_right = group_lr(aliases, values)
529556

530557
result = pd.DataFrame(result)
531558
bin_labels = pd.unique(result["octave_band"])
532559
band_to_x = {band: i for i, band in enumerate(bin_labels)}
533560
result["x_pos"] = result["octave_band"].map(band_to_x)
534561

535-
fig, ax = plt.subplots(figsize=(8, 4))
562+
fig, ax = plt.subplots(figsize=(8, 5))
536563

537564
sub_tick_label_size = 8
538565
tick_label_size = 12
@@ -623,92 +650,127 @@ def fig_04e(chreef_data, save_path, plot, intensity=False, gerbil=False,
623650
}
624651

625652
if trendlines:
626-
def get_trendline_values(trend_dict, side):
627-
x_sorted = [trend_dict[k]["x_sorted"] for k in trend_dict.keys() if trend_dict[k]["side"] == side][0]
628-
y_sorted_all = [trend_dict[k]["y_sorted"] for k in trend_dict.keys() if trend_dict[k]["side"] == side]
629-
y_sorted = []
630-
for num in range(len(x_sorted)):
631-
y_sorted.append(np.mean([y[num] for y in y_sorted_all]))
632-
return x_sorted, y_sorted
633-
634-
# Trendline Injected (Left)
635-
x_sorted, y_sorted = get_trendline_values(trend_dict, "L")
636-
x_sorted, y_sorted, y_sorted_upper, y_sorted_lower = _get_trendline_params(trend_dict, "L")
637-
638-
# central line
639-
trend_l, = ax.plot(
640-
x_sorted,
641-
y_sorted,
642-
linestyle="dotted",
643-
color=COLOR_LEFT,
644-
alpha=0.7,
645-
zorder=0
646-
)
647-
648-
if trendline_std:
649-
# upper and lower standard deviation
650-
trend_l_upper, = ax.plot(
653+
if not gerbil:
654+
def get_trendline_values(trend_dict, side):
655+
x_sorted = [trend_dict[k]["x_sorted"] for k in trend_dict.keys() if trend_dict[k]["side"] == side][0]
656+
y_sorted_all = [trend_dict[k]["y_sorted"] for k in trend_dict.keys() if trend_dict[k]["side"] == side]
657+
y_sorted = []
658+
for num in range(len(x_sorted)):
659+
y_sorted.append(np.mean([y[num] for y in y_sorted_all]))
660+
return x_sorted, y_sorted
661+
662+
# Trendline Injected (Left)
663+
x_sorted, y_sorted = get_trendline_values(trend_dict, "L")
664+
x_sorted, y_sorted, y_sorted_upper, y_sorted_lower = _get_trendline_params(trend_dict, "L")
665+
666+
if grouping == "animal":
667+
color_trend_l = "gray"
668+
color_trend_r = "gray"
669+
else:
670+
color_trend_l = COLOR_LEFT
671+
color_trend_r = COLOR_RIGHT
672+
673+
# central line
674+
trend_l, = ax.plot(
651675
x_sorted,
652-
y_sorted_upper,
653-
linestyle="solid",
654-
color=COLOR_LEFT,
655-
alpha=0.08,
676+
y_sorted,
677+
linestyle="dotted",
678+
color=color_trend_l,
679+
alpha=0.7,
656680
zorder=0
657681
)
658-
trend_l_lower, = ax.plot(
682+
683+
if trendline_std:
684+
# upper and lower standard deviation
685+
trend_l_upper, = ax.plot(
686+
x_sorted,
687+
y_sorted_upper,
688+
linestyle="solid",
689+
color=color_trend_l,
690+
alpha=0.08,
691+
zorder=0
692+
)
693+
trend_l_lower, = ax.plot(
694+
x_sorted,
695+
y_sorted_lower,
696+
linestyle="solid",
697+
color=color_trend_l,
698+
alpha=0.08,
699+
zorder=0
700+
)
701+
plt.fill_between(x_sorted, y_sorted_lower, y_sorted_upper, color=COLOR_LEFT, alpha=0.05, interpolate=True)
702+
703+
# Trendline Non-Injected (Right)
704+
x_sorted, y_sorted = get_trendline_values(trend_dict, "R")
705+
x_sorted, y_sorted, y_sorted_upper, y_sorted_lower = _get_trendline_params(trend_dict, "R")
706+
# central line
707+
trend_r, = ax.plot(
659708
x_sorted,
660-
y_sorted_lower,
661-
linestyle="solid",
662-
color=COLOR_LEFT,
663-
alpha=0.08,
709+
y_sorted,
710+
linestyle="dashed",
711+
color=color_trend_r,
712+
alpha=0.7,
664713
zorder=0
665714
)
666-
plt.fill_between(x_sorted, y_sorted_lower, y_sorted_upper, color=COLOR_LEFT, alpha=0.05, interpolate=True)
667-
668-
# Trendline Non-Injected (Right)
669-
x_sorted, y_sorted = get_trendline_values(trend_dict, "R")
670-
x_sorted, y_sorted, y_sorted_upper, y_sorted_lower = _get_trendline_params(trend_dict, "R")
671-
# central line
672-
trend_r, = ax.plot(
673-
x_sorted,
674-
y_sorted,
675-
linestyle="dashed",
676-
color=COLOR_RIGHT,
677-
alpha=0.7,
678-
zorder=0
679-
)
680715

681-
if trendline_std:
682-
# upper and lower standard deviation
683-
trend_r_upper, = ax.plot(
716+
if trendline_std:
717+
# upper and lower standard deviation
718+
trend_r_upper, = ax.plot(
719+
x_sorted,
720+
y_sorted_upper,
721+
linestyle="solid",
722+
color=color_trend_r,
723+
alpha=0.08,
724+
zorder=0
725+
)
726+
trend_r_lower, = ax.plot(
727+
x_sorted,
728+
y_sorted_lower,
729+
linestyle="solid",
730+
color=color_trend_r,
731+
alpha=0.08,
732+
zorder=0
733+
)
734+
plt.fill_between(x_sorted, y_sorted_lower, y_sorted_upper, color=COLOR_RIGHT, alpha=0.05, interpolate=True)
735+
736+
# Trendline legend
737+
trendline_legend = ax.legend(handles=[trend_l, trend_r], loc='lower center')
738+
trendline_legend = ax.legend(
739+
handles=[trend_l, trend_r],
740+
labels=["Injected", "Non-Injected"],
741+
loc="lower left",
742+
fontsize=legend_size,
743+
title="Trendlines"
744+
)
745+
# Add the legend manually to the Axes.
746+
ax.add_artist(trendline_legend)
747+
else:
748+
x_sorted = [trend_dict[k]["x_sorted"] for k in trend_dict.keys() if trend_dict[k]["side"] == "L"][0]
749+
y_left = [values_left[0] for _ in x_sorted]
750+
y_right = [values_right[0] for _ in x_sorted]
751+
if grouping == "animal":
752+
color_trend_l = "gray"
753+
color_trend_r = "gray"
754+
else:
755+
color_trend_l = COLOR_LEFT
756+
color_trend_r = COLOR_RIGHT
757+
trend_l, = ax.plot(
684758
x_sorted,
685-
y_sorted_upper,
686-
linestyle="solid",
687-
color=COLOR_RIGHT,
688-
alpha=0.08,
759+
y_left,
760+
linestyle="dotted",
761+
color=color_trend_l,
762+
alpha=0.7,
689763
zorder=0
690764
)
691-
trend_r_lower, = ax.plot(
765+
x_sorted = [trend_dict[k]["x_sorted"] for k in trend_dict.keys() if trend_dict[k]["side"] == "R"][0]
766+
trend_r, = ax.plot(
692767
x_sorted,
693-
y_sorted_lower,
694-
linestyle="solid",
695-
color=COLOR_RIGHT,
696-
alpha=0.08,
768+
y_right,
769+
linestyle="dashed",
770+
color=color_trend_r,
771+
alpha=0.7,
697772
zorder=0
698773
)
699-
plt.fill_between(x_sorted, y_sorted_lower, y_sorted_upper, color=COLOR_RIGHT, alpha=0.05, interpolate=True)
700-
701-
# Trendline legend
702-
trendline_legend = ax.legend(handles=[trend_l, trend_r], loc='lower center')
703-
trendline_legend = ax.legend(
704-
handles=[trend_l, trend_r],
705-
labels=["Injected", "Non-Injected"],
706-
loc="lower center",
707-
fontsize=legend_size,
708-
title="Trendlines"
709-
)
710-
# Add the legend manually to the Axes.
711-
ax.add_artist(trendline_legend)
712774

713775
# Create combined tick positions & labels
714776
main_ticks = range(len(bin_labels))
@@ -778,20 +840,24 @@ def main():
778840
# C: The SGN count compared to reference values from literature and healthy
779841
# Maybe remove literature reference from plot?
780842
fig_04c(chreef_data,
781-
save_path=os.path.join(args.figure_dir, f"fig_04c.{FILE_EXTENSION}"),
843+
save_path=os.path.join(args.figure_dir, f"fig_04c_{grouping}.{FILE_EXTENSION}"),
782844
plot=args.plot, grouping=grouping, use_alias=use_alias)
783845

784846
# D: The transduction efficiency. We also plot GFP intensities.
785847
fig_04d(chreef_data,
786-
save_path=os.path.join(args.figure_dir, f"fig_04d_transduction.{FILE_EXTENSION}"),
848+
save_path=os.path.join(args.figure_dir, f"fig_04d_transduction_{grouping}.{FILE_EXTENSION}"),
787849
plot=args.plot, grouping=grouping, use_alias=use_alias)
788850
# fig_04d(chreef_data,
789851
# save_path=os.path.join(args.figure_dir, f"fig_04d_intensity.{FILE_EXTENSION}"),
790852
# plot=args.plot, plot_by_side=True, intensity=True, use_alias=use_alias)
791853

792854
fig_04e(chreef_data,
793-
save_path=os.path.join(args.figure_dir, f"fig_04e_transduction.{FILE_EXTENSION}"),
855+
save_path=os.path.join(args.figure_dir, f"fig_04e_transduction_{grouping}.{FILE_EXTENSION}"),
794856
plot=args.plot, grouping=grouping, use_alias=use_alias, trendlines=True)
857+
858+
fig_04e(chreef_data,
859+
save_path=os.path.join(args.figure_dir, f"fig_04e_transduction_std_{grouping}.{FILE_EXTENSION}"),
860+
plot=args.plot, grouping=grouping, use_alias=use_alias, trendlines=True, trendline_std=True)
795861
# fig_04e(chreef_data,
796862
# save_path=os.path.join(args.figure_dir, f"fig_04e_intensity.{FILE_EXTENSION}"),
797863
# plot=args.plot, intensity=True, use_alias=use_alias)
@@ -802,8 +868,8 @@ def main():
802868
plot=args.plot, grouping="side_mono", gerbil=True, use_alias=use_alias)
803869

804870
fig_04e(chreef_data_gerbil,
805-
save_path=os.path.join(args.figure_dir, f"fig_04e_gerbil_transduction.{FILE_EXTENSION}"),
806-
plot=args.plot, gerbil=True, use_alias=use_alias)
871+
save_path=os.path.join(args.figure_dir, f"fig_04e_gerbil_transduction_{grouping}.{FILE_EXTENSION}"),
872+
plot=args.plot, gerbil=True, use_alias=use_alias, trendlines=True)
807873

808874
# fig_04e(chreef_data_gerbil,
809875
# save_path=os.path.join(args.figure_dir, f"fig_04e_gerbil_intensity.{FILE_EXTENSION}"),

0 commit comments

Comments
 (0)