@@ -482,7 +482,7 @@ def bpic(trace, model=None):
482
482
return 3 * mean_deviance - 2 * deviance_at_mean
483
483
484
484
485
- def compare (traces , models , ic = 'WAIC' , method = 'stacking' , b_samples = 1000 ,
485
+ def compare (model_dict , ic = 'WAIC' , method = 'stacking' , b_samples = 1000 ,
486
486
alpha = 1 , seed = None , round_to = 2 ):
487
487
R"""Compare models based on the widely available information criterion (WAIC)
488
488
or leave-one-out (LOO) cross-validation.
@@ -491,9 +491,7 @@ def compare(traces, models, ic='WAIC', method='stacking', b_samples=1000,
491
491
492
492
Parameters
493
493
----------
494
- traces : list of PyMC3 traces
495
- models : list of PyMC3 models
496
- in the same order as traces.
494
+ model_dict : dictionary of PyMC3 traces indexed by corresponding model
497
495
ic : string
498
496
Information Criterion (WAIC or LOO) used to compare models.
499
497
Default WAIC.
@@ -546,31 +544,36 @@ def compare(traces, models, ic='WAIC', method='stacking', b_samples=1000,
546
544
warning : A value of 1 indicates that the computation of the IC may not be
547
545
reliable see http://arxiv.org/abs/1507.04544 for details.
548
546
"""
547
+
548
+ names = [model .name for model in model_dict if model .name ]
549
+ if not names :
550
+ names = np .arange (len (model_dict ))
551
+
549
552
if ic == 'WAIC' :
550
553
ic_func = waic
551
- df_comp = pd .DataFrame (index = np . arange ( len ( models )) ,
554
+ df_comp = pd .DataFrame (index = names ,
552
555
columns = ['WAIC' , 'pWAIC' , 'dWAIC' , 'weight' ,
553
556
'SE' , 'dSE' , 'warning' ])
554
557
555
558
elif ic == 'LOO' :
556
559
ic_func = loo
557
- df_comp = pd .DataFrame (index = np . arange ( len ( models )) ,
560
+ df_comp = pd .DataFrame (index = names ,
558
561
columns = ['LOO' , 'pLOO' , 'dLOO' , 'weight' ,
559
562
'SE' , 'dSE' , 'warning' ])
560
563
561
564
else :
562
565
raise NotImplementedError (
563
566
'The information criterion {} is not supported.' .format (ic ))
564
567
565
- if len (set ([len (m .observed_RVs ) for m in models ])) != 1 :
568
+ if len (set ([len (m .observed_RVs ) for m in model_dict ])) != 1 :
566
569
raise ValueError (
567
570
'The number of observed RVs should be the same across all models' )
568
571
569
572
if method not in ['stacking' , 'BB-pseudo-BMA' , 'pseudo-BMA' ]:
570
573
raise ValueError ('The method {}, to compute weights,'
571
574
'is not supported.' .format (method ))
572
575
573
- warns = np .zeros (len (models ))
576
+ warns = np .zeros (len (model_dict ))
574
577
575
578
c = 0
576
579
def add_warns (* args ):
@@ -581,8 +584,8 @@ def add_warns(*args):
581
584
warnings .filterwarnings ('always' )
582
585
583
586
ics = []
584
- for c , (t , m ) in enumerate ( zip (traces , models )):
585
- ics .append ((c , ic_func (t , m , pointwise = True )))
587
+ for n , (m , t ) in zip (names , model_dict . items ( )):
588
+ ics .append ((n , ic_func (t , m , pointwise = True )))
586
589
587
590
ics .sort (key = lambda x : x [1 ][0 ])
588
591
@@ -663,7 +666,7 @@ def gradient(w):
663
666
round (weight , round_to ),
664
667
round (se , round_to ),
665
668
round (d_se , round_to ),
666
- warns [idx ])
669
+ warns [names . index ( idx ) ])
667
670
668
671
return df_comp .sort_values (by = ic )
669
672
0 commit comments