|
1 | 1 | import numpy as np
|
2 | 2 | import theano.tensor as tt
|
| 3 | +import pymc3 as pm |
3 | 4 |
|
4 |
| -from ..distributions.dist_math import alltrue |
| 5 | +from ..distributions import Discrete |
| 6 | +from ..distributions.dist_math import bound, factln, alltrue_elemwise, alltrue_scalar |
5 | 7 |
|
6 | 8 |
|
7 |
| -def test_alltrue(): |
8 |
| - assert alltrue([]).eval() |
9 |
| - assert alltrue([True]).eval() |
10 |
| - assert alltrue([tt.ones(10)]).eval() |
11 |
| - assert alltrue([tt.ones(10), |
| 9 | +def test_bound(): |
| 10 | + logp = tt.ones((10, 10)) |
| 11 | + cond = tt.ones((10, 10)) |
| 12 | + assert np.all(bound(logp, cond).eval() == logp.eval()) |
| 13 | + |
| 14 | + logp = tt.ones((10, 10)) |
| 15 | + cond = tt.zeros((10, 10)) |
| 16 | + assert np.all(bound(logp, cond).eval() == (-np.inf * logp).eval()) |
| 17 | + |
| 18 | + logp = tt.ones((10, 10)) |
| 19 | + cond = True |
| 20 | + assert np.all(bound(logp, cond).eval() == logp.eval()) |
| 21 | + |
| 22 | + logp = tt.ones(3) |
| 23 | + cond = np.array([1, 0, 1]) |
| 24 | + assert not np.all(bound(logp, cond).eval() == 1) |
| 25 | + assert np.prod(bound(logp, cond).eval()) == -np.inf |
| 26 | + |
| 27 | + logp = tt.ones((2, 3)) |
| 28 | + cond = np.array([[1, 1, 1], [1, 0, 1]]) |
| 29 | + assert not np.all(bound(logp, cond).eval() == 1) |
| 30 | + assert np.prod(bound(logp, cond).eval()) == -np.inf |
| 31 | + |
| 32 | +def test_alltrue_scalar(): |
| 33 | + assert alltrue_scalar([]).eval() |
| 34 | + assert alltrue_scalar([True]).eval() |
| 35 | + assert alltrue_scalar([tt.ones(10)]).eval() |
| 36 | + assert alltrue_scalar([tt.ones(10), |
12 | 37 | 5 * tt.ones(101)]).eval()
|
13 |
| - assert alltrue([np.ones(10), |
| 38 | + assert alltrue_scalar([np.ones(10), |
14 | 39 | 5 * tt.ones(101)]).eval()
|
15 |
| - assert alltrue([np.ones(10), |
| 40 | + assert alltrue_scalar([np.ones(10), |
16 | 41 | True,
|
17 | 42 | 5 * tt.ones(101)]).eval()
|
18 |
| - assert alltrue([np.array([1, 2, 3]), |
| 43 | + assert alltrue_scalar([np.array([1, 2, 3]), |
19 | 44 | True,
|
20 | 45 | 5 * tt.ones(101)]).eval()
|
21 | 46 |
|
22 |
| - assert not alltrue([False]).eval() |
23 |
| - assert not alltrue([tt.zeros(10)]).eval() |
24 |
| - assert not alltrue([True, |
| 47 | + assert not alltrue_scalar([False]).eval() |
| 48 | + assert not alltrue_scalar([tt.zeros(10)]).eval() |
| 49 | + assert not alltrue_scalar([True, |
25 | 50 | False]).eval()
|
26 |
| - assert not alltrue([np.array([0, -1]), |
| 51 | + assert not alltrue_scalar([np.array([0, -1]), |
27 | 52 | tt.ones(60)]).eval()
|
28 |
| - assert not alltrue([np.ones(10), |
| 53 | + assert not alltrue_scalar([np.ones(10), |
29 | 54 | False,
|
30 | 55 | 5 * tt.ones(101)]).eval()
|
31 | 56 |
|
32 |
| - |
33 | 57 | def test_alltrue_shape():
|
34 | 58 | vals = [True, tt.ones(10), tt.zeros(5)]
|
35 | 59 |
|
36 |
| - assert alltrue(vals).eval().shape == () |
| 60 | + assert alltrue_scalar(vals).eval().shape == () |
| 61 | + |
| 62 | +class MultinomialA(Discrete): |
| 63 | + def __init__(self, n, p, *args, **kwargs): |
| 64 | + super(MultinomialA, self).__init__(*args, **kwargs) |
| 65 | + |
| 66 | + self.n = n |
| 67 | + self.p = p |
| 68 | + |
| 69 | + def logp(self, value): |
| 70 | + n = self.n |
| 71 | + p = self.p |
| 72 | + |
| 73 | + return bound(factln(n) - factln(value).sum() + (value * tt.log(p)).sum(), |
| 74 | + value >= 0, |
| 75 | + 0 <= p, p <= 1, |
| 76 | + tt.isclose(p.sum(), 1), |
| 77 | + broadcast_conditions=False |
| 78 | + ) |
| 79 | + |
| 80 | + |
| 81 | +class MultinomialB(Discrete): |
| 82 | + def __init__(self, n, p, *args, **kwargs): |
| 83 | + super(MultinomialB, self).__init__(*args, **kwargs) |
| 84 | + |
| 85 | + self.n = n |
| 86 | + self.p = p |
| 87 | + |
| 88 | + def logp(self, value): |
| 89 | + n = self.n |
| 90 | + p = self.p |
| 91 | + |
| 92 | + return bound(factln(n) - factln(value).sum() + (value * tt.log(p)).sum(), |
| 93 | + tt.all(value >= 0), |
| 94 | + tt.all(0 <= p), tt.all(p <= 1), |
| 95 | + tt.isclose(p.sum(), 1), |
| 96 | + broadcast_conditions=False |
| 97 | + ) |
| 98 | + |
| 99 | + |
| 100 | +def test_multinomial_bound(): |
| 101 | + |
| 102 | + x = np.array([1, 5]) |
| 103 | + n = x.sum() |
| 104 | + |
| 105 | + with pm.Model() as modelA: |
| 106 | + p_a = pm.Dirichlet('p', np.ones(2)) |
| 107 | + x_obs_a = MultinomialA('x', n, p_a, observed=x) |
| 108 | + |
| 109 | + with pm.Model() as modelB: |
| 110 | + p_b = pm.Dirichlet('p', np.ones(2)) |
| 111 | + x_obs_b = MultinomialB('x', n, p_b, observed=x) |
| 112 | + |
| 113 | + assert np.isclose(modelA.logp({'p_stickbreaking_': [0]}), |
| 114 | + modelB.logp({'p_stickbreaking_': [0]})) |
0 commit comments