Skip to content

Commit 6dd0e81

Browse files
authored
Merge pull request #2764 from pymc-devs/compare_tweak
Changed arguments for compare to accept dict
2 parents 999661c + 2743564 commit 6dd0e81

File tree

5 files changed

+195
-204
lines changed

5 files changed

+195
-204
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
- Densityplot: add support for discrete variables
1919
- Fix the Binomial likelihood in `.glm.families.Binomial`, with the flexibility of specifying the `n`.
2020
- Add `offset` kwarg to `.glm`.
21+
- Changed the `compare` function to accept a dictionary of model-trace pairs instead of two separate lists of models and traces.
2122

2223
### Fixes
2324

docs/source/notebooks/model_averaging.ipynb

Lines changed: 54 additions & 85 deletions
Large diffs are not rendered by default.

docs/source/notebooks/model_comparison.ipynb

Lines changed: 114 additions & 102 deletions
Large diffs are not rendered by default.

pymc3/stats.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def _gpinv(p, k, sigma):
456456
return x
457457

458458

459-
def compare(traces, models, ic='WAIC', method='stacking', b_samples=1000,
459+
def compare(model_dict, ic='WAIC', method='stacking', b_samples=1000,
460460
alpha=1, seed=None, round_to=2):
461461
R"""Compare models based on the widely available information criterion (WAIC)
462462
or leave-one-out (LOO) cross-validation.
@@ -465,9 +465,7 @@ def compare(traces, models, ic='WAIC', method='stacking', b_samples=1000,
465465
466466
Parameters
467467
----------
468-
traces : list of PyMC3 traces
469-
models : list of PyMC3 models
470-
in the same order as traces.
468+
model_dict : dictionary of PyMC3 traces indexed by corresponding model
471469
ic : string
472470
Information Criterion (WAIC or LOO) used to compare models.
473471
Default WAIC.
@@ -520,23 +518,28 @@ def compare(traces, models, ic='WAIC', method='stacking', b_samples=1000,
520518
warning : A value of 1 indicates that the computation of the IC may not be
521519
reliable. Details see the related warning message in pm.waic and pm.loo
522520
"""
521+
522+
names = [model.name for model in model_dict if model.name]
523+
if not names:
524+
names = np.arange(len(model_dict))
525+
523526
if ic == 'WAIC':
524527
ic_func = waic
525-
df_comp = pd.DataFrame(index=np.arange(len(models)),
528+
df_comp = pd.DataFrame(index=names,
526529
columns=['WAIC', 'pWAIC', 'dWAIC', 'weight',
527530
'SE', 'dSE', 'var_warn'])
528531

529532
elif ic == 'LOO':
530533
ic_func = loo
531-
df_comp = pd.DataFrame(index=np.arange(len(models)),
534+
df_comp = pd.DataFrame(index=names,
532535
columns=['LOO', 'pLOO', 'dLOO', 'weight',
533536
'SE', 'dSE', 'shape_warn'])
534537

535538
else:
536539
raise NotImplementedError(
537540
'The information criterion {} is not supported.'.format(ic))
538541

539-
if len(set([len(m.observed_RVs) for m in models])) != 1:
542+
if len(set([len(m.observed_RVs) for m in model_dict])) != 1:
540543
raise ValueError(
541544
'The number of observed RVs should be the same across all models')
542545

@@ -545,8 +548,8 @@ def compare(traces, models, ic='WAIC', method='stacking', b_samples=1000,
545548
'is not supported.'.format(method))
546549

547550
ics = []
548-
for c, (t, m) in enumerate(zip(traces, models)):
549-
ics.append((c, ic_func(t, m, pointwise=True)))
551+
for n, (m, t) in zip(names, model_dict.items()):
552+
ics.append((n, ic_func(t, m, pointwise=True)))
550553

551554
ics.sort(key=lambda x: x[1][0])
552555

pymc3/tests/test_stats.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from numpy.random import random, normal
1414
from numpy.testing import assert_equal, assert_almost_equal, assert_array_almost_equal
1515
from scipy import stats as st
16+
import copy
1617

1718

1819
def test_log_post_trace():
@@ -67,12 +68,14 @@ def test_compare():
6768
x = pm.StudentT('x', nu=1, mu=mu, lam=1, observed=x_obs)
6869
trace2 = pm.sample(1000)
6970

70-
traces = [trace0] * 2
71-
models = [model0] * 2
71+
traces = [trace0, copy.copy(trace0)]
72+
models = [model0, copy.copy(model0)]
7273

73-
w_st = pm.compare(traces, models, method='stacking')['weight']
74-
w_bb_bma = pm.compare(traces, models, method='BB-pseudo-BMA')['weight']
75-
w_bma = pm.compare(traces, models, method='pseudo-BMA')['weight']
74+
model_dict = dict(zip(models, traces))
75+
76+
w_st = pm.compare(model_dict, method='stacking')['weight']
77+
w_bb_bma = pm.compare(model_dict, method='BB-pseudo-BMA')['weight']
78+
w_bma = pm.compare(model_dict, method='pseudo-BMA')['weight']
7679

7780
assert_almost_equal(w_st[0], w_st[1])
7881
assert_almost_equal(w_bb_bma[0], w_bb_bma[1])
@@ -84,9 +87,12 @@ def test_compare():
8487

8588
traces = [trace0, trace1, trace2]
8689
models = [model0, model1, model2]
87-
w_st = pm.compare(traces, models, method='stacking')['weight']
88-
w_bb_bma = pm.compare(traces, models, method='BB-pseudo-BMA')['weight']
89-
w_bma = pm.compare(traces, models, method='pseudo-BMA')['weight']
90+
91+
model_dict = dict(zip(models, traces))
92+
93+
w_st = pm.compare(model_dict, method='stacking')['weight']
94+
w_bb_bma = pm.compare(model_dict, method='BB-pseudo-BMA')['weight']
95+
w_bma = pm.compare(model_dict, method='pseudo-BMA')['weight']
9096

9197
assert(w_st[0] > w_st[1] > w_st[2])
9298
assert(w_bb_bma[0] > w_bb_bma[1] > w_bb_bma[2])

0 commit comments

Comments
 (0)