Skip to content

Commit 8b1f64c

Browse files
authored
Improve ExGaussian logcdf and refactor logp (#4407)
* - Improve ExGaussian logcdf and refactor logp * - Remove unused std_cdf * - Add release note * - Remove unused import `theano`
1 parent 044c407 commit 8b1f64c

File tree

3 files changed

+34
-25
lines changed

3 files changed

+34
-25
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang
3232
- Fix issue in `logp` method of `HyperGeometric`. It now returns `-inf` for invalid parameters (see [4367](https://github.com/pymc-devs/pymc3/pull/4367))
3333
- Fixed `MatrixNormal` random method to work with parameters as random variables. (see [#4368](https://github.com/pymc-devs/pymc3/pull/4368))
3434
- Update the `logcdf` method of several continuous distributions to return -inf for invalid parameters and values, and raise an informative error when multiple values cannot be evaluated in a single call. (see [4393](https://github.com/pymc-devs/pymc3/pull/4393))
35+
- Improve numerical stability in `logp` and `logcdf` methods of `ExGaussian` (see [#4407](https://github.com/pymc-devs/pymc3/pull/4407))
3536

3637
## PyMC3 3.10.0 (7 December 2020)
3738

pymc3/distributions/continuous.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import warnings
2121

2222
import numpy as np
23-
import theano
2423
import theano.tensor as tt
2524

2625
from scipy import stats
@@ -37,10 +36,10 @@
3736
gammaln,
3837
i0e,
3938
incomplete_beta,
39+
log_normal,
4040
logpow,
4141
normal_lccdf,
4242
normal_lcdf,
43-
std_cdf,
4443
zvalue,
4544
)
4645
from pymc3.distributions.distribution import Continuous, draw_values, generate_samples
@@ -3214,22 +3213,22 @@ def logp(self, value):
32143213
sigma = self.sigma
32153214
nu = self.nu
32163215

3217-
standardized_val = (value - mu) / sigma
3218-
cdf_val = std_cdf(standardized_val - sigma / nu)
3219-
cdf_val_safe = tt.switch(tt.eq(cdf_val, 0), np.finfo(theano.config.floatX).eps, cdf_val)
3220-
3221-
# This condition is suggested by exGAUS.R from gamlss
3222-
lp = tt.switch(
3223-
tt.gt(nu, 0.05 * sigma),
3224-
-tt.log(nu) + (mu - value) / nu + 0.5 * (sigma / nu) ** 2 + logpow(cdf_val_safe, 1.0),
3225-
-tt.log(sigma * tt.sqrt(2 * np.pi)) - 0.5 * standardized_val ** 2,
3216+
# Alogithm is adapted from dexGAUS.R from gamlss
3217+
return bound(
3218+
tt.switch(
3219+
tt.gt(nu, 0.05 * sigma),
3220+
(
3221+
-tt.log(nu)
3222+
+ (mu - value) / nu
3223+
+ 0.5 * (sigma / nu) ** 2
3224+
+ normal_lcdf(mu + (sigma ** 2) / nu, sigma, value)
3225+
),
3226+
log_normal(value, mean=mu, sigma=sigma),
3227+
),
3228+
0 < sigma,
3229+
0 < nu,
32263230
)
32273231

3228-
return bound(lp, sigma > 0.0, nu > 0.0)
3229-
3230-
def _distr_parameters_for_repr(self):
3231-
return ["mu", "sigma", "nu"]
3232-
32333232
def logcdf(self, value):
32343233
"""
32353234
Compute the log of the cumulative distribution function for ExGaussian distribution
@@ -3253,22 +3252,25 @@ def logcdf(self, value):
32533252
"""
32543253
mu = self.mu
32553254
sigma = self.sigma
3256-
sigma_2 = sigma ** 2
32573255
nu = self.nu
3258-
z = value - mu - sigma_2 / nu
3256+
3257+
# Alogithm is adapted from pexGAUS.R from gamlss
32593258
return tt.switch(
32603259
tt.gt(nu, 0.05 * sigma),
3261-
tt.log(
3262-
std_cdf((value - mu) / sigma)
3263-
- std_cdf(z / sigma)
3264-
* tt.exp(
3265-
((mu + (sigma_2 / nu)) ** 2 - (mu ** 2) - 2 * value * ((sigma_2) / nu))
3266-
/ (2 * sigma_2)
3267-
)
3260+
logdiffexp(
3261+
normal_lcdf(mu, sigma, value),
3262+
(
3263+
(mu - value) / nu
3264+
+ 0.5 * (sigma / nu) ** 2
3265+
+ normal_lcdf(mu + (sigma ** 2) / nu, sigma, value)
3266+
),
32683267
),
32693268
normal_lcdf(mu, sigma, value),
32703269
)
32713270

3271+
def _distr_parameters_for_repr(self):
3272+
return ["mu", "sigma", "nu"]
3273+
32723274

32733275
class VonMises(Continuous):
32743276
r"""

pymc3/tests/test_distributions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,6 +1825,9 @@ def test_get_tau_sigma(self):
18251825
(15.0, 5.000, 7.500, 7.500, -3.3093854),
18261826
(50.0, 50.000, 10.000, 10.000, -3.6436067),
18271827
(1000.0, 500.000, 10.000, 20.000, -27.8707323),
1828+
(-1.0, 1.0, 20.0, 0.9, -3.91967108), # Fails in scipy version
1829+
(0.01, 0.01, 100.0, 0.01, -5.5241087), # Fails in scipy version
1830+
(-1.0, 0.0, 0.1, 0.1, -51.022349), # Fails in previous pymc3 version
18281831
],
18291832
)
18301833
def test_ex_gaussian(self, value, mu, sigma, nu, logp):
@@ -1851,6 +1854,9 @@ def test_ex_gaussian(self, value, mu, sigma, nu, logp):
18511854
(15.0, 5.000, 7.500, 7.500, -0.4545255),
18521855
(50.0, 50.000, 10.000, 10.000, -1.433714),
18531856
(1000.0, 500.000, 10.000, 20.000, -1.573708e-11),
1857+
(0.01, 0.01, 100.0, 0.01, -0.69314718), # Fails in scipy version
1858+
(-0.43402407, 0.0, 0.1, 0.1, -13.59615423), # Previous 32-bit version failed here
1859+
(-0.72402009, 0.0, 0.1, 0.1, -31.26571842), # Previous 64-bit version failed here
18541860
],
18551861
)
18561862
def test_ex_gaussian_cdf(self, value, mu, sigma, nu, logcdf):

0 commit comments

Comments
 (0)