Skip to content

Commit 22a4c6d

Browse files
committed
Fix shape issues in model comparison
1 parent 3d368ef commit 22a4c6d

File tree

4 files changed

+64
-6
lines changed

4 files changed

+64
-6
lines changed

pymc3/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,7 @@ def as_tensor(data, name, model, distribution):
915915

916916
if hasattr(data, 'mask'):
917917
from .distributions import NoDistribution
918-
testval = distribution.default()
918+
testval = np.broadcast_to(distribution.default(), data.shape)[data.mask]
919919
fakedist = NoDistribution.dist(shape=data.mask.sum(), dtype=dtype,
920920
testval=testval, parent_dist=distribution)
921921
missing_values = FreeRV(name=name + '_missing', distribution=fakedist,

pymc3/stats.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from collections import namedtuple
99
from .model import modelcontext
1010
from .util import get_default_varnames
11+
from pymc3.theanof import floatX
1112

1213
from scipy.misc import logsumexp
1314
from scipy.stats.distributions import pareto
@@ -123,16 +124,35 @@ def dic(trace, model=None):
123124
return 2 * mean_deviance - deviance_at_mean
124125

125126

126-
def log_post_trace(trace, model):
127+
def _log_post_trace(trace, model):
127128
"""Calculate the elementwise log-posterior for the sampled trace.
128129
129130
Parameters
130131
----------
131132
trace : result of MCMC run
132133
model : PyMC Model
133134
Optional model. Default None, taken from context.
135+
136+
Returns
137+
-------
138+
logp : array of shape (n_samples, n_observations)
139+
The contribution of the observations to the logp of the whole model.
134140
"""
135-
return np.vstack([obs.logp_elemwise(pt) for obs in model.observed_RVs] for pt in trace)
141+
def logp_vals_point(pt):
142+
if len(model.observed_RVs) == 0:
143+
return floatX(np.array([], dtype='d'))
144+
145+
logp_vals = []
146+
for var in model.observed_RVs:
147+
logp = var.logp_elemwise(pt)
148+
if var.missing_values:
149+
logp = logp[~var.observations.mask]
150+
logp_vals.append(logp.ravel())
151+
152+
return np.concatenate(logp_vals)
153+
154+
logp = (logp_vals_point(pt) for pt in trace)
155+
return np.stack(logp)
136156

137157

138158
def waic(trace, model=None, pointwise=False):
@@ -160,7 +180,9 @@ def waic(trace, model=None, pointwise=False):
160180
"""
161181
model = modelcontext(model)
162182

163-
log_py = log_post_trace(trace, model)
183+
log_py = _log_post_trace(trace, model)
184+
if log_py.size == 0:
185+
raise ValueError('The model does not contain observed values.')
164186

165187
lppd_i = logsumexp(log_py, axis=0, b=1.0 / log_py.shape[0])
166188

@@ -210,7 +232,9 @@ def loo(trace, model=None, pointwise=False):
210232
"""
211233
model = modelcontext(model)
212234

213-
log_py = log_post_trace(trace, model)
235+
log_py = _log_post_trace(trace, model)
236+
if log_py.size == 0:
237+
raise ValueError('The model does not contain observed values.')
214238

215239
# Importance ratios
216240
r = np.exp(-log_py)

pymc3/tests/test_model_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def test_as_tensor(self):
9797
# Create a fake model and fake distribution to be used for the test
9898
fake_model = pm.Model()
9999
with fake_model:
100-
fake_distribution = pm.Normal('fake_dist', mu=0, sd=1)
100+
fake_distribution = pm.Normal.dist(mu=0, sd=1)
101101
# Create the testval attribute simply for the sake of model testing
102102
fake_distribution.testval = None
103103

pymc3/tests/test_stats.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,45 @@
88
from ..backends import ndarray
99
from ..stats import df_summary, autocorr, hpd, mc_error, quantiles, make_indices
1010
from ..theanof import floatX_array
11+
import pymc3.stats as pmstats
1112
from numpy.random import random, normal
1213
from numpy.testing import assert_equal, assert_almost_equal, assert_array_almost_equal
1314
from scipy import stats as st
1415

1516

17+
def test_log_post_trace():
18+
with pm.Model() as model:
19+
pm.Normal('y')
20+
trace = pm.sample()
21+
22+
logp = pmstats._log_post_trace(trace, model)
23+
assert logp.shape == (len(trace), 0)
24+
25+
with pm.Model() as model:
26+
pm.Normal('a')
27+
pm.Normal('y', observed=np.zeros((2, 3)))
28+
trace = pm.sample()
29+
30+
logp = pmstats._log_post_trace(trace, model)
31+
assert logp.shape == (len(trace), 6)
32+
npt.assert_allclose(logp, -0.5 * np.log(2 * np.pi), atol=1e-7)
33+
34+
with pm.Model() as model:
35+
pm.Normal('a')
36+
pm.Normal('y', observed=np.zeros((2, 3)))
37+
data = pd.DataFrame(np.zeros((3, 4)))
38+
data.values[1, 1] = np.nan
39+
pm.Normal('y2', observed=data)
40+
data = data.copy()
41+
data.values[:] = np.nan
42+
pm.Normal('y3', observed=data)
43+
trace = pm.sample()
44+
45+
logp = pmstats._log_post_trace(trace, model)
46+
assert logp.shape == (len(trace), 17)
47+
npt.assert_allclose(logp, -0.5 * np.log(2 * np.pi), atol=1e-7)
48+
49+
1650
class TestStats(SeededTest):
1751
@classmethod
1852
def setup_class(cls):

0 commit comments

Comments
 (0)