Skip to content

Commit 954b01d

Browse files
authored
Merge pull request #1478 from AustinRochford/bugfix-multinomial-vector
Bugfix multinomial vector
2 parents 540da5f + 1a21db4 commit 954b01d

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

pymc3/distributions/multivariate.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,15 +301,15 @@ def logp(self, x):
301301

302302
if x.ndim==2:
303303
x_sum = x.sum(axis=0)
304-
n_sum = n * x.shape[0]
304+
k = x.shape[0]
305305
else:
306306
x_sum = x
307-
n_sum = n
307+
k = 1
308308
return bound(
309-
factln(n_sum) + tt.sum(x_sum * tt.log(p) - factln(x_sum)),
309+
k * factln(n) - tt.sum(factln(x)) + tt.sum(x_sum * tt.log(p)),
310310
tt.all(x >= 0),
311311
tt.all(x <= n),
312-
tt.eq(tt.sum(x_sum), n_sum),
312+
tt.eq(tt.sum(x_sum), k * n),
313313
tt.all(p <= 1),
314314
tt.eq(p.sum(), 1),
315315
n >= 0)

pymc3/tests/test_distributions.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -568,13 +568,16 @@ def test_multinomial_vec(self):
568568
vals = np.array([[2,4,4], [3,3,4]])
569569
p = np.array([0.2, 0.3, 0.5])
570570
n = 10
571-
with Model() as model:
572-
Multinomial('m', n=10, p=p, shape=vals.shape)
573-
pt = {'m': vals}
574-
with Model() as model_sum:
575-
Multinomial('m_sum', n=2*n, p=p, shape=len(p))
576-
pt_sum = {'m_sum': vals.sum(0)}
577-
assert_almost_equal(model.fastlogp(pt), model_sum.fastlogp(pt_sum), decimal=4)
571+
572+
with Model() as model_single:
573+
Multinomial('m', n=n, p=p, shape=len(p))
574+
575+
with Model() as model_many:
576+
Multinomial('m', n=n, p=p, shape=vals.shape)
577+
578+
assert_almost_equal(sum([model_single.fastlogp({'m': val}) for val in vals]),
579+
model_many.fastlogp({'m': vals}),
580+
decimal=4)
578581

579582
def test_categorical(self):
580583
for n in [2, 3, 4]:

0 commit comments

Comments
 (0)