Skip to content

Commit b1c78bc

Browse files
committed
Add counts to prop plots feature and added tests and tutorial details for this feature
1 parent 36ac2b5 commit b1c78bc

19 files changed

+341
-30
lines changed

dabest/_effsize_objects.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,9 @@ def plot(
10371037

10381038
es_marker_kwargs=None,
10391039
es_errorbar_kwargs=None,
1040+
1041+
prop_sample_counts=False,
1042+
prop_sample_counts_kwargs=None
10401043
):
10411044
"""
10421045
Creates an estimation plot for the effect size of interest.
@@ -1126,6 +1129,9 @@ def plot(
11261129
Pass any keyword arguments accepted by the seaborn `swarmplot`
11271130
command here, as a dict. If None, the following keywords are
11281131
passed to sns.swarmplot : {'size':`raw_marker_size`}.
1132+
barplot_kwargs : dict, default None
1133+
By default, the keyword arguments passed are:
1134+
{"estimator": np.mean, "errorbar": plot_kwargs["ci"]}
11291135
violinplot_kwargs : dict, default None
11301136
Pass any keyword arguments accepted by the matplotlib `
11311137
pyplot.violinplot` command here, as a dict. If None, the following
@@ -1241,6 +1247,13 @@ def plot(
12411247
Pass relevant keyword arguments to the effectsize errorbar plotting. If none, the following keywords are passed:
12421248
{'color': 'black', 'lw': 2, 'linestyle': '-', 'alpha': 1,'zorder': 1,}
12431249
1250+
prop_sample_counts: bool, default False
1251+
Show the sample counts for each group in proportional plots
1252+
prop_sample_counts_kwargs: dict, default None
1253+
Pass relevant keyword arguments. If None, the following keywords are passed:
1254+
{'color': 'k', 'zorder': 5, 'ha': 'center', 'va': 'center'}
1255+
1256+
12441257
Returns
12451258
-------
12461259
A :class:`matplotlib.figure.Figure` with 2 Axes, if ``ax = None``.

dabest/_modidx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@
106106
'dabest/plot_tools.py'),
107107
'dabest.plot_tools.SwarmPlot._swarm': ('API/plot_tools.html#swarmplot._swarm', 'dabest/plot_tools.py'),
108108
'dabest.plot_tools.SwarmPlot.plot': ('API/plot_tools.html#swarmplot.plot', 'dabest/plot_tools.py'),
109+
'dabest.plot_tools.add_counts_to_prop_plots': ( 'API/plot_tools.html#add_counts_to_prop_plots',
110+
'dabest/plot_tools.py'),
109111
'dabest.plot_tools.barplotter': ('API/plot_tools.html#barplotter', 'dabest/plot_tools.py'),
110112
'dabest.plot_tools.check_data_matches_labels': ( 'API/plot_tools.html#check_data_matches_labels',
111113
'dabest/plot_tools.py'),

dabest/misc_tools.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def get_kwargs(plot_kwargs, ytick_color):
200200
# Barplot kwargs
201201
default_barplot_kwargs = {
202202
"estimator": np.mean,
203-
"errorbar": plot_kwargs["ci"]
203+
"errorbar": plot_kwargs["ci"],
204204
}
205205
if plot_kwargs["barplot_kwargs"] is None:
206206
barplot_kwargs = default_barplot_kwargs
@@ -435,11 +435,23 @@ def get_kwargs(plot_kwargs, ytick_color):
435435
else:
436436
es_errorbar_kwargs = merge_two_dicts(default_es_errorbar_kwargs, plot_kwargs['es_errorbar_kwargs'])
437437

438+
# Prop sample counts kwargs
439+
default_prop_sample_counts_kwargs = {
440+
'color': 'k',
441+
'zorder': 5,
442+
'ha': 'center',
443+
'va': 'center'
444+
}
445+
if plot_kwargs['prop_sample_counts_kwargs'] is None:
446+
prop_sample_counts_kwargs = default_prop_sample_counts_kwargs
447+
else:
448+
prop_sample_counts_kwargs = merge_two_dicts(default_prop_sample_counts_kwargs, plot_kwargs['prop_sample_counts_kwargs'])
449+
438450
# Return the kwargs.
439451
return (swarmplot_kwargs, barplot_kwargs, sankey_kwargs, violinplot_kwargs, slopegraph_kwargs,
440452
reflines_kwargs, legend_kwargs, group_summaries_kwargs, redraw_axes_kwargs, delta_dot_kwargs,
441453
delta_text_kwargs, summary_bars_kwargs, swarm_bars_kwargs, contrast_bars_kwargs, table_kwargs, gridkey_kwargs,
442-
es_marker_kwargs, es_errorbar_kwargs)
454+
es_marker_kwargs, es_errorbar_kwargs, prop_sample_counts_kwargs)
443455

444456

445457
def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx, all_plot_groups):

dabest/plot_tools.py

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
__all__ = ['halfviolin', 'get_swarm_spans', 'error_bar', 'check_data_matches_labels', 'normalize_dict', 'width_determine',
1010
'single_sankey', 'sankeydiag', 'summary_bars_plotter', 'contrast_bars_plotter', 'swarm_bars_plotter',
1111
'delta_text_plotter', 'DeltaDotsPlotter', 'slopegraph_plotter', 'plot_minimeta_or_deltadelta_violins',
12-
'effect_size_curve_plotter', 'gridkey_plotter', 'barplotter', 'table_for_horizontal_plots', 'swarmplot',
13-
'SwarmPlot']
12+
'effect_size_curve_plotter', 'gridkey_plotter', 'barplotter', 'table_for_horizontal_plots',
13+
'add_counts_to_prop_plots', 'swarmplot', 'SwarmPlot']
1414

1515
# %% ../nbs/API/plot_tools.ipynb 4
1616
import math
@@ -1919,9 +1919,8 @@ def barplotter(xvar, yvar, all_plot_groups, rawdata_axes, plot_data, bar_color,
19191919
horizontal : bool
19201920
If the plot is horizontal.
19211921
"""
1922-
1923-
x_label, y_label = rawdata_axes.get_xlabel(), rawdata_axes.get_ylabel()
19241922

