Skip to content

Commit 1a58aa3

Browse files
authored
add arguments to compareplot (#2590)
* add arguments to compareplot * change insample to insample_dev * add compareplot, energyplot and kdeplot to docs
1 parent e3df907 commit 1a58aa3

File tree

2 files changed

+65
-17
lines changed

2 files changed

+65
-17
lines changed

docs/source/api/plots.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ Plots
55
.. currentmodule:: pymc3.plots
66

77
.. automodule:: pymc3.plots
8-
:members: traceplot, plot_posterior, forestplot, compare_plot, autocorrplot
8+
:members: traceplot, plot_posterior, forestplot, compareplot, autocorrplot,
9+
energyplot, kdeplot

pymc3/plots/compareplot.py

Lines changed: 63 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import numpy as np
33

44

5-
def compareplot(comp_df, ax=None):
5+
def compareplot(comp_df, insample_dev=True, se=True, dse=True, ax=None,
6+
plot_kwargs=None):
67
"""
78
Model comparison summary plot in the style of the one used in the book
89
Statistical Rethinking by Richard McElreath.
@@ -11,9 +12,22 @@ def compareplot(comp_df, ax=None):
1112
----------
1213
1314
comp_df: DataFrame
14-
The result of the pm.compare() function
15+
the result of the `pm.compare()` function
16+
insample_dev : bool
17+
plot the in-sample deviance, that is the value of the IC without the
18+
penalization given by the effective number of parameters (pIC).
19+
Defaults to True
20+
se : bool
21+
plot the standard error of the IC estimate. Defaults to True
22+
dse : bool
23+
plot standard error of the difference in IC between each model and the
24+
top-ranked model. Defaults to True
25+
plot_kwargs : dict
26+
Optional arguments for plot elements. Currently accepts 'color_ic',
27+
'marker_ic', 'color_insample_dev', 'marker_insample_dev', 'color_dse',
28+
'marker_dse', 'ls_min_ic' 'color_ls_min_ic', 'fontsize'
1529
ax : axes
16-
Matplotlib axes. Defaults to None.
30+
Matplotlib axes. Defaults to None
1731
1832
Returns
1933
-------
@@ -24,26 +38,59 @@ def compareplot(comp_df, ax=None):
2438
if ax is None:
2539
_, ax = plt.subplots()
2640

27-
yticks_pos, step = np.linspace(0, -1, (comp_df.shape[0] * 2) - 1, retstep=True)
41+
if plot_kwargs is None:
42+
plot_kwargs = {}
43+
44+
yticks_pos, step = np.linspace(0, -1, (comp_df.shape[0] * 2) - 1,
45+
retstep=True)
2846
yticks_pos[1::2] = yticks_pos[1::2] + step / 2
2947

3048
yticks_labels = [''] * len(yticks_pos)
31-
yticks_labels[0] = comp_df.index[0]
32-
yticks_labels[1::2] = comp_df.index[1:]
3349

34-
data = comp_df.values
35-
min_ic = data[0, 0]
50+
if dse:
51+
yticks_labels[0] = comp_df.index[0]
52+
yticks_labels[2::2] = comp_df.index[1:]
53+
ax.set_yticks(yticks_pos)
54+
ax.errorbar(x=comp_df.WAIC[1:],
55+
y=yticks_pos[1::2],
56+
xerr=comp_df.dSE[1:],
57+
color=plot_kwargs.get('color_dse', 'grey'),
58+
fmt=plot_kwargs.get('marker_dse', '^'))
59+
60+
else:
61+
yticks_labels = comp_df.index
62+
ax.set_yticks(yticks_pos[::2])
63+
64+
if se:
65+
ax.errorbar(x=comp_df.WAIC,
66+
y=yticks_pos[::2],
67+
xerr=comp_df.SE,
68+
color=plot_kwargs.get('color_ic', 'k'),
69+
fmt=plot_kwargs.get('marker_ic', 'o'),
70+
mfc='None',
71+
mew=1)
72+
else:
73+
ax.plot(comp_df.WAIC,
74+
yticks_pos[::2],
75+
color=plot_kwargs.get('color_ic', 'k'),
76+
marker=plot_kwargs.get('marker_ic', 'o'),
77+
mfc='None',
78+
mew=1,
79+
lw=0)
3680

37-
ax.errorbar(x=data[:, 0], y=yticks_pos[::2], xerr=data[:, 4],
38-
fmt='ko', mfc='None', mew=1)
39-
ax.errorbar(x=data[1:, 0], y=yticks_pos[1::2],
40-
xerr=data[1:, 5], fmt='^', color='grey')
81+
if insample_dev:
82+
ax.plot(comp_df.WAIC - (2 * comp_df.pWAIC),
83+
yticks_pos[::2],
84+
color=plot_kwargs.get('color_insample_dev', 'k'),
85+
marker=plot_kwargs.get('marker_insample_dev', 'o'),
86+
lw=0)
4187

42-
ax.plot(data[:, 0] - (2 * data[:, 1]), yticks_pos[::2], 'ko')
43-
ax.axvline(min_ic, ls='--', color='grey')
88+
ax.axvline(comp_df.WAIC[0],
89+
ls=plot_kwargs.get('ls_min_ic', '--'),
90+
color=plot_kwargs.get('color_ls_min_ic', 'grey'))
4491

45-
ax.set_yticks(yticks_pos)
92+
ax.set_xlabel('Deviance', fontsize=plot_kwargs.get('fontsize', 14))
4693
ax.set_yticklabels(yticks_labels)
47-
ax.set_xlabel('Deviance')
94+
ax.set_ylim(-1 + step, 0 - step)
4895

4996
return ax

0 commit comments

Comments
 (0)