Skip to content

Commit 313ad0d

Browse files
aloctavodiatwiecki
authored andcommitted
ENH Add warning to WAIC for high values of the log predictive densities variance (#1281)
1 parent 5d8a8c5 commit 313ad0d

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

pymc3/stats.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,38 @@ def waic(trace, model=None, n_eff=False):
9999
"""
100100
Calculate the widely available information criterion and the effective number of parameters of the samples in trace from model.
101101
Read more theory here - in a paper by some of the leading authorities on Model Selection - http://bit.ly/1W2YJ7c
102+
103+
Parameters
104+
----------
105+
trace : result of MCMC run
106+
model : PyMC Model
107+
Optional model. Default None, taken from context.
108+
n_eff: bool
109+
if True the effective number parameters will be returned.
110+
Default False
111+
112+
Returns
113+
-------
114+
waic: widely available information criterion
115+
p_waic: effective number parameters, only if n_eff True
116+
102117
"""
103118
model = modelcontext(model)
104119

105120
log_py = log_post_trace(trace, model)
106121

107122
lppd = np.sum(np.log(np.mean(np.exp(log_py), axis=0)))
108-
109-
p_waic = np.sum(np.var(log_py, axis=0))
123+
124+
vars_lpd = np.var(log_py, axis=0)
125+
if np.any(vars_lpd > 0.4):
126+
warnings.warn("""For one or more samples the posterior variance of the
127+
log predictive densities exceeds 0.4. This could be indication of
128+
WAIC starting to fail see http://arxiv.org/abs/1507.04544 for details
129+
""")
130+
p_waic = np.sum(vars_lpd)
110131

111132
if n_eff:
112-
return -2 * lppd + 2 * p_waic, p_waic
133+
return -2 * lppd + 2 * p_waic, p_waic
113134
else:
114135
return -2 * lppd + 2 * p_waic
115136

0 commit comments

Comments
 (0)