1923+
x_label, y_label = rawdata_axes.get_xlabel(), rawdata_axes.get_ylabel()
19251924
if horizontal:
19261925
x_var, y_var, orient = np.ones(len(all_plot_groups)), all_plot_groups, "h"
19271926
else:
@@ -1969,7 +1968,7 @@ def barplotter(xvar, yvar, all_plot_groups, rawdata_axes, plot_data, bar_color,
19691968
centre = x + width / 2.0
19701969
bar.set_x(centre - bar_width / 2.0)
19711970
bar.set_width(bar_width)
1972-
1971+
19731972
# reset the x and y labels
19741973
rawdata_axes.set_xlabel(x_label)
19751974
rawdata_axes.set_ylabel(y_label)
@@ -2047,6 +2046,61 @@ def table_for_horizontal_plots(effectsize_df, ax, contrast_axes, ticks_to_plot,
20472046
ax.set_xlabel(label, fontsize=fontsize_label) # Set the x-axis label - hardcoded for now
20482047
sns.despine(ax=ax, left=True, bottom=True)
20492048

2049+
2050+
def add_counts_to_prop_plots(plot_data, xvar, yvar, rawdata_axes, horizontal, is_paired, prop_sample_counts_kwargs):
2051+
"""
2052+
Add counts to the proportion plots.
2053+
2054+
Parameters
2055+
----------
2056+
plot_data : object (Dataframe)
2057+
Dataframe of the plot data.
2058+
xvar : str
2059+
Column name of the x variable.
2060+
yvar : str
2061+
Column name of the y variable.
2062+
rawdata_axes : object
2063+
Matplotlib axis object to plot on.
2064+
horizontal : bool
2065+
If the plot is horizontal.
2066+
is_paired : bool
2067+
Whether the data is paired.
2068+
prop_sample_counts_kwargs : dict
2069+
Keyword arguments for the sample counts.
2070+
"""
2071+
2072+
# Group orders
2073+
if isinstance(plot_data[xvar].dtype, pd.CategoricalDtype):
2074+
sample_size_text_order = pd.unique(plot_data[xvar]).categories
2075+
else:
2076+
sample_size_text_order = pd.unique(plot_data[xvar])
2077+
2078+
# Get the sample size values
2079+
ones, zeros = plot_data[plot_data[yvar] == 1], plot_data[plot_data[yvar] == 0]
2080+
2081+
sample_size_val1 = ones.groupby(xvar, observed=False)[yvar].count().reindex(index=sample_size_text_order)
2082+
sample_size_val0 = zeros.groupby(xvar, observed=False)[yvar].count().reindex(index=sample_size_text_order)
2083+
2084+
fontsize = 8 if horizontal else 10
2085+
fontsize -= 2 if is_paired else 0
2086+
2087+
if "fontsize" not in prop_sample_counts_kwargs.keys():
2088+
fontsize = 8 if horizontal else 10
2089+
fontsize -= 2 if is_paired else 0
2090+
prop_sample_counts_kwargs.update({'fontsize': fontsize})
2091+
2092+
for sample_text_x, sample_text_y0, sample_text_y1 in zip(
2093+
np.arange(0,len(sample_size_text_order)+1,1),
2094+
sample_size_val0,
2095+
sample_size_val1,
2096+
):
2097+
if horizontal:
2098+
rawdata_axes.text(0.05, sample_text_x, sample_text_y1, **prop_sample_counts_kwargs)
2099+
rawdata_axes.text(0.95, sample_text_x, sample_text_y0, **prop_sample_counts_kwargs)
2100+
else:
2101+
rawdata_axes.text(sample_text_x, 0.05, sample_text_y1, **prop_sample_counts_kwargs)
2102+
rawdata_axes.text(sample_text_x, 0.95, sample_text_y0, **prop_sample_counts_kwargs)
2103+
20502104
# %% ../nbs/API/plot_tools.ipynb 6
20512105
def swarmplot(
20522106
data: pd.DataFrame,

dabest/plotter.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
6161
delta_text=True, delta_text_kwargs=None,
6262
delta_dot=True, delta_dot_kwargs=None,
6363
horizontal=False, horizontal_table_kwargs=None,
64-
es_marker_kwargs=None, es_errorbar_kwargs=None
64+
es_marker_kwargs=None, es_errorbar_kwargs=None,
65+
prop_sample_counts=False, prop_sample_counts_kwargs=None
6566
"""
6667
from .misc_tools import (
6768
get_params,
@@ -93,6 +94,7 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
9394
gridkey_plotter,
9495
barplotter,
9596
table_for_horizontal_plots,
97+
add_counts_to_prop_plots,
9698
)
9799

98100
warnings.filterwarnings(
@@ -121,13 +123,14 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
121123
plot_kwargs=plot_kwargs,
122124
)
123125

124-
(swarmplot_kwargs, barplot_kwargs, sankey_kwargs, violinplot_kwargs,
125-
slopegraph_kwargs, reflines_kwargs, legend_kwargs, group_summaries_kwargs,
126-
redraw_axes_kwargs, delta_dot_kwargs, delta_text_kwargs, summary_bars_kwargs,
127-
swarm_bars_kwargs, contrast_bars_kwargs, table_kwargs, gridkey_kwargs, es_marker_kwargs, es_errorbar_kwargs) = get_kwargs(
128-
plot_kwargs=plot_kwargs,
129-
ytick_color=ytick_color
130-
)
126+
(swarmplot_kwargs, barplot_kwargs, sankey_kwargs,
127+
violinplot_kwargs, slopegraph_kwargs, reflines_kwargs,
128+
legend_kwargs, group_summaries_kwargs, redraw_axes_kwargs, delta_dot_kwargs,
129+
delta_text_kwargs, summary_bars_kwargs, swarm_bars_kwargs, contrast_bars_kwargs,
130+
table_kwargs, gridkey_kwargs, es_marker_kwargs, es_errorbar_kwargs, prop_sample_counts_kwargs) = get_kwargs(
131+
plot_kwargs=plot_kwargs,
132+
ytick_color=ytick_color
133+
)
131134

132135
# We also need to extract the `sankey` and `flow` from the kwargs for plotter.py
133136
# to use for varying different kinds of paired proportional plots
@@ -318,6 +321,18 @@ def effectsize_df_plotter(effectsize_df, **plot_kwargs):
318321
horizontal=horizontal,
319322
)
320323

324+
# Add counts to prop plots
325+
if proportional and plot_kwargs['prop_sample_counts'] and sankey_kwargs["flow"]:
326+
add_counts_to_prop_plots(
327+
plot_data=plot_data,
328+
xvar=xvar,
329+
yvar=yvar,
330+
rawdata_axes=rawdata_axes,
331+
horizontal=horizontal,
332+
is_paired = is_paired,
333+
prop_sample_counts_kwargs=prop_sample_counts_kwargs,
334+
)
335+
321336
# Plot effect sizes and bootstraps.
322337
plot_groups = (temp_all_plot_groups if (is_paired == "baseline" and show_pairs and two_col_sankey)
323338
else temp_idx if two_col_sankey

nbs/API/effsize_objects.ipynb

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,9 @@
11961196
"\n",
11971197
" es_marker_kwargs=None,\n",
11981198
" es_errorbar_kwargs=None,\n",
1199+
"\n",
1200+
" prop_sample_counts=False,\n",
1201+
" prop_sample_counts_kwargs=None\n",
11991202
" ):\n",
12001203
" \"\"\"\n",
12011204
" Creates an estimation plot for the effect size of interest.\n",
@@ -1285,6 +1288,9 @@
12851288
" Pass any keyword arguments accepted by the seaborn `swarmplot`\n",
12861289
" command here, as a dict. If None, the following keywords are\n",
12871290
" passed to sns.swarmplot : {'size':`raw_marker_size`}.\n",
1291+
" barplot_kwargs : dict, default None\n",
1292+
" By default, the keyword arguments passed are:\n",
1293+
" {\"estimator\": np.mean, \"errorbar\": plot_kwargs[\"ci\"]}\n",
12881294
" violinplot_kwargs : dict, default None\n",
12891295
" Pass any keyword arguments accepted by the matplotlib `\n",
12901296
" pyplot.violinplot` command here, as a dict. If None, the following\n",
@@ -1400,6 +1406,13 @@
14001406
" Pass relevant keyword arguments to the effectsize errorbar plotting. If none, the following keywords are passed:\n",
14011407
" {'color': 'black', 'lw': 2, 'linestyle': '-', 'alpha': 1,'zorder': 1,}\n",
14021408
"\n",
1409+
" prop_sample_counts: bool, default False\n",
1410+
" Show the sample counts for each group in proportional plots\n",
1411+
" prop_sample_counts_kwargs: dict, default None\n",
1412+
" Pass relevant keyword arguments. If None, the following keywords are passed:\n",
1413+
" {'color': 'k', 'zorder': 5, 'ha': 'center', 'va': 'center'}\n",
1414+
" \n",
1415+
"\n",
14031416
" Returns\n",
14041417
" -------\n",
14051418
" A :class:`matplotlib.figure.Figure` with 2 Axes, if ``ax = None``.\n",

nbs/API/misc_tools.ipynb

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@
253253
" # Barplot kwargs\n",
254254
" default_barplot_kwargs = {\n",
255255
" \"estimator\": np.mean, \n",
256-
" \"errorbar\": plot_kwargs[\"ci\"]\n",
256+
" \"errorbar\": plot_kwargs[\"ci\"],\n",
257257
" }\n",
258258
" if plot_kwargs[\"barplot_kwargs\"] is None:\n",
259259
" barplot_kwargs = default_barplot_kwargs\n",
@@ -488,11 +488,23 @@
488488
" else:\n",
489489
" es_errorbar_kwargs = merge_two_dicts(default_es_errorbar_kwargs, plot_kwargs['es_errorbar_kwargs'])\n",
490490
"\n",
491+
" # Prop sample counts kwargs\n",
492+
" default_prop_sample_counts_kwargs = {\n",
493+
" 'color': 'k', \n",
494+
" 'zorder': 5, \n",
495+
" 'ha': 'center', \n",
496+
" 'va': 'center'\n",
497+
" }\n",
498+
" if plot_kwargs['prop_sample_counts_kwargs'] is None:\n",
499+
" prop_sample_counts_kwargs = default_prop_sample_counts_kwargs\n",
500+
" else:\n",
501+
" prop_sample_counts_kwargs = merge_two_dicts(default_prop_sample_counts_kwargs, plot_kwargs['prop_sample_counts_kwargs'])\n",
502+
"\n",
491503
" # Return the kwargs.\n",
492504
" return (swarmplot_kwargs, barplot_kwargs, sankey_kwargs, violinplot_kwargs, slopegraph_kwargs, \n",
493505
" reflines_kwargs, legend_kwargs, group_summaries_kwargs, redraw_axes_kwargs, delta_dot_kwargs,\n",
494506
" delta_text_kwargs, summary_bars_kwargs, swarm_bars_kwargs, contrast_bars_kwargs, table_kwargs, gridkey_kwargs,\n",
495-
" es_marker_kwargs, es_errorbar_kwargs)\n",
507+
" es_marker_kwargs, es_errorbar_kwargs, prop_sample_counts_kwargs)\n",
496508
"\n",
497509
"\n",
498510
"def get_color_palette(plot_kwargs, plot_data, xvar, show_pairs, idx, all_plot_groups):\n",

nbs/API/plot_tools.ipynb

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,9 +1969,8 @@
19691969
" horizontal : bool\n",
19701970
" If the plot is horizontal.\n",
19711971
" \"\"\"\n",
1972-
" \n",
1973-
" x_label, y_label = rawdata_axes.get_xlabel(), rawdata_axes.get_ylabel()\n",
19741972
"\n",
1973+
" x_label, y_label = rawdata_axes.get_xlabel(), rawdata_axes.get_ylabel()\n",
19751974
" if horizontal:\n",
19761975
" x_var, y_var, orient = np.ones(len(all_plot_groups)), all_plot_groups, \"h\"\n",
19771976
" else:\n",
@@ -2019,7 +2018,7 @@
20192018
" centre = x + width / 2.0\n",
20202019
" bar.set_x(centre - bar_width / 2.0)\n",
20212020
" bar.set_width(bar_width)\n",
2022-
" \n",
2021+
"\n",
20232022
" # reset the x and y labels\n",
20242023
" rawdata_axes.set_xlabel(x_label)\n",
20252024
" rawdata_axes.set_ylabel(y_label)\n",
@@ -2095,7 +2094,62 @@
20952094
" ax.set_yticklabels([])\n",
20962095
" ax.tick_params(left=False, bottom=False)\n",
20972096
" ax.set_xlabel(label, fontsize=fontsize_label) # Set the x-axis label - hardcoded for now\n",
2098-
" sns.despine(ax=ax, left=True, bottom=True)"
2097+
" sns.despine(ax=ax, left=True, bottom=True)\n",
2098+
"\n",
2099+
"\n",
2100+
"def add_counts_to_prop_plots(plot_data, xvar, yvar, rawdata_axes, horizontal, is_paired, prop_sample_counts_kwargs):\n",
2101+
" \"\"\"\n",
2102+
" Add counts to the proportion plots.\n",
2103+
"\n",
2104+
" Parameters\n",
2105+
" ----------\n",
2106+
" plot_data : object (Dataframe)\n",
2107+
" Dataframe of the plot data.\n",
2108+
" xvar : str\n",
2109+
" Column name of the x variable.\n",
2110+
" yvar : str\n",
2111+
" Column name of the y variable.\n",
2112+
" rawdata_axes : object\n",
2113+
" Matplotlib axis object to plot on.\n",
2114+
" horizontal : bool\n",
2115+
" If the plot is horizontal.\n",
2116+
" is_paired : bool\n",
2117+
" Whether the data is paired.\n",
2118+
" prop_sample_counts_kwargs : dict\n",
2119+
" Keyword arguments for the sample counts.\n",
2120+
" \"\"\"\n",
2121+
"\n",
2122+
" # Group orders\n",
2123+
" if isinstance(plot_data[xvar].dtype, pd.CategoricalDtype):\n",
2124+
" sample_size_text_order = pd.unique(plot_data[xvar]).categories\n",
2125+
" else:\n",
2126+
" sample_size_text_order = pd.unique(plot_data[xvar])\n",
2127+
"\n",
2128+
" # Get the sample size values\n",
2129+
" ones, zeros = plot_data[plot_data[yvar] == 1], plot_data[plot_data[yvar] == 0]\n",
2130+
"\n",
2131+
" sample_size_val1 = ones.groupby(xvar, observed=False)[yvar].count().reindex(index=sample_size_text_order)\n",
2132+
" sample_size_val0 = zeros.groupby(xvar, observed=False)[yvar].count().reindex(index=sample_size_text_order)\n",
2133+
"\n",
2134+
" fontsize = 8 if horizontal else 10\n",
2135+
" fontsize -= 2 if is_paired else 0\n",
2136+
"\n",
2137+
" if \"fontsize\" not in prop_sample_counts_kwargs.keys():\n",
2138+
" fontsize = 8 if horizontal else 10\n",
2139+
" fontsize -= 2 if is_paired else 0\n",
2140+
" prop_sample_counts_kwargs.update({'fontsize': fontsize})\n",
2141+
"\n",
2142+
" for sample_text_x, sample_text_y0, sample_text_y1 in zip(\n",
2143+
" np.arange(0,len(sample_size_text_order)+1,1), \n",
2144+
" sample_size_val0,\n",
2145+
" sample_size_val1,\n",
2146+
" ):\n",
2147+
" if horizontal:\n",
2148+
" rawdata_axes.text(0.05, sample_text_x, sample_text_y1, **prop_sample_counts_kwargs)\n",
2149+
" rawdata_axes.text(0.95, sample_text_x, sample_text_y0, **prop_sample_counts_kwargs)\n",
2150+
" else:\n",
2151+
" rawdata_axes.text(sample_text_x, 0.05, sample_text_y1, **prop_sample_counts_kwargs)\n",
2152+
" rawdata_axes.text(sample_text_x, 0.95, sample_text_y0, **prop_sample_counts_kwargs)"
20992153
]
21002154
},
21012155
{

0 commit comments

Comments
 (0)