Skip to content

Commit 601b0f3

Browse files
Support 1D n and 2D p in Multinomial
1 parent 954b01d commit 601b0f3

File tree

2 files changed

+64
-18
lines changed

2 files changed

+64
-18
lines changed

pymc3/distributions/multivariate.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -269,19 +269,34 @@ class Multinomial(Discrete):
269269
270270
Parameters
271271
----------
272-
n : int
272+
n : int or array
273273
Number of trials (n > 0).
274-
p : array
274+
p : one- or two-dimensional array
275275
Probability of each one of the different outcomes. Elements must
276-
be non-negative and sum to 1. They will be automatically rescaled otherwise.
276+
be non-negative and sum to 1 along the last axis. They will be automatically
277+
rescaled otherwise.
277278
"""
278279

279280
def __init__(self, n, p, *args, **kwargs):
280281
super(Multinomial, self).__init__(*args, **kwargs)
281-
self.n = n
282-
self.p = p / tt.sum(p)
283-
self.mean = n * p
284-
self.mode = tt.cast(tt.round(n * p), 'int32')
282+
283+
p = p / tt.sum(p, axis=-1, keepdims=True)
284+
285+
if len(self.shape) == 2:
286+
try:
287+
assert n.shape == (self.shape[0],)
288+
except AttributeError:
289+
# this occurs when n is a scalar Python int or float
290+
n *= tt.ones(self.shape[0])
291+
292+
self.n = tt.shape_padright(n)
293+
self.p = p if p.ndim == 2 else tt.shape_padleft(p)
294+
else:
295+
self.n = n
296+
self.p = p
297+
298+
self.mean = self.n * self.p
299+
self.mode = tt.cast(tt.round(self.mean), 'int32')
285300

286301
def _random(self, n, p, size=None):
287302
if size == p.shape:
@@ -299,20 +314,13 @@ def logp(self, x):
299314
n = self.n
300315
p = self.p
301316

302-
if x.ndim==2:
303-
x_sum = x.sum(axis=0)
304-
k = x.shape[0]
305-
else:
306-
x_sum = x
307-
k = 1
308317
return bound(
309-
k * factln(n) - tt.sum(factln(x)) + tt.sum(x_sum * tt.log(p)),
318+
tt.sum(factln(n)) - tt.sum(factln(x)) + tt.sum(x * tt.log(p)),
310319
tt.all(x >= 0),
311-
tt.all(x <= n),
312-
tt.eq(tt.sum(x_sum), k * n),
320+
tt.all(tt.eq(tt.sum(x, axis=-1, keepdims=True), n)),
313321
tt.all(p <= 1),
314-
tt.eq(p.sum(), 1),
315-
n >= 0)
322+
tt.all(tt.eq(tt.sum(p, axis=-1), 1)),
323+
tt.all(tt.ge(n, 0)))
316324

317325

318326
def posdef(AA):

pymc3/tests/test_distributions.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,44 @@ def test_multinomial_vec(self):
579579
model_many.fastlogp({'m': vals}),
580580
decimal=4)
581581

582+
def test_multinomial_vec_2d_n(self):
583+
vals = np.array([[2,4,4], [4,3,4]])
584+
p = np.array([0.2, 0.3, 0.5])
585+
ns = np.array([10, 11])
586+
587+
with Model() as model:
588+
Multinomial('m', n=ns, p=p, shape=vals.shape)
589+
590+
assert_almost_equal(sum([multinomial_logpdf(val, n, p) for val, n in zip(vals, ns)]),
591+
model.fastlogp({'m': vals}),
592+
decimal=4)
593+
594+
def test_multinomial_vec_2d_n_2d_p(self):
595+
vals = np.array([[2,4,4], [4,3,4]])
596+
ps = np.array([[0.2, 0.3, 0.5],
597+
[0.9, 0.09, 0.01]])
598+
ns = np.array([10, 11])
599+
600+
with Model() as model:
601+
Multinomial('m', n=ns, p=ps, shape=vals.shape)
602+
603+
assert_almost_equal(sum([multinomial_logpdf(val, n, p) for val, n, p in zip(vals, ns, ps)]),
604+
model.fastlogp({'m': vals}),
605+
decimal=4)
606+
607+
def test_multinomial_vec_2d_p(self):
608+
vals = np.array([[2,4,4], [3,3,4]])
609+
ps = np.array([[0.2, 0.3, 0.5],
610+
[0.3, 0.3, 0.4]])
611+
n = 10
612+
613+
with Model() as model:
614+
Multinomial('m', n=n, p=ps, shape=vals.shape)
615+
616+
assert_almost_equal(sum([multinomial_logpdf(val, n, p) for val, p in zip(vals, ps)]),
617+
model.fastlogp({'m': vals}),
618+
decimal=4)
619+
582620
def test_categorical(self):
583621
for n in [2, 3, 4]:
584622
yield self.check_categorical, n

0 commit comments

Comments
 (0)