|
7 | 7 | import warnings
|
8 | 8 | from .model import modelcontext
|
9 | 9 |
|
| 10 | +from scipy.misc import logsumexp |
10 | 11 | from scipy.stats.distributions import pareto
|
11 | 12 |
|
12 | 13 | import pymc3 as pm
|
@@ -101,7 +102,7 @@ def log_post_trace(trace, model):
|
101 | 102 | '''
|
102 | 103 | Calculate the elementwise log-posterior for the sampled trace.
|
103 | 104 | '''
|
104 |
| - return np.hstack([obs.logp_elemwise(pt) for pt in trace] for obs in model.observed_RVs) |
| 105 | + return np.vstack([obs.logp_elemwise(pt) for obs in model.observed_RVs] for pt in trace) |
105 | 106 |
|
106 | 107 |
|
107 | 108 | def waic(trace, model=None, n_eff=False):
|
@@ -132,7 +133,7 @@ def waic(trace, model=None, n_eff=False):
|
132 | 133 |
|
133 | 134 | log_py = log_post_trace(trace, model)
|
134 | 135 |
|
135 |
| - lppd_i = np.log(np.mean(np.exp(log_py), axis=0)) |
| 136 | + lppd_i = logsumexp(log_py, axis = 0, b = 1.0 / log_py.shape[0]) |
136 | 137 |
|
137 | 138 | vars_lpd = np.var(log_py, axis=0)
|
138 | 139 | if np.any(vars_lpd > 0.4):
|
@@ -183,8 +184,7 @@ def loo(trace, model=None, n_eff=False):
|
183 | 184 | log_py = log_post_trace(trace, model)
|
184 | 185 |
|
185 | 186 | # Importance ratios
|
186 |
| - py = np.exp(log_py) |
187 |
| - r = 1. / py |
| 187 | + r = np.exp(-log_py) |
188 | 188 | r_sorted = np.sort(r, axis=0)
|
189 | 189 |
|
190 | 190 | # Extract largest 20% of importance ratios and fit generalized Pareto to each
|
@@ -222,7 +222,7 @@ def loo(trace, model=None, n_eff=False):
|
222 | 222 | # Truncate weights to guarantee finite variance
|
223 | 223 | w = np.minimum(r_new, r_new.mean(axis=0) * S**0.75)
|
224 | 224 |
|
225 |
| - loo_lppd_i = -2 * np.log(np.sum(w * py, axis=0) / np.sum(w, axis=0)) |
| 225 | + loo_lppd_i = -2.0 * logsumexp(log_py, axis = 0, b = w / np.sum(w, axis = 0)) |
226 | 226 |
|
227 | 227 | loo_lppd_se = np.sqrt(len(loo_lppd_i) * np.var(loo_lppd_i))
|
228 | 228 |
|
|
0 commit comments