12
12
from pymc3 .theanof import floatX
13
13
14
14
from scipy .misc import logsumexp
15
+ from scipy .stats import dirichlet
15
16
from scipy .stats .distributions import pareto
16
17
17
18
from .backends import tracetab as ttab
@@ -143,13 +144,15 @@ def _log_post_trace(trace, model, progressbar=False):
143
144
logp : array of shape (n_samples, n_observations)
144
145
The contribution of the observations to the logp of the whole model.
145
146
"""
147
+ cached = [(var , var .logp_elemwise ) for var in model .observed_RVs ]
148
+
146
149
def logp_vals_point (pt ):
147
150
if len (model .observed_RVs ) == 0 :
148
151
return floatX (np .array ([], dtype = 'd' ))
149
152
150
153
logp_vals = []
151
- for var in model . observed_RVs :
152
- logp = var . logp_elemwise (pt )
154
+ for var , logp in cached :
155
+ logp = logp (pt )
153
156
if var .missing_values :
154
157
logp = logp [~ var .observations .mask ]
155
158
logp_vals .append (logp .ravel ())
@@ -335,7 +338,8 @@ def bpic(trace, model=None):
335
338
return 3 * mean_deviance - 2 * deviance_at_mean
336
339
337
340
338
- def compare (traces , models , ic = 'WAIC' ):
341
+ def compare (traces , models , ic = 'WAIC' , bootstrap = True , b_samples = 1000 ,
342
+ alpha = 1 , seed = None ):
339
343
"""Compare models based on the widely available information criterion (WAIC)
340
344
or leave-one-out (LOO) cross-validation.
341
345
Read more theory here - in a paper by some of the leading authorities on
@@ -349,6 +353,19 @@ def compare(traces, models, ic='WAIC'):
349
353
ic : string
350
354
Information Criterion (WAIC or LOO) used to compare models.
351
355
Default WAIC.
356
+ bootstrap : boolean
357
+ If True a Bayesian bootstrap will be used to compute the weights and
358
+ the standard error of the IC estimate (SE).
359
+ b_samples: int
360
+ Number of samples taken by the Bayesian bootstrap estimation
361
+ alpha : float
362
+ The shape parameter in the Dirichlet distribution used for the
363
+ Bayesian bootstrap. When alpha=1 (default), the distribution is uniform
364
+ on the simplex. A smaller alpha will keeps the final weights
365
+ more away from 0 and 1.
366
+ seed : int or np.random.RandomState instance
367
+ If int or RandomState, use it for seeding Bayesian bootstrap.
368
+ Default None the global np.random state is used.
352
369
353
370
Returns
354
371
-------
@@ -361,13 +378,13 @@ def compare(traces, models, ic='WAIC'):
361
378
dIC : Relative difference between each IC (WAIC or LOO)
362
379
and the lowest IC (WAIC or LOO).
363
380
It's always 0 for the top-ranked model.
364
- weight: Akaike weights for each model.
381
+ weight: Akaike-like weights for each model.
365
382
This can be loosely interpreted as the probability of each model
366
- (among the compared model) given the data. Be careful that these
367
- weights are based on point estimates of the IC (uncertainty is ignored) .
383
+ (among the compared model) given the data. By default the uncertainty
384
+ in the weights estimation is considered using Bayesian bootstrap .
368
385
SE : Standard error of the IC estimate.
369
- For a "large enough" sample size this is an estimate of the uncertainty
370
- in the computation of the IC.
386
+ By default these values are estimated using Bayesian bootstrap (best
387
+ option) or, if bootstrap=False, using a Gaussian approximation
371
388
dSE : Standard error of the difference in IC between each model and
372
389
the top-ranked model.
373
390
It's always 0 for the top-ranked model.
@@ -378,20 +395,21 @@ def compare(traces, models, ic='WAIC'):
378
395
ic_func = waic
379
396
df_comp = pd .DataFrame (index = np .arange (len (models )),
380
397
columns = ['WAIC' , 'pWAIC' , 'dWAIC' , 'weight' ,
381
- 'SE' , 'dSE' , 'warning' ])
398
+ 'SE' , 'dSE' , 'warning' ])
399
+
382
400
elif ic == 'LOO' :
383
401
ic_func = loo
384
402
df_comp = pd .DataFrame (index = np .arange (len (models )),
385
403
columns = ['LOO' , 'pLOO' , 'dLOO' , 'weight' ,
386
- 'SE' , 'dSE' , 'warning' ])
404
+ 'SE' , 'dSE' , 'warning' ])
405
+
387
406
else :
388
407
raise NotImplementedError (
389
408
'The information criterion {} is not supported.' .format (ic ))
390
409
391
410
warns = np .zeros (len (models ))
392
411
393
412
c = 0
394
-
395
413
def add_warns (* args ):
396
414
warns [c ] = 1
397
415
@@ -405,16 +423,43 @@ def add_warns(*args):
405
423
406
424
ics .sort (key = lambda x : x [1 ][0 ])
407
425
408
- min_ic = ics [0 ][1 ][0 ]
409
- Z = np .sum ([np .exp (- 0.5 * (x [1 ][0 ] - min_ic )) for x in ics ])
426
+ if bootstrap :
427
+ N = len (ics [0 ][1 ][3 ])
428
+
429
+ ic_i = np .zeros ((len (ics ), N ))
430
+ for i in range (len (ics )):
431
+ ic_i [i ] = ics [i ][1 ][3 ] * N
432
+
433
+ b_weighting = dirichlet .rvs (alpha = [alpha ] * N , size = b_samples ,
434
+ random_state = seed )
435
+ weights = np .zeros ((b_samples , len (ics )))
436
+ z_bs = np .zeros ((b_samples , len (ics )))
437
+ for i in range (b_samples ):
438
+ z_b = np .dot (ic_i , b_weighting [i ])
439
+ u_weights = np .exp (- 0.5 * (z_b - np .min (z_b )))
440
+ z_bs [i ] = z_b
441
+ weights [i ] = u_weights / np .sum (u_weights )
442
+
443
+ weights_mean = weights .mean (0 )
444
+ se = z_bs .std (0 )
445
+ for i , (idx , res ) in enumerate (ics ):
446
+ diff = res [3 ] - ics [0 ][1 ][3 ]
447
+ d_ic = np .sum (diff )
448
+ d_se = np .sqrt (len (diff ) * np .var (diff ))
449
+ df_comp .at [idx ] = (res [0 ], res [2 ], d_ic , weights_mean [i ],
450
+ se [i ], d_se , warns [idx ])
410
451
411
- for idx , res in ics :
412
- diff = ics [0 ][1 ][3 ] - res [3 ]
413
- d_ic = np .sum (diff )
414
- d_se = np .sqrt (len (diff ) * np .var (diff ))
415
- weight = np .exp (- 0.5 * (res [0 ] - min_ic )) / Z
416
- df_comp .at [idx ] = (res [0 ], res [2 ], abs (d_ic ), weight , res [1 ],
417
- d_se , warns [idx ])
452
+ else :
453
+ min_ic = ics [0 ][1 ][0 ]
454
+ Z = np .sum ([np .exp (- 0.5 * (x [1 ][0 ] - min_ic )) for x in ics ])
455
+
456
+ for idx , res in ics :
457
+ diff = res [3 ] - ics [0 ][1 ][3 ]
458
+ d_ic = np .sum (diff )
459
+ d_se = np .sqrt (len (diff ) * np .var (diff ))
460
+ weight = np .exp (- 0.5 * (res [0 ] - min_ic )) / Z
461
+ df_comp .at [idx ] = (res [0 ], res [2 ], d_ic , weight , res [1 ],
462
+ d_se , warns [idx ])
418
463
419
464
return df_comp .sort_values (by = ic )
420
465
0 commit comments