Skip to content

Commit 1b01b53

Browse files
committed
MAINT Instead of separate functions, add broadcast kwarg and make it be used by multivariate distributions.
1 parent ed89f10 commit 1b01b53

File tree

5 files changed

+102
-69
lines changed

5 files changed

+102
-69
lines changed

pymc3/distributions/continuous.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import warnings
1414

1515
from . import transforms
16-
from .dist_math import bound, bound_elemwise, logpow, gammaln, betaln, std_cdf, i0, i1
16+
from .dist_math import bound, bound, logpow, gammaln, betaln, std_cdf, i0, i1
1717
from .distribution import Continuous, draw_values, generate_samples
1818

1919
__all__ = ['Uniform', 'Flat', 'Normal', 'Beta', 'Exponential', 'Laplace',
@@ -146,7 +146,7 @@ def random(self, point=None, size=None, repeat=None):
146146
def logp(self, value):
147147
lower = self.lower
148148
upper = self.upper
149-
return bound_elemwise(-tt.log(upper - lower),
149+
return bound(-tt.log(upper - lower),
150150
value >= lower, value <= upper)
151151

152152

@@ -243,7 +243,7 @@ def logp(self, value):
243243
sd = self.sd
244244
tau = self.tau
245245
mu = self.mu
246-
return bound_elemwise((-tau * (value - mu)**2 + tt.log(tau / np.pi / 2.)) / 2.,
246+
return bound((-tau * (value - mu)**2 + tt.log(tau / np.pi / 2.)) / 2.,
247247
sd > 0)
248248

249249

@@ -289,7 +289,7 @@ def random(self, point=None, size=None, repeat=None):
289289
def logp(self, value):
290290
tau = self.tau
291291
sd = self.sd
292-
return bound_elemwise(-0.5 * tau * value**2 + 0.5 * tt.log(tau * 2. / np.pi),
292+
return bound(-0.5 * tau * value**2 + 0.5 * tt.log(tau * 2. / np.pi),
293293
value >= 0,
294294
tau > 0, sd > 0)
295295

@@ -402,7 +402,7 @@ def logp(self, value):
402402
lam = self.lam
403403
alpha = self.alpha
404404
# value *must* be iid. Otherwise this is wrong.
405-
return bound_elemwise(logpow(lam / (2. * np.pi), 0.5)
405+
return bound(logpow(lam / (2. * np.pi), 0.5)
406406
- logpow(value - alpha, 1.5)
407407
- (0.5 * lam / (value - alpha)
408408
* ((value - alpha - mu) / mu)**2),
@@ -492,7 +492,7 @@ def logp(self, value):
492492
alpha = self.alpha
493493
beta = self.beta
494494

495-
return bound_elemwise(logpow(value, alpha - 1) + logpow(1 - value, beta - 1)
495+
return bound(logpow(value, alpha - 1) + logpow(1 - value, beta - 1)
496496
- betaln(alpha, beta),
497497
value >= 0, value <= 1,
498498
alpha > 0, beta > 0)
@@ -537,7 +537,7 @@ def random(self, point=None, size=None, repeat=None):
537537

538538
def logp(self, value):
539539
lam = self.lam
540-
return bound_elemwise(tt.log(lam) - lam * value, value > 0, lam > 0)
540+
return bound(tt.log(lam) - lam * value, value > 0, lam > 0)
541541

542542

543543
class Laplace(Continuous):
@@ -641,7 +641,7 @@ def random(self, point=None, size=None, repeat=None):
641641
def logp(self, value):
642642
mu = self.mu
643643
tau = self.tau
644-
return bound_elemwise(-0.5 * tau * (tt.log(value) - mu)**2
644+
return bound(-0.5 * tau * (tt.log(value) - mu)**2
645645
+ 0.5 * tt.log(tau / (2. * np.pi))
646646
- tt.log(value),
647647
tau > 0)
@@ -702,7 +702,7 @@ def logp(self, value):
702702
lam = self.lam
703703
sd = self.sd
704704

705-
return bound_elemwise(gammaln((nu + 1.0) / 2.0)
705+
return bound(gammaln((nu + 1.0) / 2.0)
706706
+ .5 * tt.log(lam / (nu * np.pi))
707707
- gammaln(nu / 2.0)
708708
- (nu + 1.0) / 2.0 * tt.log1p(lam * (value - mu)**2 / nu),
@@ -765,7 +765,7 @@ def random(self, point=None, size=None, repeat=None):
765765
def logp(self, value):
766766
alpha = self.alpha
767767
m = self.m
768-
return bound_elemwise(tt.log(alpha) + logpow(m, alpha)
768+
return bound(tt.log(alpha) + logpow(m, alpha)
769769
- logpow(value, alpha + 1),
770770
value >= m, alpha > 0, m > 0)
771771

@@ -817,7 +817,7 @@ def random(self, point=None, size=None, repeat=None):
817817
def logp(self, value):
818818
alpha = self.alpha
819819
beta = self.beta
820-
return bound_elemwise(- tt.log(np.pi) - tt.log(beta)
820+
return bound(- tt.log(np.pi) - tt.log(beta)
821821
- tt.log1p(((value - alpha) / beta)**2),
822822
beta > 0)
823823

@@ -863,7 +863,7 @@ def random(self, point=None, size=None, repeat=None):
863863

864864
def logp(self, value):
865865
beta = self.beta
866-
return bound_elemwise(tt.log(2) - tt.log(np.pi) - tt.log(beta)
866+
return bound(tt.log(2) - tt.log(np.pi) - tt.log(beta)
867867
- tt.log1p((value / beta)**2),
868868
value >= 0, beta > 0)
869869

@@ -943,7 +943,7 @@ def random(self, point=None, size=None, repeat=None):
943943
def logp(self, value):
944944
alpha = self.alpha
945945
beta = self.beta
946-
return bound_elemwise(
946+
return bound(
947947
-gammaln(alpha) + logpow(
948948
beta, alpha) - beta * value + logpow(value, alpha - 1),
949949

@@ -1007,7 +1007,7 @@ def random(self, point=None, size=None, repeat=None):
10071007
def logp(self, value):
10081008
alpha = self.alpha
10091009
beta = self.beta
1010-
return bound_elemwise(logpow(beta, alpha) - gammaln(alpha) - beta / value
1010+
return bound(logpow(beta, alpha) - gammaln(alpha) - beta / value
10111011
+ logpow(value, -alpha - 1),
10121012
value > 0, alpha > 0, beta > 0)
10131013

@@ -1088,7 +1088,7 @@ def _random(a, b, size=None):
10881088
def logp(self, value):
10891089
alpha = self.alpha
10901090
beta = self.beta
1091-
return bound_elemwise(tt.log(alpha) - tt.log(beta)
1091+
return bound(tt.log(alpha) - tt.log(beta)
10921092
+ (alpha - 1) * tt.log(value / beta)
10931093
- (value / beta)**alpha,
10941094
value >= 0, alpha > 0, beta > 0)
@@ -1161,7 +1161,7 @@ def random(self, point=None, size=None, repeat=None):
11611161
size=size)
11621162

11631163
def logp(self, value):
1164-
return bound_elemwise(self.dist.logp(value),
1164+
return bound(self.dist.logp(value),
11651165
value >= self.lower, value <= self.upper)
11661166

11671167

@@ -1286,7 +1286,7 @@ def logp(self, value):
12861286
+ logpow(std_cdf((value - mu) / sigma - sigma / nu), 1.),
12871287
- tt.log(sigma * tt.sqrt(2 * np.pi))
12881288
- 0.5 * ((value - mu) / sigma)**2)
1289-
return bound_elemwise(lp, sigma > 0., nu > 0.)
1289+
return bound(lp, sigma > 0., nu > 0.)
12901290

12911291

12921292
class VonMises(Continuous):
@@ -1335,7 +1335,7 @@ def random(self, point=None, size=None, repeat=None):
13351335
def logp(self, value):
13361336
mu = self.mu
13371337
kappa = self.kappa
1338-
return bound_elemwise(kappa * tt.cos(mu - value) - tt.log(2 * np.pi * i0(kappa)), value >= -np.pi, value <= np.pi, kappa >= 0)
1338+
return bound(kappa * tt.cos(mu - value) - tt.log(2 * np.pi * i0(kappa)), value >= -np.pi, value <= np.pi, kappa >= 0)
13391339

13401340

13411341
class SkewNormal(Continuous):
@@ -1401,7 +1401,7 @@ def logp(self, value):
14011401
sd = self.sd
14021402
mu = self.mu
14031403
alpha = self.alpha
1404-
return bound_elemwise(
1404+
return bound(
14051405
tt.log(1 +
14061406
tt.erf(((value - mu) * tt.sqrt(tau) * alpha) / tt.sqrt(2)))
14071407
+ (-tau * (value - mu)**2

pymc3/distributions/discrete.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import theano.tensor as tt
66
from scipy import stats
77

8-
from .dist_math import bound, bound_elemwise, factln, binomln, betaln, logpow
8+
from .dist_math import bound, bound, factln, binomln, betaln, logpow
99
from .distribution import Discrete, draw_values, generate_samples
1010

1111
__all__ = ['Binomial', 'BetaBinomial', 'Bernoulli', 'Poisson',
@@ -54,7 +54,7 @@ def logp(self, value):
5454
n = self.n
5555
p = self.p
5656

57-
return bound_elemwise(
57+
return bound(
5858
binomln(n, value) + logpow(p, value) + logpow(1 - p, n - value),
5959
0 <= value, value <= n,
6060
0 <= p, p <= 1)
@@ -118,7 +118,7 @@ def random(self, point=None, size=None, repeat=None):
118118
def logp(self, value):
119119
alpha = self.alpha
120120
beta = self.beta
121-
return bound_elemwise(binomln(self.n, value)
121+
return bound(binomln(self.n, value)
122122
+ betaln(value + alpha, self.n - value + beta)
123123
- betaln(alpha, beta),
124124
value >= 0, value <= self.n,
@@ -158,7 +158,7 @@ def random(self, point=None, size=None, repeat=None):
158158

159159
def logp(self, value):
160160
p = self.p
161-
return bound_elemwise(
161+
return bound(
162162
tt.switch(value, tt.log(p), tt.log(1 - p)),
163163
value >= 0, value <= 1,
164164
p >= 0, p <= 1)
@@ -204,7 +204,7 @@ def random(self, point=None, size=None, repeat=None):
204204

205205
def logp(self, value):
206206
mu = self.mu
207-
log_prob = bound_elemwise(
207+
log_prob = bound(
208208
logpow(mu, value) - factln(value) - mu,
209209
mu >= 0, value >= 0)
210210
# Return zero when mu and value are both zero
@@ -255,7 +255,7 @@ def random(self, point=None, size=None, repeat=None):
255255
def logp(self, value):
256256
mu = self.mu
257257
alpha = self.alpha
258-
negbinom = bound_elemwise(binomln(value + alpha - 1, value)
258+
negbinom = bound(binomln(value + alpha - 1, value)
259259
+ logpow(mu / (mu + alpha), value)
260260
+ logpow(alpha / (mu + alpha), alpha),
261261
value >= 0, mu > 0, alpha > 0)
@@ -300,7 +300,7 @@ def random(self, point=None, size=None, repeat=None):
300300

301301
def logp(self, value):
302302
p = self.p
303-
return bound_elemwise(tt.log(p) + logpow(1 - p, value - 1),
303+
return bound(tt.log(p) + logpow(1 - p, value - 1),
304304
0 <= p, p <= 1, value >= 1)
305305

306306

@@ -348,8 +348,8 @@ def random(self, point=None, size=None, repeat=None):
348348
def logp(self, value):
349349
upper = self.upper
350350
lower = self.lower
351-
return bound_elemwise(-tt.log(upper - lower + 1),
352-
lower <= value, value <= upper)
351+
return bound(-tt.log(upper - lower + 1),
352+
lower <= value, value <= upper)
353353

354354

355355
class Categorical(Discrete):
@@ -408,7 +408,7 @@ def logp(self, value):
408408
a = tt.log(p[tt.arange(p.shape[0]), value])
409409
else:
410410
a = tt.log(p[value])
411-
return bound_elemwise(a,
411+
return bound(a,
412412
value >= 0, value <= (k - 1),
413413
sumto1)
414414

@@ -439,7 +439,7 @@ def _random(c, dtype=dtype, size=None):
439439

440440
def logp(self, value):
441441
c = self.c
442-
return bound_elemwise(0, tt.eq(value, c))
442+
return bound(0, tt.eq(value, c))
443443

444444
def ConstantDist(*args, **kwargs):
445445
warnings.warn("ConstantDist has been deprecated. In future, use Constant instead.",

pymc3/distributions/dist_math.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,33 @@
1010

1111
from .special import gammaln, multigammaln
1212

13-
14-
def bound_elemwise(logp, *conditions):
13+
def bound(logp, *conditions, **kwargs):
1514
"""
1615
Bounds a log probability density with several conditions.
1716
18-
Respects shape of logp and performs broadcasting when
19-
conditions.shape > logp.shape.
20-
2117
Parameters
2218
----------
2319
logp : float
2420
*conditions : booleans
21+
broadcast_conditions : bool (optional, default=True)
22+
If True, broadcasts logp to match the largest shape of the conditions.
23+
This is used e.g. in DiscreteUniform where logp is a scalar constant and the shape
24+
is specified via the conditions.
25+
If False, will return the same shape as logp.
26+
This is used e.g. in Multinomial where broadcasting can lead to differences in the logp.
2527
2628
Returns
2729
-------
2830
logp with elements set to -inf where any condition is False
2931
"""
30-
return tt.switch(alltrue_elemwise(conditions), logp, -np.inf)
32+
broadcast_conditions = kwargs.get('broadcast_conditions', True)
33+
34+
if broadcast_conditions:
35+
alltrue = alltrue_elemwise
36+
else:
37+
alltrue = alltrue_scalar
38+
39+
return tt.switch(alltrue(conditions), logp, -np.inf)
3140

3241

3342
def alltrue_elemwise(vals):
@@ -37,23 +46,7 @@ def alltrue_elemwise(vals):
3746
return ret
3847

3948

40-
def bound(logp, *conditions):
41-
"""
42-
Bounds a log probability density with several conditions
43-
44-
Parameters
45-
----------
46-
logp : float
47-
*conditions : booleans
48-
49-
Returns
50-
-------
51-
logp if all conditions are true
52-
-inf if some are false
53-
"""
54-
return tt.switch(alltrue(conditions), logp, -np.inf)
55-
56-
def alltrue(vals):
49+
def alltrue_scalar(vals):
5750
return tt.all([tt.all(1 * val) for val in vals])
5851

5952

pymc3/distributions/multivariate.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ def logp(self, value):
245245
return bound(tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1)
246246
+ gammaln(tt.sum(a, axis=-1)),
247247
tt.all(value >= 0), tt.all(value <= 1),
248-
k > 1, tt.all(a > 0))
248+
k > 1, tt.all(a > 0),
249+
broadcast_conditions=False)
249250

250251

251252
class Multinomial(Discrete):
@@ -323,7 +324,9 @@ def logp(self, x):
323324
tt.all(tt.eq(tt.sum(x, axis=-1, keepdims=True), n)),
324325
tt.all(p <= 1),
325326
tt.all(tt.eq(tt.sum(p, axis=-1), 1)),
326-
tt.all(tt.ge(n, 0)))
327+
tt.all(tt.ge(n, 0)),
328+
broadcast_conditions=False
329+
)
327330

328331

329332
def posdef(AA):
@@ -443,7 +446,9 @@ def logp(self, X):
443446
- 2 * multigammaln(n / 2., p)) / 2,
444447
matrix_pos_def(X),
445448
tt.eq(X, X.T),
446-
n > (p - 1))
449+
n > (p - 1),
450+
broadcast_conditions=False
451+
)
447452

448453

449454
def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, testval=None):
@@ -605,4 +610,6 @@ def logp(self, x):
605610
return bound(result,
606611
tt.all(X <= 1), tt.all(X >= -1),
607612
matrix_pos_def(X),
608-
n > 0)
613+
n > 0,
614+
broadcast_conditions=False
615+
)

0 commit comments

Comments
 (0)