|
3 | 3 |
|
4 | 4 | from jax import lax, random |
5 | 5 | import jax.numpy as np |
6 | | -from jax.scipy.special import gammaln |
| 6 | +from jax.scipy.special import betaln, gammaln |
7 | 7 |
|
8 | 8 | from numpyro.distributions import constraints |
9 | 9 | from numpyro.distributions.continuous import Beta, Gamma |
|
12 | 12 | from numpyro.distributions.util import promote_shapes, validate_sample |
13 | 13 |
|
14 | 14 |
|
15 | | -def _log_beta(x, y): |
16 | | - return gammaln(x) + gammaln(y) - gammaln(x + y) |
17 | | - |
18 | | - |
19 | 15 | class BetaBinomial(Distribution): |
20 | 16 | r""" |
21 | 17 | Compound distribution comprising of a beta-binomial pair. The probability of |
@@ -52,8 +48,8 @@ def log_prob(self, value): |
52 | 48 | log_factorial_k = gammaln(value + 1) |
53 | 49 | log_factorial_nmk = gammaln(self.total_count - value + 1) |
54 | 50 | return (log_factorial_n - log_factorial_k - log_factorial_nmk + |
55 | | - _log_beta(value + self.concentration1, self.total_count - value + self.concentration0) - |
56 | | - _log_beta(self.concentration0, self.concentration1)) |
| 51 | + betaln(value + self.concentration1, self.total_count - value + self.concentration0) - |
| 52 | + betaln(self.concentration0, self.concentration1)) |
57 | 53 |
|
58 | 54 | @property |
59 | 55 | def mean(self): |
@@ -95,7 +91,7 @@ def sample(self, key, sample_shape=()): |
95 | 91 | @validate_sample |
96 | 92 | def log_prob(self, value): |
97 | 93 | post_value = self.concentration + value |
98 | | - return -_log_beta(self.concentration, value + 1) - np.log(post_value) + \ |
| 94 | + return -betaln(self.concentration, value + 1) - np.log(post_value) + \ |
99 | 95 | self.concentration * np.log(self.rate) - post_value * np.log1p(self.rate) |
100 | 96 |
|
101 | 97 | @property |
|
0 commit comments