Skip to content

Commit 59f68c3

Browse files
committed
Make all univariate distributions use bound_elemwise. Addd tests.
1 parent 1ad526a commit 59f68c3

File tree

3 files changed

+106
-61
lines changed

3 files changed

+106
-61
lines changed

pymc3/distributions/continuous.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def logp(self, value):
243243
sd = self.sd
244244
tau = self.tau
245245
mu = self.mu
246-
return bound((-tau * (value - mu)**2 + tt.log(tau / np.pi / 2.)) / 2.,
246+
return bound_elemwise((-tau * (value - mu)**2 + tt.log(tau / np.pi / 2.)) / 2.,
247247
sd > 0)
248248

249249

@@ -289,7 +289,7 @@ def random(self, point=None, size=None, repeat=None):
289289
def logp(self, value):
290290
tau = self.tau
291291
sd = self.sd
292-
return bound(-0.5 * tau * value**2 + 0.5 * tt.log(tau * 2. / np.pi),
292+
return bound_elemwise(-0.5 * tau * value**2 + 0.5 * tt.log(tau * 2. / np.pi),
293293
value >= 0,
294294
tau > 0, sd > 0)
295295

@@ -402,7 +402,7 @@ def logp(self, value):
402402
lam = self.lam
403403
alpha = self.alpha
404404
# value *must* be iid. Otherwise this is wrong.
405-
return bound(logpow(lam / (2. * np.pi), 0.5)
405+
return bound_elemwise(logpow(lam / (2. * np.pi), 0.5)
406406
- logpow(value - alpha, 1.5)
407407
- (0.5 * lam / (value - alpha)
408408
* ((value - alpha - mu) / mu)**2),
@@ -492,7 +492,7 @@ def logp(self, value):
492492
alpha = self.alpha
493493
beta = self.beta
494494

495-
return bound(logpow(value, alpha - 1) + logpow(1 - value, beta - 1)
495+
return bound_elemwise(logpow(value, alpha - 1) + logpow(1 - value, beta - 1)
496496
- betaln(alpha, beta),
497497
value >= 0, value <= 1,
498498
alpha > 0, beta > 0)
@@ -537,7 +537,7 @@ def random(self, point=None, size=None, repeat=None):
537537

538538
def logp(self, value):
539539
lam = self.lam
540-
return bound(tt.log(lam) - lam * value, value > 0, lam > 0)
540+
return bound_elemwise(tt.log(lam) - lam * value, value > 0, lam > 0)
541541

542542

543543
class Laplace(Continuous):
@@ -641,7 +641,7 @@ def random(self, point=None, size=None, repeat=None):
641641
def logp(self, value):
642642
mu = self.mu
643643
tau = self.tau
644-
return bound(-0.5 * tau * (tt.log(value) - mu)**2
644+
return bound_elemwise(-0.5 * tau * (tt.log(value) - mu)**2
645645
+ 0.5 * tt.log(tau / (2. * np.pi))
646646
- tt.log(value),
647647
tau > 0)
@@ -702,7 +702,7 @@ def logp(self, value):
702702
lam = self.lam
703703
sd = self.sd
704704

705-
return bound(gammaln((nu + 1.0) / 2.0)
705+
return bound_elemwise(gammaln((nu + 1.0) / 2.0)
706706
+ .5 * tt.log(lam / (nu * np.pi))
707707
- gammaln(nu / 2.0)
708708
- (nu + 1.0) / 2.0 * tt.log1p(lam * (value - mu)**2 / nu),
@@ -765,7 +765,7 @@ def random(self, point=None, size=None, repeat=None):
765765
def logp(self, value):
766766
alpha = self.alpha
767767
m = self.m
768-
return bound(tt.log(alpha) + logpow(m, alpha)
768+
return bound_elemwise(tt.log(alpha) + logpow(m, alpha)
769769
- logpow(value, alpha + 1),
770770
value >= m, alpha > 0, m > 0)
771771

@@ -817,7 +817,7 @@ def random(self, point=None, size=None, repeat=None):
817817
def logp(self, value):
818818
alpha = self.alpha
819819
beta = self.beta
820-
return bound(- tt.log(np.pi) - tt.log(beta)
820+
return bound_elemwise(- tt.log(np.pi) - tt.log(beta)
821821
- tt.log1p(((value - alpha) / beta)**2),
822822
beta > 0)
823823

@@ -863,7 +863,7 @@ def random(self, point=None, size=None, repeat=None):
863863

864864
def logp(self, value):
865865
beta = self.beta
866-
return bound(tt.log(2) - tt.log(np.pi) - tt.log(beta)
866+
return bound_elemwise(tt.log(2) - tt.log(np.pi) - tt.log(beta)
867867
- tt.log1p((value / beta)**2),
868868
value >= 0, beta > 0)
869869

@@ -943,7 +943,7 @@ def random(self, point=None, size=None, repeat=None):
943943
def logp(self, value):
944944
alpha = self.alpha
945945
beta = self.beta
946-
return bound(
946+
return bound_elemwise(
947947
-gammaln(alpha) + logpow(
948948
beta, alpha) - beta * value + logpow(value, alpha - 1),
949949

@@ -1007,7 +1007,7 @@ def random(self, point=None, size=None, repeat=None):
10071007
def logp(self, value):
10081008
alpha = self.alpha
10091009
beta = self.beta
1010-
return bound(logpow(beta, alpha) - gammaln(alpha) - beta / value
1010+
return bound_elemwise(logpow(beta, alpha) - gammaln(alpha) - beta / value
10111011
+ logpow(value, -alpha - 1),
10121012
value > 0, alpha > 0, beta > 0)
10131013

@@ -1088,7 +1088,7 @@ def _random(a, b, size=None):
10881088
def logp(self, value):
10891089
alpha = self.alpha
10901090
beta = self.beta
1091-
return bound(tt.log(alpha) - tt.log(beta)
1091+
return bound_elemwise(tt.log(alpha) - tt.log(beta)
10921092
+ (alpha - 1) * tt.log(value / beta)
10931093
- (value / beta)**alpha,
10941094
value >= 0, alpha > 0, beta > 0)
@@ -1131,12 +1131,12 @@ def __init__(self, distribution, lower, upper, transform='infer', *args, **kwarg
11311131
self.testval = 0.5 * (upper + lower)
11321132

11331133
if not np.isinf(lower) and np.isinf(upper):
1134-
self.transform = transforms.lowerbound(lower)
1134+
self.transform = transforms.lowerbound_elemwise(lower)
11351135
if default <= lower:
11361136
self.testval = lower + 1
11371137

11381138
if np.isinf(lower) and not np.isinf(upper):
1139-
self.transform = transforms.upperbound(upper)
1139+
self.transform = transforms.upperbound_elemwise(upper)
11401140
if default >= upper:
11411141
self.testval = upper - 1
11421142

@@ -1161,7 +1161,7 @@ def random(self, point=None, size=None, repeat=None):
11611161
size=size)
11621162

11631163
def logp(self, value):
1164-
return bound(self.dist.logp(value),
1164+
return bound_elemwise(self.dist.logp(value),
11651165
value >= self.lower, value <= self.upper)
11661166

11671167

@@ -1286,7 +1286,7 @@ def logp(self, value):
12861286
+ logpow(std_cdf((value - mu) / sigma - sigma / nu), 1.),
12871287
- tt.log(sigma * tt.sqrt(2 * np.pi))
12881288
- 0.5 * ((value - mu) / sigma)**2)
1289-
return bound(lp, sigma > 0., nu > 0.)
1289+
return bound_elemwise(lp, sigma > 0., nu > 0.)
12901290

12911291

12921292
class VonMises(Continuous):
@@ -1335,7 +1335,7 @@ def random(self, point=None, size=None, repeat=None):
13351335
def logp(self, value):
13361336
mu = self.mu
13371337
kappa = self.kappa
1338-
return bound(kappa * tt.cos(mu - value) - tt.log(2 * np.pi * i0(kappa)), value >= -np.pi, value <= np.pi, kappa >= 0)
1338+
return bound_elemwise(kappa * tt.cos(mu - value) - tt.log(2 * np.pi * i0(kappa)), value >= -np.pi, value <= np.pi, kappa >= 0)
13391339

13401340

13411341
class SkewNormal(Continuous):
@@ -1401,7 +1401,7 @@ def logp(self, value):
14011401
sd = self.sd
14021402
mu = self.mu
14031403
alpha = self.alpha
1404-
return bound(
1404+
return bound_elemwise(
14051405
tt.log(1 +
14061406
tt.erf(((value - mu) * tt.sqrt(tau) * alpha) / tt.sqrt(2)))
14071407
+ (-tau * (value - mu)**2

pymc3/distributions/discrete.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def logp(self, value):
5454
n = self.n
5555
p = self.p
5656

57-
return bound(
57+
return bound_elemwise(
5858
binomln(n, value) + logpow(p, value) + logpow(1 - p, n - value),
5959
0 <= value, value <= n,
6060
0 <= p, p <= 1)
@@ -118,7 +118,7 @@ def random(self, point=None, size=None, repeat=None):
118118
def logp(self, value):
119119
alpha = self.alpha
120120
beta = self.beta
121-
return bound(binomln(self.n, value)
121+
return bound_elemwise(binomln(self.n, value)
122122
+ betaln(value + alpha, self.n - value + beta)
123123
- betaln(alpha, beta),
124124
value >= 0, value <= self.n,
@@ -158,7 +158,7 @@ def random(self, point=None, size=None, repeat=None):
158158

159159
def logp(self, value):
160160
p = self.p
161-
return bound(
161+
return bound_elemwise(
162162
tt.switch(value, tt.log(p), tt.log(1 - p)),
163163
value >= 0, value <= 1,
164164
p >= 0, p <= 1)
@@ -204,7 +204,7 @@ def random(self, point=None, size=None, repeat=None):
204204

205205
def logp(self, value):
206206
mu = self.mu
207-
log_prob = bound(
207+
log_prob = bound_elemwise(
208208
logpow(mu, value) - factln(value) - mu,
209209
mu >= 0, value >= 0)
210210
# Return zero when mu and value are both zero
@@ -255,7 +255,7 @@ def random(self, point=None, size=None, repeat=None):
255255
def logp(self, value):
256256
mu = self.mu
257257
alpha = self.alpha
258-
negbinom = bound(binomln(value + alpha - 1, value)
258+
negbinom = bound_elemwise(binomln(value + alpha - 1, value)
259259
+ logpow(mu / (mu + alpha), value)
260260
+ logpow(alpha / (mu + alpha), alpha),
261261
value >= 0, mu > 0, alpha > 0)
@@ -300,7 +300,7 @@ def random(self, point=None, size=None, repeat=None):
300300

301301
def logp(self, value):
302302
p = self.p
303-
return bound(tt.log(p) + logpow(1 - p, value - 1),
303+
return bound_elemwise(tt.log(p) + logpow(1 - p, value - 1),
304304
0 <= p, p <= 1, value >= 1)
305305

306306

@@ -348,7 +348,7 @@ def random(self, point=None, size=None, repeat=None):
348348
def logp(self, value):
349349
upper = self.upper
350350
lower = self.lower
351-
return bound_elemwise(-tt.log(upper - lower + 1) * tt.ones_like(value),
351+
return bound_elemwise(-tt.log(upper - lower + 1),
352352
lower <= value, value <= upper)
353353

354354

@@ -408,7 +408,7 @@ def logp(self, value):
408408
a = tt.log(p[tt.arange(p.shape[0]), value])
409409
else:
410410
a = tt.log(p[value])
411-
return bound(a,
411+
return bound_elemwise(a,
412412
value >= 0, value <= (k - 1),
413413
sumto1)
414414

@@ -439,7 +439,7 @@ def _random(c, dtype=dtype, size=None):
439439

440440
def logp(self, value):
441441
c = self.c
442-
return bound(0, tt.eq(value, c))
442+
return bound_elemwise(0, tt.eq(value, c))
443443

444444
def ConstantDist(*args, **kwargs):
445445
warnings.warn("ConstantDist has been deprecated. In future, use Constant instead.",

pymc3/tests/test_dist_math.py

Lines changed: 78 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,81 @@
11
import numpy as np
22
import theano.tensor as tt
3+
import pymc3 as pm
34

4-
from ..distributions.dist_math import alltrue
5-
6-
7-
def test_alltrue():
8-
assert alltrue([]).eval()
9-
assert alltrue([True]).eval()
10-
assert alltrue([tt.ones(10)]).eval()
11-
assert alltrue([tt.ones(10),
12-
5 * tt.ones(101)]).eval()
13-
assert alltrue([np.ones(10),
14-
5 * tt.ones(101)]).eval()
15-
assert alltrue([np.ones(10),
16-
True,
17-
5 * tt.ones(101)]).eval()
18-
assert alltrue([np.array([1, 2, 3]),
19-
True,
20-
5 * tt.ones(101)]).eval()
21-
22-
assert not alltrue([False]).eval()
23-
assert not alltrue([tt.zeros(10)]).eval()
24-
assert not alltrue([True,
25-
False]).eval()
26-
assert not alltrue([np.array([0, -1]),
27-
tt.ones(60)]).eval()
28-
assert not alltrue([np.ones(10),
29-
False,
30-
5 * tt.ones(101)]).eval()
31-
32-
33-
def test_alltrue_shape():
34-
vals = [True, tt.ones(10), tt.zeros(5)]
35-
36-
assert alltrue(vals).eval().shape == ()
5+
from ..distributions import Discrete
6+
from ..distributions.dist_math import bound_elemwise, bound, factln
7+
8+
9+
def test_bound_elemwise():
10+
logp = tt.ones((10, 10))
11+
cond = tt.ones((10, 10))
12+
assert np.all(bound_elemwise(logp, cond).eval() == logp.eval())
13+
14+
logp = tt.ones((10, 10))
15+
cond = tt.zeros((10, 10))
16+
assert np.all(bound_elemwise(logp, cond).eval() == (-np.inf * logp).eval())
17+
18+
logp = tt.ones((10, 10))
19+
cond = True
20+
assert np.all(bound_elemwise(logp, cond).eval() == logp.eval())
21+
22+
logp = tt.ones(3)
23+
cond = np.array([1, 0, 1])
24+
assert not np.all(bound_elemwise(logp, cond).eval() == 1)
25+
assert np.prod(bound_elemwise(logp, cond).eval()) == -np.inf
26+
27+
logp = tt.ones((2, 3))
28+
cond = np.array([[1, 1, 1], [1, 0, 1]])
29+
assert not np.all(bound_elemwise(logp, cond).eval() == 1)
30+
assert np.prod(bound_elemwise(logp, cond).eval()) == -np.inf
31+
32+
33+
class MultinomialA(Discrete):
34+
def __init__(self, n, p, *args, **kwargs):
35+
super(MultinomialA, self).__init__(*args, **kwargs)
36+
37+
self.n = n
38+
self.p = p
39+
40+
def logp(self, value):
41+
n = self.n
42+
p = self.p
43+
44+
return bound(factln(n) - factln(value).sum() + (value * tt.log(p)).sum(),
45+
value >= 0,
46+
0 <= p, p <= 1,
47+
tt.isclose(p.sum(), 1))
48+
49+
50+
class MultinomialB(Discrete):
51+
def __init__(self, n, p, *args, **kwargs):
52+
super(MultinomialB, self).__init__(*args, **kwargs)
53+
54+
self.n = n
55+
self.p = p
56+
57+
def logp(self, value):
58+
n = self.n
59+
p = self.p
60+
61+
return bound(factln(n) - factln(value).sum() + (value * tt.log(p)).sum(),
62+
tt.all(value >= 0),
63+
tt.all(0 <= p), tt.all(p <= 1),
64+
tt.isclose(p.sum(), 1))
65+
66+
67+
def test_multinomial_bound():
68+
69+
x = np.array([1, 5])
70+
n = x.sum()
71+
72+
with pm.Model() as modelA:
73+
p_a = pm.Dirichlet('p', np.ones(2))
74+
x_obs_a = MultinomialA('x', n, p_a, observed=x)
75+
76+
with pm.Model() as modelB:
77+
p_b = pm.Dirichlet('p', np.ones(2))
78+
x_obs_b = MultinomialB('x', n, p_b, observed=x)
79+
80+
assert np.isclose(modelA.logp({'p_stickbreaking_': [0]}),
81+
modelB.logp({'p_stickbreaking_': [0]}))

0 commit comments

Comments
 (0)