Skip to content

Commit d9e57e8

Browse files
authored
Merge pull request #1596 from pymc-devs/fix_bounds2
Make bound broadcast again.
2 parents 1a1a36e + 6c7f127 commit d9e57e8

File tree

5 files changed

+132
-29
lines changed

5 files changed

+132
-29
lines changed

pymc3/distributions/continuous.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def logp(self, value):
147147
lower = self.lower
148148
upper = self.upper
149149
return bound(-tt.log(upper - lower),
150-
value >= lower, value <= upper)
150+
value >= lower, value <= upper)
151151

152152

153153
class Flat(Continuous):

pymc3/distributions/discrete.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ class ZeroInflatedPoisson(Discrete):
454454
Often used to model the number of events occurring in a fixed period
455455
of time when the times at which events occur are independent.
456456
457-
.. math::
457+
.. math::
458458
459459
f(x \mid \theta, \psi) = \left\{ \begin{array}{l}
460460
(1-\psi) + \psi e^{-\theta}, \text{if } x = 0 \\
@@ -503,7 +503,7 @@ class ZeroInflatedNegativeBinomial(Discrete):
503503
504504
The Zero-inflated version of the Negative Binomial (NB).
505505
The NB distribution describes a Poisson random variable
506-
whose rate parameter is gamma distributed.
506+
whose rate parameter is gamma distributed.
507507
508508
.. math::
509509

pymc3/distributions/dist_math.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,43 @@
1010

1111
from .special import gammaln, multigammaln
1212

13-
14-
def bound(logp, *conditions):
13+
def bound(logp, *conditions, **kwargs):
1514
"""
16-
Bounds a log probability density with several conditions
15+
Bounds a log probability density with several conditions.
1716
1817
Parameters
1918
----------
2019
logp : float
2120
*conditions : booleans
21+
broadcast_conditions : bool (optional, default=True)
22+
If True, broadcasts logp to match the largest shape of the conditions.
23+
This is used e.g. in DiscreteUniform where logp is a scalar constant and the shape
24+
is specified via the conditions.
25+
If False, will return the same shape as logp.
26+
This is used e.g. in Multinomial where broadcasting can lead to differences in the logp.
2227
2328
Returns
2429
-------
25-
logp if all conditions are true
26-
-inf if some are false
30+
logp with elements set to -inf where any condition is False
2731
"""
32+
broadcast_conditions = kwargs.get('broadcast_conditions', True)
33+
34+
if broadcast_conditions:
35+
alltrue = alltrue_elemwise
36+
else:
37+
alltrue = alltrue_scalar
38+
2839
return tt.switch(alltrue(conditions), logp, -np.inf)
2940

3041

31-
def alltrue(vals):
42+
def alltrue_elemwise(vals):
43+
ret = 1
44+
for c in vals:
45+
ret = ret * (1 * c)
46+
return ret
47+
48+
49+
def alltrue_scalar(vals):
3250
return tt.all([tt.all(1 * val) for val in vals])
3351

3452

pymc3/distributions/multivariate.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ def logp(self, value):
245245
return bound(tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1)
246246
+ gammaln(tt.sum(a, axis=-1)),
247247
tt.all(value >= 0), tt.all(value <= 1),
248-
k > 1, tt.all(a > 0))
248+
k > 1, tt.all(a > 0),
249+
broadcast_conditions=False)
249250

250251

251252
class Multinomial(Discrete):
@@ -323,7 +324,9 @@ def logp(self, x):
323324
tt.all(tt.eq(tt.sum(x, axis=-1, keepdims=True), n)),
324325
tt.all(p <= 1),
325326
tt.all(tt.eq(tt.sum(p, axis=-1), 1)),
326-
tt.all(tt.ge(n, 0)))
327+
tt.all(tt.ge(n, 0)),
328+
broadcast_conditions=False
329+
)
327330

328331

329332
def posdef(AA):
@@ -443,7 +446,9 @@ def logp(self, X):
443446
- 2 * multigammaln(n / 2., p)) / 2,
444447
matrix_pos_def(X),
445448
tt.eq(X, X.T),
446-
n > (p - 1))
449+
n > (p - 1),
450+
broadcast_conditions=False
451+
)
447452

448453

449454
def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, testval=None):
@@ -605,4 +610,6 @@ def logp(self, x):
605610
return bound(result,
606611
tt.all(X <= 1), tt.all(X >= -1),
607612
matrix_pos_def(X),
608-
n > 0)
613+
n > 0,
614+
broadcast_conditions=False
615+
)

pymc3/tests/test_dist_math.py

Lines changed: 94 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,114 @@
11
import numpy as np
22
import theano.tensor as tt
3+
import pymc3 as pm
34

4-
from ..distributions.dist_math import alltrue
5+
from ..distributions import Discrete
6+
from ..distributions.dist_math import bound, factln, alltrue_elemwise, alltrue_scalar
57

68

