14
14
from scipy .misc import logsumexp
15
15
from scipy .stats import dirichlet
16
16
from scipy .stats .distributions import pareto
17
+ from scipy .optimize import minimize
17
18
18
19
from .backends import tracetab as ttab
19
20
@@ -340,7 +341,7 @@ def bpic(trace, model=None):
340
341
return 3 * mean_deviance - 2 * deviance_at_mean
341
342
342
343
343
- def compare (traces , models , ic = 'WAIC' , bootstrap = True , b_samples = 1000 ,
344
+ def compare (traces , models , ic = 'WAIC' , method = 'stacking' , b_samples = 1000 ,
344
345
alpha = 1 , seed = None ):
345
346
"""Compare models based on the widely available information criterion (WAIC)
346
347
or leave-one-out (LOO) cross-validation.
@@ -355,19 +356,28 @@ def compare(traces, models, ic='WAIC', bootstrap=True, b_samples=1000,
355
356
ic : string
356
357
Information Criterion (WAIC or LOO) used to compare models.
357
358
Default WAIC.
358
- bootstrap : boolean
359
- If True a Bayesian bootstrap will be used to compute the weights and
360
- the standard error of the IC estimate (SE).
359
+ method : str
360
+ Method used to estimate the weights for each model. Available options
361
+ are:
362
+ - 'stacking' : (default) stacking of predictive distributions.
363
+ - 'BB-pseudo-BMA' : pseudo-Bayesian Model averaging using Akaike-type
364
+ weighting. The weights are stabilized using the Bayesian bootstrap
365
+ - 'pseudo-BMA': pseudo-Bayesian Model averaging using Akaike-type
366
+ weighting, without Bootstrap stabilization (not recommended)
367
+
368
+ For more information read https://arxiv.org/abs/1704.02030
361
369
b_samples: int
362
- Number of samples taken by the Bayesian bootstrap estimation
370
+ Number of samples taken by the Bayesian bootstrap estimation. Only
371
+ useful when method = 'BB-pseudo-BMA'.
363
372
alpha : float
364
373
The shape parameter in the Dirichlet distribution used for the
365
- Bayesian bootstrap. When alpha=1 (default), the distribution is uniform
366
- on the simplex. A smaller alpha will keeps the final weights
367
- more away from 0 and 1.
374
+ Bayesian bootstrap. Only useful when method = 'BB-pseudo-BMA'. When
375
+ alpha=1 (default), the distribution is uniform on the simplex. A
376
+ smaller alpha will keeps the final weights more away from 0 and 1.
368
377
seed : int or np.random.RandomState instance
369
- If int or RandomState, use it for seeding Bayesian bootstrap.
370
- Default None the global np.random state is used.
378
+ If int or RandomState, use it for seeding Bayesian bootstrap. Only
379
+ useful when method = 'BB-pseudo-BMA'. Default None the global
380
+ np.random state is used.
371
381
372
382
Returns
373
383
-------
@@ -380,13 +390,13 @@ def compare(traces, models, ic='WAIC', bootstrap=True, b_samples=1000,
380
390
dIC : Relative difference between each IC (WAIC or LOO)
381
391
and the lowest IC (WAIC or LOO).
382
392
It's always 0 for the top-ranked model.
383
- weight: Akaike-like weights for each model.
393
+ weight: Relative weight for each model.
384
394
This can be loosely interpreted as the probability of each model
385
395
(among the compared model) given the data. By default the uncertainty
386
396
in the weights estimation is considered using Bayesian bootstrap.
387
397
SE : Standard error of the IC estimate.
388
- By default these values are estimated using Bayesian bootstrap (best
389
- option) or, if bootstrap=False, using a Gaussian approximation
398
+ If method = BB-pseudo-BMA these values are estimated using Bayesian
399
+ bootstrap.
390
400
dSE : Standard error of the difference in IC between each model and
391
401
the top-ranked model.
392
402
It's always 0 for the top-ranked model.
@@ -409,6 +419,14 @@ def compare(traces, models, ic='WAIC', bootstrap=True, b_samples=1000,
409
419
raise NotImplementedError (
410
420
'The information criterion {} is not supported.' .format (ic ))
411
421
422
+ if len (set ([len (m .observed_RVs ) for m in models ])) != 1 :
423
+ raise ValueError (
424
+ 'The Observed RVs should be the same across all models' )
425
+
426
+ if method not in ['stacking' , 'BB-pseudo-BMA' , 'pseudo-BMA' ]:
427
+ raise NotImplementedError (
428
+ 'The method to compute weights {} is not supported.' .format (method ))
429
+
412
430
warns = np .zeros (len (models ))
413
431
414
432
c = 0
@@ -425,45 +443,95 @@ def add_warns(*args):
425
443
426
444
ics .sort (key = lambda x : x [1 ][0 ])
427
445
428
- if bootstrap :
429
- N = len (ics [0 ][1 ][3 ])
430
-
431
- ic_i = np .zeros ((len (ics ), N ))
432
- for i in range (len (ics )):
433
- ic_i [i ] = ics [i ][1 ][3 ] * N
446
+ if method == 'stacking' :
447
+ N , K , ic_i = _ic_matrix (ics )
448
+ exp_ic_i = np .exp (- 0.5 * ic_i )
449
+ Km = K - 1
450
+
451
+ def w_fuller (w ):
452
+ return np .concatenate ((w , 1. - np .sum (w , keepdims = True )))
453
+
454
+ def log_score (w ):
455
+ w_full = w_fuller (w )
456
+ score = 0.
457
+ for i in range (N ):
458
+ score += np .log (np .dot (exp_ic_i [i ], w_full ))
459
+ return - score
460
+
461
+ def gradient (w ):
462
+ w_full = w_fuller (w )
463
+ grad = np .zeros (Km )
464
+ for k in range (Km ):
465
+ for i in range (N ):
466
+ grad [k ] += (exp_ic_i [i , k ] - exp_ic_i [i , Km ]) / \
467
+ np .dot (exp_ic_i [i ], w_full )
468
+ return - grad
469
+
470
+ theta = np .full (Km , 1. / K )
471
+ bounds = [(0. , 1. ) for i in range (Km )]
472
+ constraints = [{'type' : 'ineq' , 'fun' : lambda x : - np .sum (x ) + 1. },
473
+ {'type' : 'ineq' , 'fun' : lambda x : np .sum (x )}]
474
+
475
+ w = minimize (fun = log_score ,
476
+ x0 = theta ,
477
+ jac = gradient ,
478
+ bounds = bounds ,
479
+ constraints = constraints )
480
+
481
+ weights = w_fuller (w ['x' ])
482
+ ses = [res [1 ] for _ , res in ics ]
483
+
484
+ elif method == 'BB-pseudo-BMA' :
485
+ N , K , ic_i = _ic_matrix (ics )
486
+ ic_i = ic_i * N
434
487
435
488
b_weighting = dirichlet .rvs (alpha = [alpha ] * N , size = b_samples ,
436
489
random_state = seed )
437
- weights = np .zeros ((b_samples , len ( ics ) ))
438
- z_bs = np .zeros (( b_samples , len ( ics )) )
490
+ weights = np .zeros ((b_samples , K ))
491
+ z_bs = np .zeros_like ( weights )
439
492
for i in range (b_samples ):
440
- z_b = np .dot (ic_i , b_weighting [i ])
493
+ z_b = np .dot (b_weighting [i ], ic_i )
441
494
u_weights = np .exp (- 0.5 * (z_b - np .min (z_b )))
442
495
z_bs [i ] = z_b
443
496
weights [i ] = u_weights / np .sum (u_weights )
444
497
445
- weights_mean = weights .mean (0 )
446
- se = z_bs .std (0 )
447
- for i , (idx , res ) in enumerate (ics ):
448
- diff = res [3 ] - ics [0 ][1 ][3 ]
449
- d_ic = np .sum (diff )
450
- d_se = np .sqrt (len (diff ) * np .var (diff ))
451
- df_comp .at [idx ] = (res [0 ], res [2 ], d_ic , weights_mean [i ],
452
- se [i ], d_se , warns [idx ])
498
+ weights = weights .mean (0 )
499
+ ses = z_bs .std (0 )
453
500
454
- else :
501
+ elif method == 'pseudo-BMA' :
455
502
min_ic = ics [0 ][1 ][0 ]
456
503
Z = np .sum ([np .exp (- 0.5 * (x [1 ][0 ] - min_ic )) for x in ics ])
504
+ weights = []
505
+ ses = []
506
+ for _ , res in ics :
507
+ weights .append (np .exp (- 0.5 * (res [0 ] - min_ic )) / Z )
508
+ ses .append (res [1 ])
457
509
458
- for idx , res in ics :
510
+ if np .any (weights ):
511
+ for i , (idx , res ) in enumerate (ics ):
459
512
diff = res [3 ] - ics [0 ][1 ][3 ]
460
513
d_ic = np .sum (diff )
461
514
d_se = np .sqrt (len (diff ) * np .var (diff ))
462
- weight = np .exp (- 0.5 * (res [0 ] - min_ic )) / Z
463
- df_comp .at [idx ] = (res [0 ], res [2 ], d_ic , weight , res [1 ],
464
- d_se , warns [idx ])
515
+ se = ses [i ]
516
+ weight = weights [i ]
517
+ df_comp .at [idx ] = (res [0 ], res [2 ], d_ic , weight , se , d_se ,
518
+ warns [idx ])
519
+
520
+ return df_comp .sort_values (by = ic )
521
+
522
+
523
+ def _ic_matrix (ics ):
524
+ """Store the previously computed pointwise predictive accuracy values (ics)
525
+ in a 2D matrix array.
526
+ """
527
+ N = len (ics [0 ][1 ][3 ])
528
+ K = len (ics )
529
+
530
+ ic_i = np .zeros ((N , K ))
531
+ for i in range (K ):
532
+ ic_i [:, i ] = ics [i ][1 ][3 ]
465
533
466
- return df_comp . sort_values ( by = ic )
534
+ return N , K , ic_i
467
535
468
536
469
537
def make_indices (dimensions ):
0 commit comments