Skip to content

Commit 9543afc

Browse files
ozankabaktwiecki
authored andcommitted
Fixed bug in WAIC and LOO computation, also switched to SciPy's (#1557)
logsumexp function for better numerical stability.
1 parent 50d0e02 commit 9543afc

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pymc3/stats.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import warnings
88
from .model import modelcontext
99

10+
from scipy.misc import logsumexp
1011
from scipy.stats.distributions import pareto
1112

1213
import pymc3 as pm
@@ -101,7 +102,7 @@ def log_post_trace(trace, model):
101102
'''
102103
Calculate the elementwise log-posterior for the sampled trace.
103104
'''
104-
return np.hstack([obs.logp_elemwise(pt) for pt in trace] for obs in model.observed_RVs)
105+
return np.vstack([obs.logp_elemwise(pt) for obs in model.observed_RVs] for pt in trace)
105106

106107

107108
def waic(trace, model=None, n_eff=False):
@@ -132,7 +133,7 @@ def waic(trace, model=None, n_eff=False):
132133

133134
log_py = log_post_trace(trace, model)
134135

135-
lppd_i = np.log(np.mean(np.exp(log_py), axis=0))
136+
lppd_i = logsumexp(log_py, axis = 0, b = 1.0 / log_py.shape[0])
136137

137138
vars_lpd = np.var(log_py, axis=0)
138139
if np.any(vars_lpd > 0.4):
@@ -183,8 +184,7 @@ def loo(trace, model=None, n_eff=False):
183184
log_py = log_post_trace(trace, model)
184185

185186
# Importance ratios
186-
py = np.exp(log_py)
187-
r = 1. / py
187+
r = np.exp(-log_py)
188188
r_sorted = np.sort(r, axis=0)
189189

190190
# Extract largest 20% of importance ratios and fit generalized Pareto to each
@@ -222,7 +222,7 @@ def loo(trace, model=None, n_eff=False):
222222
# Truncate weights to guarantee finite variance
223223
w = np.minimum(r_new, r_new.mean(axis=0) * S**0.75)
224224

225-
loo_lppd_i = -2 * np.log(np.sum(w * py, axis=0) / np.sum(w, axis=0))
225+
loo_lppd_i = -2.0 * logsumexp(log_py, axis = 0, b = w / np.sum(w, axis = 0))
226226

227227
loo_lppd_se = np.sqrt(len(loo_lppd_i) * np.var(loo_lppd_i))
228228

0 commit comments

Comments
 (0)