Skip to content

Commit ed288ed

Browse files
aloctavodiaJunpeng Lao
authored andcommitted
use Bayesian bootstrapping to compute weights (#2479)
* use Bayesian bootstrapping to compute weights * speed up _log_post_trace * replace bootstrapping with bootstrap
1 parent 6e04322 commit ed288ed

File tree

1 file changed

+65
-20
lines changed

1 file changed

+65
-20
lines changed

pymc3/stats.py

Lines changed: 65 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pymc3.theanof import floatX
1313

1414
from scipy.misc import logsumexp
15+
from scipy.stats import dirichlet
1516
from scipy.stats.distributions import pareto
1617

1718
from .backends import tracetab as ttab
@@ -143,13 +144,15 @@ def _log_post_trace(trace, model, progressbar=False):
143144
logp : array of shape (n_samples, n_observations)
144145
The contribution of the observations to the logp of the whole model.
145146
"""
147+
cached = [(var, var.logp_elemwise) for var in model.observed_RVs]
148+
146149
def logp_vals_point(pt):
147150
if len(model.observed_RVs) == 0:
148151
return floatX(np.array([], dtype='d'))
149152

150153
logp_vals = []
151-
for var in model.observed_RVs:
152-
logp = var.logp_elemwise(pt)
154+
for var, logp in cached:
155+
logp = logp(pt)
153156
if var.missing_values:
154157
logp = logp[~var.observations.mask]
155158
logp_vals.append(logp.ravel())
@@ -335,7 +338,8 @@ def bpic(trace, model=None):
335338
return 3 * mean_deviance - 2 * deviance_at_mean
336339

337340

338-
def compare(traces, models, ic='WAIC'):
341+
def compare(traces, models, ic='WAIC', bootstrap=True, b_samples=1000,
342+
alpha=1, seed=None):
339343
"""Compare models based on the widely available information criterion (WAIC)
340344
or leave-one-out (LOO) cross-validation.
341345
Read more theory here - in a paper by some of the leading authorities on
@@ -349,6 +353,19 @@ def compare(traces, models, ic='WAIC'):
349353
ic : string
350354
Information Criterion (WAIC or LOO) used to compare models.
351355
Default WAIC.
356+
bootstrap : boolean
357+
If True a Bayesian bootstrap will be used to compute the weights and
358+
the standard error of the IC estimate (SE).
359+
b_samples: int
360+
Number of samples taken by the Bayesian bootstrap estimation
361+
alpha : float
362+
The shape parameter in the Dirichlet distribution used for the
363+
Bayesian bootstrap. When alpha=1 (default), the distribution is uniform
364+
on the simplex. A smaller alpha will keeps the final weights
365+
more away from 0 and 1.
366+
seed : int or np.random.RandomState instance
367+
If int or RandomState, use it for seeding Bayesian bootstrap.
368+
Default None the global np.random state is used.
352369
353370
Returns
354371
-------
@@ -361,13 +378,13 @@ def compare(traces, models, ic='WAIC'):
361378
dIC : Relative difference between each IC (WAIC or LOO)
362379
and the lowest IC (WAIC or LOO).
363380
It's always 0 for the top-ranked model.
364-
weight: Akaike weights for each model.
381+
weight: Akaike-like weights for each model.
365382
This can be loosely interpreted as the probability of each model
366-
(among the compared model) given the data. Be careful that these
367-
weights are based on point estimates of the IC (uncertainty is ignored).
383+
(among the compared model) given the data. By default the uncertainty
384+
in the weights estimation is considered using Bayesian bootstrap.
368385
SE : Standard error of the IC estimate.
369-
For a "large enough" sample size this is an estimate of the uncertainty
370-
in the computation of the IC.
386+
By default these values are estimated using Bayesian bootstrap (best
387+
option) or, if bootstrap=False, using a Gaussian approximation
371388
dSE : Standard error of the difference in IC between each model and
372389
the top-ranked model.
373390
It's always 0 for the top-ranked model.
@@ -378,20 +395,21 @@ def compare(traces, models, ic='WAIC'):
378395
ic_func = waic
379396
df_comp = pd.DataFrame(index=np.arange(len(models)),
380397
columns=['WAIC', 'pWAIC', 'dWAIC', 'weight',
381-
'SE', 'dSE', 'warning'])
398+
'SE', 'dSE', 'warning'])
399+
382400
elif ic == 'LOO':
383401
ic_func = loo
384402
df_comp = pd.DataFrame(index=np.arange(len(models)),
385403
columns=['LOO', 'pLOO', 'dLOO', 'weight',
386-
'SE', 'dSE', 'warning'])
404+
'SE', 'dSE', 'warning'])
405+
387406
else:
388407
raise NotImplementedError(
389408
'The information criterion {} is not supported.'.format(ic))
390409

391410
warns = np.zeros(len(models))
392411

393412
c = 0
394-
395413
def add_warns(*args):
396414
warns[c] = 1
397415

@@ -405,16 +423,43 @@ def add_warns(*args):
405423

406424
ics.sort(key=lambda x: x[1][0])
407425

408-
min_ic = ics[0][1][0]
409-
Z = np.sum([np.exp(-0.5 * (x[1][0] - min_ic)) for x in ics])
426+
if bootstrap:
427+
N = len(ics[0][1][3])
428+
429+
ic_i = np.zeros((len(ics), N))
430+
for i in range(len(ics)):
431+
ic_i[i] = ics[i][1][3] * N
432+
433+
b_weighting = dirichlet.rvs(alpha=[alpha] * N, size=b_samples,
434+
random_state=seed)
435+
weights = np.zeros((b_samples, len(ics)))
436+
z_bs = np.zeros((b_samples, len(ics)))
437+
for i in range(b_samples):
438+
z_b = np.dot(ic_i, b_weighting[i])
439+
u_weights = np.exp(-0.5 * (z_b - np.min(z_b)))
440+
z_bs[i] = z_b
441+
weights[i] = u_weights / np.sum(u_weights)
442+
443+
weights_mean = weights.mean(0)
444+
se = z_bs.std(0)
445+
for i, (idx, res) in enumerate(ics):
446+
diff = res[3] - ics[0][1][3]
447+
d_ic = np.sum(diff)
448+
d_se = np.sqrt(len(diff) * np.var(diff))
449+
df_comp.at[idx] = (res[0], res[2], d_ic, weights_mean[i],
450+
se[i], d_se, warns[idx])
410451

411-
for idx, res in ics:
412-
diff = ics[0][1][3] - res[3]
413-
d_ic = np.sum(diff)
414-
d_se = np.sqrt(len(diff) * np.var(diff))
415-
weight = np.exp(-0.5 * (res[0] - min_ic)) / Z
416-
df_comp.at[idx] = (res[0], res[2], abs(d_ic), weight, res[1],
417-
d_se, warns[idx])
452+
else:
453+
min_ic = ics[0][1][0]
454+
Z = np.sum([np.exp(-0.5 * (x[1][0] - min_ic)) for x in ics])
455+
456+
for idx, res in ics:
457+
diff = res[3] - ics[0][1][3]
458+
d_ic = np.sum(diff)
459+
d_se = np.sqrt(len(diff) * np.var(diff))
460+
weight = np.exp(-0.5 * (res[0] - min_ic)) / Z
461+
df_comp.at[idx] = (res[0], res[2], d_ic, weight, res[1],
462+
d_se, warns[idx])
418463

419464
return df_comp.sort_values(by=ic)
420465

0 commit comments

Comments
 (0)