Skip to content

Commit f0003ed

Browse files
lucianopazjunpenglao
authored andcommitted
Rebase (#3297)
1 parent b3a3e0b commit f0003ed

File tree

3 files changed

+19
-11
lines changed

3 files changed

+19
-11
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
- Rewrote `Multinomial._random` method to better handle shape broadcasting (#3271)
3434
- Fixed `Rice` distribution, which inconsistently mixed two parametrizations (#3286).
3535
- `Rice` distribution now accepts multiple parameters and observations and is usable with NUTS (#3289).
36+
- `sample_posterior_predictive` no longer calls `draw_values` to initialize the shape of the ppc trace. This called could lead to `ValueError`'s when sampling the ppc from a model with `Flat` or `HalfFlat` prior distributions (Fix issue #3294).
3637

3738

3839
### Deprecations

pymc3/sampling.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,17 +1123,9 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size
11231123
if progressbar:
11241124
indices = tqdm(indices, total=samples)
11251125

1126-
varnames = [var.name for var in vars]
1127-
1128-
# draw once to inspect the shape
1129-
var_values = list(zip(varnames,
1130-
draw_values(vars, point=model.test_point, size=size)))
11311126
ppc_trace = defaultdict(list)
1132-
for varname, value in var_values:
1133-
ppc_trace[varname] = np.zeros((samples,) + value.shape, value.dtype)
1134-
11351127
try:
1136-
for slc, idx in enumerate(indices):
1128+
for idx in indices:
11371129
if nchain > 1:
11381130
chain_idx, point_idx = np.divmod(idx, len_trace)
11391131
param = trace._straces[chain_idx % nchain].point(point_idx)
@@ -1142,7 +1134,7 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size
11421134

11431135
values = draw_values(vars, point=param, size=size)
11441136
for k, v in zip(vars, values):
1145-
ppc_trace[k.name][slc] = v
1137+
ppc_trace[k.name].append(v)
11461138

11471139
except KeyboardInterrupt:
11481140
pass
@@ -1151,7 +1143,7 @@ def sample_posterior_predictive(trace, samples=None, model=None, vars=None, size
11511143
if progressbar:
11521144
indices.close()
11531145

1154-
return ppc_trace
1146+
return {k: np.asarray(v) for k, v in ppc_trace.items()}
11551147

11561148

11571149
def sample_ppc(*args, **kwargs):

pymc3/tests/test_sampling.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,21 @@ def test_sum_normal(self):
289289
_, pval = stats.kstest(ppc['b'], stats.norm(scale=scale).cdf)
290290
assert pval > 0.001
291291

292+
def test_model_not_drawable_prior(self):
293+
data = np.random.poisson(lam=10, size=200)
294+
model = pm.Model()
295+
with model:
296+
mu = pm.HalfFlat('sigma')
297+
pm.Poisson('foo', mu=mu, observed=data)
298+
trace = pm.sample(tune=1000)
299+
300+
with model:
301+
with pytest.raises(ValueError) as excinfo:
302+
pm.sample_prior_predictive(50)
303+
assert "Cannot sample" in str(excinfo.value)
304+
samples = pm.sample_posterior_predictive(trace, 50)
305+
assert samples['foo'].shape == (50, 200)
306+
292307

293308
class TestSamplePPCW(SeededTest):
294309
def test_sample_posterior_predictive_w(self):

0 commit comments

Comments
 (0)