Skip to content

Commit 4dbd754

Browse files
ColCarrollJunpeng Lao
authored andcommitted
Fix sample_prior_predictive edge case (#3048)
1 parent 7c0571a commit 4dbd754

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

pymc3/distributions/distribution.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,14 @@ def generate_samples(generator, *args, **kwargs):
470470

471471
if broadcast_shape is None:
472472
inputs = args + tuple(kwargs.values())
473-
broadcast_shape = np.broadcast(*inputs).shape # size of generator(size=1)
473+
try:
474+
broadcast_shape = np.broadcast(*inputs).shape # size of generator(size=1)
475+
except ValueError: # sometimes happens if args have shape (500,) and (500, 4)
476+
max_dims = max(j.ndim for j in args + tuple(kwargs.values()))
477+
args = tuple([j.reshape(j.shape + (1,) * (max_dims - j.ndim)) for j in args])
478+
kwargs = {k: v.reshape(v.shape + (1,) * (max_dims - v.ndim)) for k, v in kwargs.items()}
479+
inputs = args + tuple(kwargs.values())
480+
broadcast_shape = np.broadcast(*inputs).shape # size of generator(size=1)
474481

475482
dist_shape = to_tuple(dist_shape)
476483
broadcast_shape = to_tuple(broadcast_shape)

pymc3/tests/test_sampling.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,3 +432,11 @@ def test_density_dist(self):
432432
prior = pm.sample_prior_predictive()
433433

434434
npt.assert_almost_equal(prior['a'].mean(), 0, decimal=1)
435+
436+
def test_shape_edgecase(self):
437+
with pm.Model():
438+
mu = pm.Normal('mu', shape=5)
439+
sd = pm.Uniform('sd', lower=2, upper=3)
440+
x = pm.Normal('x', mu=mu, sd=sd, shape=5)
441+
prior = pm.sample_prior_predictive(10)
442+
assert prior['mu'].shape == (10, 5)

0 commit comments

Comments
 (0)