7-
def test_alltrue():
8-
assert alltrue([]).eval()
9-
assert alltrue([True]).eval()
10-
assert alltrue([tt.ones(10)]).eval()
11-
assert alltrue([tt.ones(10),
9+
def test_bound():
10+
logp = tt.ones((10, 10))
11+
cond = tt.ones((10, 10))
12+
assert np.all(bound(logp, cond).eval() == logp.eval())
13+
14+
logp = tt.ones((10, 10))
15+
cond = tt.zeros((10, 10))
16+
assert np.all(bound(logp, cond).eval() == (-np.inf * logp).eval())
17+
18+
logp = tt.ones((10, 10))
19+
cond = True
20+
assert np.all(bound(logp, cond).eval() == logp.eval())
21+
22+
logp = tt.ones(3)
23+
cond = np.array([1, 0, 1])
24+
assert not np.all(bound(logp, cond).eval() == 1)
25+
assert np.prod(bound(logp, cond).eval()) == -np.inf
26+
27+
logp = tt.ones((2, 3))
28+
cond = np.array([[1, 1, 1], [1, 0, 1]])
29+
assert not np.all(bound(logp, cond).eval() == 1)
30+
assert np.prod(bound(logp, cond).eval()) == -np.inf
31+
32+
def test_alltrue_scalar():
33+
assert alltrue_scalar([]).eval()
34+
assert alltrue_scalar([True]).eval()
35+
assert alltrue_scalar([tt.ones(10)]).eval()
36+
assert alltrue_scalar([tt.ones(10),
1237
5 * tt.ones(101)]).eval()
13-
assert alltrue([np.ones(10),
38+
assert alltrue_scalar([np.ones(10),
1439
5 * tt.ones(101)]).eval()
15-
assert alltrue([np.ones(10),
40+
assert alltrue_scalar([np.ones(10),
1641
True,
1742
5 * tt.ones(101)]).eval()
18-
assert alltrue([np.array([1, 2, 3]),
43+
assert alltrue_scalar([np.array([1, 2, 3]),
1944
True,
2045
5 * tt.ones(101)]).eval()
2146

22-
assert not alltrue([False]).eval()
23-
assert not alltrue([tt.zeros(10)]).eval()
24-
assert not alltrue([True,
47+
assert not alltrue_scalar([False]).eval()
48+
assert not alltrue_scalar([tt.zeros(10)]).eval()
49+
assert not alltrue_scalar([True,
2550
False]).eval()
26-
assert not alltrue([np.array([0, -1]),
51+
assert not alltrue_scalar([np.array([0, -1]),
2752
tt.ones(60)]).eval()
28-
assert not alltrue([np.ones(10),
53+
assert not alltrue_scalar([np.ones(10),
2954
False,
3055
5 * tt.ones(101)]).eval()
3156

32-
3357
def test_alltrue_shape():
3458
vals = [True, tt.ones(10), tt.zeros(5)]
3559

36-
assert alltrue(vals).eval().shape == ()
60+
assert alltrue_scalar(vals).eval().shape == ()
61+
62+
class MultinomialA(Discrete):
63+
def __init__(self, n, p, *args, **kwargs):
64+
super(MultinomialA, self).__init__(*args, **kwargs)
65+
66+
self.n = n
67+
self.p = p
68+
69+
def logp(self, value):
70+
n = self.n
71+
p = self.p
72+
73+
return bound(factln(n) - factln(value).sum() + (value * tt.log(p)).sum(),
74+
value >= 0,
75+
0 <= p, p <= 1,
76+
tt.isclose(p.sum(), 1),
77+
broadcast_conditions=False
78+
)
79+
80+
81+
class MultinomialB(Discrete):
82+
def __init__(self, n, p, *args, **kwargs):
83+
super(MultinomialB, self).__init__(*args, **kwargs)
84+
85+
self.n = n
86+
self.p = p
87+
88+
def logp(self, value):
89+
n = self.n
90+
p = self.p
91+
92+
return bound(factln(n) - factln(value).sum() + (value * tt.log(p)).sum(),
93+
tt.all(value >= 0),
94+
tt.all(0 <= p), tt.all(p <= 1),
95+
tt.isclose(p.sum(), 1),
96+
broadcast_conditions=False
97+
)
98+
99+
100+
def test_multinomial_bound():
101+
102+
x = np.array([1, 5])
103+
n = x.sum()
104+
105+
with pm.Model() as modelA:
106+
p_a = pm.Dirichlet('p', np.ones(2))
107+
x_obs_a = MultinomialA('x', n, p_a, observed=x)
108+
109+
with pm.Model() as modelB:
110+
p_b = pm.Dirichlet('p', np.ones(2))
111+
x_obs_b = MultinomialB('x', n, p_b, observed=x)
112+
113+
assert np.isclose(modelA.logp({'p_stickbreaking_': [0]}),
114+
modelB.logp({'p_stickbreaking_': [0]}))

0 commit comments

Comments
 (0)