Skip to content

Commit 28b23e5

Browse files
committed
Fixed legend location for horizontal plots and added a bbox kwarg for user input
1 parent b1c78bc commit 28b23e5

11 files changed

+108
-100
lines changed

dabest/misc_tools.py

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def get_kwargs(plot_kwargs, ytick_color):
272272
# Legend kwargs.
273273
default_legend_kwargs = {
274274
"loc": "upper left",
275-
"frameon": False
275+
"frameon": False,
276276
}
277277
if plot_kwargs["legend_kwargs"] is None:
278278
legend_kwargs = default_legend_kwargs
@@ -1110,7 +1110,7 @@ def set_xaxis_ticks_and_lims(show_delta2, show_mini_meta, rawdata_axes, contrast
11101110
)
11111111

11121112

1113-
def show_legend(legend_labels, legend_handles, rawdata_axes, contrast_axes, float_contrast, show_pairs, horizontal, legend_kwargs):
1113+
def show_legend(legend_labels, legend_handles, rawdata_axes, contrast_axes, table_axes, float_contrast, show_pairs, horizontal, legend_kwargs, table_kwargs):
11141114
"""
11151115
Show the legend for the plotter function.
11161116
@@ -1124,6 +1124,8 @@ def show_legend(legend_labels, legend_handles, rawdata_axes, contrast_axes, floa
11241124
The raw data axes.
11251125
contrast_axes : object (Axes)
11261126
The contrast axes.
1127+
table_axes : object (Axes)
1128+
The table axes.
11271129
float_contrast : bool
11281130
A boolean flag to determine if the plot is GA or Cumming format.
11291131
show_pairs : bool
@@ -1140,41 +1142,39 @@ def show_legend(legend_labels, legend_handles, rawdata_axes, contrast_axes, floa
11401142
pd.Series(legend_handles, dtype="object").loc[unique_idx]
11411143
).tolist()
11421144

1143-
if len(legend_handles_unique) > 0:
1145+
# Location of the legend
1146+
if "bbox_to_anchor" not in legend_kwargs.keys():
11441147
if horizontal:
1145-
axes_with_legend = rawdata_axes
1146-
# bta = (0.8, 0.8)
1147-
leg = axes_with_legend.legend(
1148-
legend_handles_unique,
1149-
legend_labels_unique,
1150-
handlelength=0.5,
1151-
**legend_kwargs
1152-
)
1153-
if show_pairs:
1154-
for line in leg.get_lines():
1155-
line.set_linewidth(3)
1148+
bta = (1,1)
11561149
else:
11571150
if float_contrast:
1158-
axes_with_legend = contrast_axes
1159-
if show_pairs:
1160-
bta = (2.00, 1.02)
1161-
else:
1162-
bta = (1.5, 1.02)
1151+
bta = (2.00, 1.02) if show_pairs else (1.5, 1.02)
11631152
else:
1164-
axes_with_legend = rawdata_axes
1165-
if show_pairs:
1166-
bta = (1.02, 1.0)
1167-
else:
1168-
bta = (1.0, 1.0)
1169-
leg = axes_with_legend.legend(
1170-
legend_handles_unique,
1171-
legend_labels_unique,
1172-
bbox_to_anchor=bta,
1173-
**legend_kwargs
1174-
)
1175-
if show_pairs:
1176-
for line in leg.get_lines():
1177-
line.set_linewidth(3.0)
1153+
bta = (1.02, 1.0) if show_pairs else (1.0, 1.0)
1154+
legend_kwargs.update({'bbox_to_anchor': bta})
1155+
1156+
# Pick the ax to plot
1157+
if horizontal:
1158+
if table_kwargs['show']:
1159+
axes_with_legend = table_axes
1160+
else:
1161+
axes_with_legend = contrast_axes
1162+
elif float_contrast:
1163+
axes_with_legend = contrast_axes
1164+
else:
1165+
axes_with_legend = rawdata_axes
1166+
1167+
# Plot the legend
1168+
if len(legend_handles_unique) > 0:
1169+
leg = axes_with_legend.legend(
1170+
legend_handles_unique,
1171+
legend_labels_unique,
1172+
handlelength=0.5,
1173+
**legend_kwargs
1174+
)
1175+
if show_pairs:
1176+
for line in leg.get_lines():
1177+
line.set_linewidth(3.0)
11781178

11791179
def Gardner_Altman_Plot_Aesthetic_Adjustments(effect_size_type, plot_data, xvar, yvar, current_control, current_group,
11801180
rawdata_axes, contrast_axes, results, current_effsize, is_paired, one_sankey,

dabest/plotter.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -408,23 +408,6 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
408408
proportional=proportional,
409409
horizontal=horizontal,
410410
)
411-
# Legend
412-
handles, labels = rawdata_axes.get_legend_handles_labels()
413-
legend_labels = [l for l in labels]
414-
legend_handles = [h for h in handles]
415-
416-
if bootstraps_color_by_group is False and not effectsize_df.delta2:
417-
rawdata_axes.legend().set_visible(False)
418-
show_legend(
419-
legend_labels=legend_labels,
420-
legend_handles=legend_handles,
421-
rawdata_axes=rawdata_axes,
422-
contrast_axes=contrast_axes,
423-
float_contrast=float_contrast,
424-
show_pairs=show_pairs,
425-
horizontal=horizontal,
426-
legend_kwargs=legend_kwargs
427-
)
428411

429412
# Plot aesthetic adjustments.
430413
if float_contrast and not horizontal:
@@ -589,6 +572,27 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
589572
effect_size = effect_size,
590573
gridkey_kwargs = gridkey_kwargs,
591574
)
575+
576+
577+
# Legend
578+
handles, labels = rawdata_axes.get_legend_handles_labels()
579+
legend_labels = [l for l in labels]
580+
legend_handles = [h for h in handles]
581+
582+
if bootstraps_color_by_group is False and not effectsize_df.delta2:
583+
rawdata_axes.legend().set_visible(False)
584+
show_legend(
585+
legend_labels=legend_labels,
586+
legend_handles=legend_handles,
587+
rawdata_axes=rawdata_axes,
588+
contrast_axes=contrast_axes,
589+
table_axes=table_axes,
590+
float_contrast=float_contrast,
591+
show_pairs=show_pairs,
592+
horizontal=horizontal,
593+
legend_kwargs=legend_kwargs,
594+
table_kwargs=table_kwargs
595+
)
592596

593597
# Make sure no stray ticks appear!
594598
rawdata_axes.xaxis.set_ticks_position("bottom")

nbs/API/misc_tools.ipynb

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@
325325
" # Legend kwargs.\n",
326326
" default_legend_kwargs = {\n",
327327
" \"loc\": \"upper left\", \n",
328-
" \"frameon\": False\n",
328+
" \"frameon\": False,\n",
329329
" }\n",
330330
" if plot_kwargs[\"legend_kwargs\"] is None:\n",
331331
" legend_kwargs = default_legend_kwargs\n",
@@ -1163,7 +1163,7 @@
11631163
" )\n",
11641164
"\n",
11651165
"\n",
1166-
"def show_legend(legend_labels, legend_handles, rawdata_axes, contrast_axes, float_contrast, show_pairs, horizontal, legend_kwargs):\n",
1166+
"def show_legend(legend_labels, legend_handles, rawdata_axes, contrast_axes, table_axes, float_contrast, show_pairs, horizontal, legend_kwargs, table_kwargs):\n",
11671167
" \"\"\"\n",
11681168
" Show the legend for the plotter function.\n",
11691169
"\n",
@@ -1177,6 +1177,8 @@
11771177
" The raw data axes.\n",
11781178
" contrast_axes : object (Axes)\n",
11791179
" The contrast axes.\n",
1180+
" table_axes : object (Axes)\n",
1181+
" The table axes.\n",
11801182
" float_contrast : bool\n",
11811183
" A boolean flag to determine if the plot is GA or Cumming format.\n",
11821184
" show_pairs : bool\n",
@@ -1193,41 +1195,39 @@
11931195
" pd.Series(legend_handles, dtype=\"object\").loc[unique_idx]\n",
11941196
" ).tolist()\n",
11951197
"\n",
1196-
" if len(legend_handles_unique) > 0:\n",
1198+
" # Location of the legend\n",
1199+
" if \"bbox_to_anchor\" not in legend_kwargs.keys():\n",
11971200
" if horizontal:\n",
1198-
" axes_with_legend = rawdata_axes\n",
1199-
" # bta = (0.8, 0.8)\n",
1200-
" leg = axes_with_legend.legend(\n",
1201-
" legend_handles_unique,\n",
1202-
" legend_labels_unique,\n",
1203-
" handlelength=0.5,\n",
1204-
" **legend_kwargs\n",
1205-
" )\n",
1206-
" if show_pairs:\n",
1207-
" for line in leg.get_lines():\n",
1208-
" line.set_linewidth(3)\n",
1201+
" bta = (1,1)\n",
12091202
" else:\n",
12101203
" if float_contrast:\n",
1211-
" axes_with_legend = contrast_axes\n",
1212-
" if show_pairs:\n",
1213-
" bta = (2.00, 1.02)\n",
1214-
" else:\n",
1215-
" bta = (1.5, 1.02)\n",
1204+
" bta = (2.00, 1.02) if show_pairs else (1.5, 1.02)\n",
12161205
" else:\n",
1217-
" axes_with_legend = rawdata_axes\n",
1218-
" if show_pairs:\n",
1219-
" bta = (1.02, 1.0)\n",
1220-
" else:\n",
1221-
" bta = (1.0, 1.0)\n",
1222-
" leg = axes_with_legend.legend(\n",
1223-
" legend_handles_unique,\n",
1224-
" legend_labels_unique,\n",
1225-
" bbox_to_anchor=bta,\n",
1226-
" **legend_kwargs\n",
1227-
" )\n",
1228-
" if show_pairs:\n",
1229-
" for line in leg.get_lines():\n",
1230-
" line.set_linewidth(3.0)\n",
1206+
" bta = (1.02, 1.0) if show_pairs else (1.0, 1.0)\n",
1207+
" legend_kwargs.update({'bbox_to_anchor': bta})\n",
1208+
"\n",
1209+
" # Pick the ax to plot\n",
1210+
" if horizontal:\n",
1211+
" if table_kwargs['show']:\n",
1212+
" axes_with_legend = table_axes\n",
1213+
" else:\n",
1214+
" axes_with_legend = contrast_axes\n",
1215+
" elif float_contrast:\n",
1216+
" axes_with_legend = contrast_axes\n",
1217+
" else:\n",
1218+
" axes_with_legend = rawdata_axes\n",
1219+
"\n",
1220+
" # Plot the legend\n",
1221+
" if len(legend_handles_unique) > 0:\n",
1222+
" leg = axes_with_legend.legend(\n",
1223+
" legend_handles_unique,\n",
1224+
" legend_labels_unique,\n",
1225+
" handlelength=0.5,\n",
1226+
" **legend_kwargs\n",
1227+
" )\n",
1228+
" if show_pairs:\n",
1229+
" for line in leg.get_lines():\n",
1230+
" line.set_linewidth(3.0)\n",
12311231
" \n",
12321232
"def Gardner_Altman_Plot_Aesthetic_Adjustments(effect_size_type, plot_data, xvar, yvar, current_control, current_group,\n",
12331233
" rawdata_axes, contrast_axes, results, current_effsize, is_paired, one_sankey,\n",

nbs/API/plotter.ipynb

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -465,23 +465,6 @@
465465
" proportional=proportional,\n",
466466
" horizontal=horizontal,\n",
467467
" )\n",
468-
" # Legend\n",
469-
" handles, labels = rawdata_axes.get_legend_handles_labels()\n",
470-
" legend_labels = [l for l in labels]\n",
471-
" legend_handles = [h for h in handles]\n",
472-
"\n",
473-
" if bootstraps_color_by_group is False and not effectsize_df.delta2:\n",
474-
" rawdata_axes.legend().set_visible(False)\n",
475-
" show_legend(\n",
476-
" legend_labels=legend_labels, \n",
477-
" legend_handles=legend_handles, \n",
478-
" rawdata_axes=rawdata_axes, \n",
479-
" contrast_axes=contrast_axes, \n",
480-
" float_contrast=float_contrast, \n",
481-
" show_pairs=show_pairs, \n",
482-
" horizontal=horizontal,\n",
483-
" legend_kwargs=legend_kwargs\n",
484-
" )\n",
485468
"\n",
486469
" # Plot aesthetic adjustments.\n",
487470
" if float_contrast and not horizontal:\n",
@@ -646,6 +629,27 @@
646629
" effect_size = effect_size,\n",
647630
" gridkey_kwargs = gridkey_kwargs,\n",
648631
" )\n",
632+
" \n",
633+
"\n",
634+
" # Legend\n",
635+
" handles, labels = rawdata_axes.get_legend_handles_labels()\n",
636+
" legend_labels = [l for l in labels]\n",
637+
" legend_handles = [h for h in handles]\n",
638+
"\n",
639+
" if bootstraps_color_by_group is False and not effectsize_df.delta2:\n",
640+
" rawdata_axes.legend().set_visible(False)\n",
641+
" show_legend(\n",
642+
" legend_labels=legend_labels, \n",
643+
" legend_handles=legend_handles, \n",
644+
" rawdata_axes=rawdata_axes, \n",
645+
" contrast_axes=contrast_axes, \n",
646+
" table_axes=table_axes,\n",
647+
" float_contrast=float_contrast, \n",
648+
" show_pairs=show_pairs, \n",
649+
" horizontal=horizontal,\n",
650+
" legend_kwargs=legend_kwargs,\n",
651+
" table_kwargs=table_kwargs\n",
652+
" )\n",
649653
"\n",
650654
" # Make sure no stray ticks appear!\n",
651655
" rawdata_axes.xaxis.set_ticks_position(\"bottom\")\n",
-81 Bytes
Loading
795 Bytes
Loading
409 Bytes
Loading
-901 Bytes
Loading
32 Bytes
Loading
-966 Bytes
Loading

0 commit comments

Comments
 (0)