Skip to content

Commit cc3676f

Browse files
authored
Use the upstream betaln (#564)
1 parent 2da2c6d commit cc3676f

File tree

1 file changed

+4
-8
lines changed

1 file changed

+4
-8
lines changed

numpyro/distributions/conjugate.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from jax import lax, random
55
import jax.numpy as np
6-
from jax.scipy.special import gammaln
6+
from jax.scipy.special import betaln, gammaln
77

88
from numpyro.distributions import constraints
99
from numpyro.distributions.continuous import Beta, Gamma
@@ -12,10 +12,6 @@
1212
from numpyro.distributions.util import promote_shapes, validate_sample
1313

1414

15-
def _log_beta(x, y):
16-
return gammaln(x) + gammaln(y) - gammaln(x + y)
17-
18-
1915
class BetaBinomial(Distribution):
2016
r"""
2117
Compound distribution comprising of a beta-binomial pair. The probability of
@@ -52,8 +48,8 @@ def log_prob(self, value):
5248
log_factorial_k = gammaln(value + 1)
5349
log_factorial_nmk = gammaln(self.total_count - value + 1)
5450
return (log_factorial_n - log_factorial_k - log_factorial_nmk +
55-
_log_beta(value + self.concentration1, self.total_count - value + self.concentration0) -
56-
_log_beta(self.concentration0, self.concentration1))
51+
betaln(value + self.concentration1, self.total_count - value + self.concentration0) -
52+
betaln(self.concentration0, self.concentration1))
5753

5854
@property
5955
def mean(self):
@@ -95,7 +91,7 @@ def sample(self, key, sample_shape=()):
9591
@validate_sample
9692
def log_prob(self, value):
9793
post_value = self.concentration + value
98-
return -_log_beta(self.concentration, value + 1) - np.log(post_value) + \
94+
return -betaln(self.concentration, value + 1) - np.log(post_value) + \
9995
self.concentration * np.log(self.rate) - post_value * np.log1p(self.rate)
10096

10197
@property

0 commit comments

Comments
 (0)