Skip to content

Commit 72d9e6d

Browse files
committed
Changed arguments for compare to accept dict
1 parent 1cdd163 commit 72d9e6d

File tree

1 file changed

+14
-11
lines changed

1 file changed

+14
-11
lines changed

pymc3/stats.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,7 @@ def bpic(trace, model=None):
482482
return 3 * mean_deviance - 2 * deviance_at_mean
483483

484484

485-
def compare(traces, models, ic='WAIC', method='stacking', b_samples=1000,
485+
def compare(model_dict, ic='WAIC', method='stacking', b_samples=1000,
486486
alpha=1, seed=None, round_to=2):
487487
R"""Compare models based on the widely available information criterion (WAIC)
488488
or leave-one-out (LOO) cross-validation.
@@ -491,9 +491,7 @@ def compare(traces, models, ic='WAIC', method='stacking', b_samples=1000,
491491
492492
Parameters
493493
----------
494-
traces : list of PyMC3 traces
495-
models : list of PyMC3 models
496-
in the same order as traces.
494+
model_dict : dictionary of PyMC3 traces indexed by corresponding model
497495
ic : string
498496
Information Criterion (WAIC or LOO) used to compare models.
499497
Default WAIC.
@@ -546,31 +544,36 @@ def compare(traces, models, ic='WAIC', method='stacking', b_samples=1000,
546544
warning : A value of 1 indicates that the computation of the IC may not be
547545
reliable see http://arxiv.org/abs/1507.04544 for details.
548546
"""
547+
548+
names = [model.name for model in model_dict if model.name]
549+
if not names:
550+
names = np.arange(len(model_dict))
551+
549552
if ic == 'WAIC':
550553
ic_func = waic
551-
df_comp = pd.DataFrame(index=np.arange(len(models)),
554+
df_comp = pd.DataFrame(index=names,
552555
columns=['WAIC', 'pWAIC', 'dWAIC', 'weight',
553556
'SE', 'dSE', 'warning'])
554557

555558
elif ic == 'LOO':
556559
ic_func = loo
557-
df_comp = pd.DataFrame(index=np.arange(len(models)),
560+
df_comp = pd.DataFrame(index=names,
558561
columns=['LOO', 'pLOO', 'dLOO', 'weight',
559562
'SE', 'dSE', 'warning'])
560563

561564
else:
562565
raise NotImplementedError(
563566
'The information criterion {} is not supported.'.format(ic))
564567

565-
if len(set([len(m.observed_RVs) for m in models])) != 1:
568+
if len(set([len(m.observed_RVs) for m in model_dict])) != 1:
566569
raise ValueError(
567570
'The number of observed RVs should be the same across all models')
568571

569572
if method not in ['stacking', 'BB-pseudo-BMA', 'pseudo-BMA']:
570573
raise ValueError('The method {}, to compute weights,'
571574
'is not supported.'.format(method))
572575

573-
warns = np.zeros(len(models))
576+
warns = np.zeros(len(model_dict))
574577

575578
c = 0
576579
def add_warns(*args):
@@ -581,8 +584,8 @@ def add_warns(*args):
581584
warnings.filterwarnings('always')
582585

583586
ics = []
584-
for c, (t, m) in enumerate(zip(traces, models)):
585-
ics.append((c, ic_func(t, m, pointwise=True)))
587+
for n, (m, t) in zip(names, model_dict.items()):
588+
ics.append((n, ic_func(t, m, pointwise=True)))
586589

587590
ics.sort(key=lambda x: x[1][0])
588591

@@ -663,7 +666,7 @@ def gradient(w):
663666
round(weight, round_to),
664667
round(se, round_to),
665668
round(d_se, round_to),
666-
warns[idx])
669+
warns[names.index(idx)])
667670

668671
return df_comp.sort_values(by=ic)
669672

0 commit comments

Comments
 (0)