Skip to content

Commit 1043f21

Browse files
Chris Fonnesbecktwiecki
authored andcommitted
Generalized multinomial to accept multiple observations (#1390)
* Generalized multinomial to accept multiple observations * Added test for multinomial with multiple observations * Added shape to Multinomial call in test_multinomial_vec * Fix for test failur in multinomial * Fix for test failure in multinomial * Added correct shape to multinomial_vec test * Dialed down assert_equal precision in mutlinomial test
1 parent e9ee56a commit 1043f21

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

pymc3/distributions/multivariate.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,20 @@ def random(self, point=None, size=None):
257257
def logp(self, x):
258258
n = self.n
259259
p = self.p
260-
# only defined for sum(p) == 1
260+
261+
if x.ndim==2:
262+
x_sum = x.sum(axis=0)
263+
n_sum = n * x.shape[0]
264+
else:
265+
x_sum = x
266+
n_sum = n
267+
261268
return bound(
262-
factln(n) + tt.sum(x * tt.log(p) - factln(x)),
263-
tt.all(x >= 0), tt.all(x <= n), tt.eq(tt.sum(x), n),
269+
factln(n_sum) + tt.sum(x_sum * tt.log(p) - factln(x_sum)),
270+
tt.all(x >= 0),
271+
tt.all(x <= n),
272+
tt.eq(tt.sum(x_sum), n_sum),
273+
tt.isclose(p.sum(), 1),
264274
n >= 0)
265275

266276

pymc3/tests/test_distributions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,18 @@ def check_multinomial(self, n):
560560
self.pymc3_matches_scipy(Multinomial, Vector(Nat, n), {'p': Simplex(n), 'n': Nat},
561561
multinomial_logpdf)
562562

563+
def test_multinomial_vec(self):
564+
vals = np.array([[2,4,4], [3,3,4]])
565+
p = np.array([0.2, 0.3, 0.5])
566+
n = 10
567+
with Model() as model:
568+
Multinomial('m', n=10, p=p, shape=vals.shape)
569+
pt = {'m': vals}
570+
with Model() as model_sum:
571+
Multinomial('m_sum', n=2*n, p=p, shape=len(p))
572+
pt_sum = {'m_sum': vals.sum(0)}
573+
assert_almost_equal(model.fastlogp(pt), model_sum.fastlogp(pt_sum), decimal=4)
574+
563575
def test_categorical(self):
564576
for n in [2, 3, 4]:
565577
yield self.check_categorical, n

0 commit comments

Comments
 (0)