Skip to content

Commit 5d8a8c5

Browse files
aloctavodiatwiecki
authored andcommitted
Report loo results in terms of deviance. (#1279)
* report loo results in terms of deviance. Optionally report the estimated effective number of parameters * improved docstring
1 parent 680a8eb commit 5d8a8c5

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

pymc3/stats.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,21 +112,34 @@ def waic(trace, model=None, n_eff=False):
112112
return -2 * lppd + 2 * p_waic, p_waic
113113
else:
114114
return -2 * lppd + 2 * p_waic
115-
116-
def loo(trace, model=None):
115+
116+
def loo(trace, model=None, n_eff=False):
117117
"""
118118
Calculates leave-one-out (LOO) cross-validation for out of sample predictive
119119
model fit, following Vehtari et al. (2015). Cross-validation is computed using
120120
Pareto-smoothed importance sampling (PSIS).
121121
122-
Returns log pointwise predictive density calculated via approximated LOO cross-validation.
122+
Parameters
123+
----------
124+
trace : result of MCMC run
125+
model : PyMC Model
126+
Optional model. Default None, taken from context.
127+
n_eff: bool
128+
if True the effective number parameters will be computed and returned.
129+
Default False
130+
131+
Returns
132+
-------
133+
elpd_loo: log pointwise predictive density calculated via approximated LOO cross-validation
134+
p_loo: effective number parameters, only if n_eff True
123135
"""
124136
model = modelcontext(model)
125137

126138
log_py = log_post_trace(trace, model)
127139

128140
# Importance ratios
129-
r = 1./np.exp(log_py)
141+
py = np.exp(log_py)
142+
r = 1./py
130143
r_sorted = np.sort(r, axis=0)
131144

132145
# Extract largest 20% of importance ratios and fit generalized Pareto to each
@@ -154,11 +167,14 @@ def loo(trace, model=None):
154167
# Truncate weights to guarantee finite variance
155168
w = np.minimum(r_new, r_new.mean(axis=0) * S**0.75)
156169

157-
loo_lppd = np.sum(np.log(np.sum(w * np.exp(log_py), axis=0) / np.sum(w, axis=0)))
170+
loo_lppd = np.sum(np.log(np.sum(w * py, axis=0) / np.sum(w, axis=0)))
158171

159-
return loo_lppd
172+
if n_eff:
173+
p_loo = np.sum(np.log(np.mean(py, axis=0))) - loo_lppd
174+
return -2 * loo_lppd, p_loo
175+
else:
176+
return -2 * loo_lppd
160177

161-
162178
def bpic(trace, model=None):
163179
"""
164180
Calculates Bayesian predictive information criterion n of the samples in trace from model

0 commit comments

Comments
 (0)