Skip to content

Commit 99dbf35

Browse files
nbudtwiecki
authored andcommitted
Extend Rice distribution (#3289)
* Fix Rice distribution and add new parametrization (#3286) * fix math error in the docstring introduced by prev commit * code format * Update RELEASE-NOTES.md * Add i1e and i0e.grad * elemwise i0e and i1e, does not work * elemwise i0e and i1e * Rice now accepts tensor parameters * update RELEASE-NOTES.md
1 parent 9906955 commit 99dbf35

File tree

5 files changed

+77
-15
lines changed

5 files changed

+77
-15
lines changed

RELEASE-NOTES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
context manager instance. If they do not, the conditional relations between
2121
the distribution's parameters could be broken, and `random` could return
2222
values drawn from an incorrect distribution.
23+
- `Rice` distribution is now defined with either the noncentrality parameter or the shape parameter (#3287).
2324

2425
### Maintenance
2526

@@ -30,6 +31,9 @@
3031
- Fix for #3210. Now `distribution.draw_values(params)`, will draw the `params` values from their joint probability distribution and not from combinations of their marginals (Refer to PR #3273).
3132
- Removed dependence on pandas-datareader for retrieving Yahoo Finance data in examples (#3262)
3233
- Rewrote `Multinomial._random` method to better handle shape broadcasting (#3271)
34+
- Fixed `Rice` distribution, which inconsistently mixed two parametrizations (#3286).
35+
- `Rice` distribution now accepts multiple parameters and observations and is usable with NUTS (#3289).
36+
3337

3438
### Deprecations
3539

pymc3/distributions/continuous.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from scipy.interpolate import InterpolatedUnivariateSpline
1313
import warnings
1414

15-
from theano.scalar import i1, i0
16-
1715
from pymc3.theanof import floatX
1816
from . import transforms
1917
from pymc3.util import get_variable_name
@@ -3505,28 +3503,57 @@ class Rice(PositiveContinuous):
35053503
======== ==============================================================
35063504
Support :math:`x \in (0, \infty)`
35073505
Mean :math:`\sigma {\sqrt {\pi /2}}\,\,L_{{1/2}}(-\nu ^{2}/2\sigma ^{2})`
3508-
Variance :math:`2\sigma ^{2}+\nu ^{2}-{\frac {\pi \sigma ^{2}}{2}}L_{{1/2}}^{2}
3509-
\left({\frac {-\nu ^{2}}{2\sigma ^{2}}}\right)`
3506+
Variance :math:`2\sigma ^{2}+\nu ^{2}-{\frac {\pi \sigma ^{2}}{2}}L_{{1/2}}^{2}\left({\frac {-\nu ^{2}}{2\sigma ^{2}}}\right)`
35103507
======== ==============================================================
35113508
35123509
35133510
Parameters
35143511
----------
35153512
nu : float
3516-
shape parameter.
3513+
noncentrality parameter.
35173514
sd : float
3518-
standard deviation.
3515+
scale parameter.
3516+
b : float
3517+
shape parameter (alternative to nu).
3518+
3519+
Notes
3520+
-----
3521+
The distribution :math:`\mathrm{Rice}\left(|\nu|,\sigma\right)` is the
3522+
distribution of :math:`R=\sqrt{X^2+Y^2}` where :math:`X\sim N(\nu \cos{\theta}, \sigma^2)`,
3523+
:math:`Y\sim N(\nu \sin{\theta}, \sigma^2)` are independent and for any
3524+
real :math:`\theta`.
3525+
3526+
The distribution is defined with either nu or b.
3527+
The link between the two parametrizations is given by
3528+
3529+
.. math::
3530+
3531+
b = \dfrac{\nu}{\sigma}
35193532
35203533
"""
35213534

3522-
def __init__(self, nu=None, sd=None, *args, **kwargs):
3535+
def __init__(self, nu=None, sd=None, b=None, *args, **kwargs):
35233536
super(Rice, self).__init__(*args, **kwargs)
3537+
nu, b, sd = self.get_nu_b(nu, b, sd)
35243538
self.nu = nu = tt.as_tensor_variable(nu)
35253539
self.sd = sd = tt.as_tensor_variable(sd)
3540+
self.b = b = tt.as_tensor_variable(b)
35263541
self.mean = sd * np.sqrt(np.pi / 2) * tt.exp((-nu**2 / (2 * sd**2)) / 2) * ((1 - (-nu**2 / (2 * sd**2)))
3527-
* i0(-(-nu**2 / (2 * sd**2)) / 2) - (-nu**2 / (2 * sd**2)) * i1(-(-nu**2 / (2 * sd**2)) / 2))
3542+
* tt.i0(-(-nu**2 / (2 * sd**2)) / 2) - (-nu**2 / (2 * sd**2)) * tt.i1(-(-nu**2 / (2 * sd**2)) / 2))
35283543
self.variance = 2 * sd**2 + nu**2 - (np.pi * sd**2 / 2) * (tt.exp((-nu**2 / (2 * sd**2)) / 2) * ((1 - (-nu**2 / (
3529-
2 * sd**2))) * i0(-(-nu**2 / (2 * sd**2)) / 2) - (-nu**2 / (2 * sd**2)) * i1(-(-nu**2 / (2 * sd**2)) / 2)))**2
3544+
2 * sd**2))) * tt.i0(-(-nu**2 / (2 * sd**2)) / 2) - (-nu**2 / (2 * sd**2)) * tt.i1(-(-nu**2 / (2 * sd**2)) / 2)))**2
3545+
3546+
def get_nu_b(self, nu, b, sd):
3547+
if sd is None:
3548+
sd = 1.
3549+
if nu is None and b is not None:
3550+
nu = b * sd
3551+
return nu, b, sd
3552+
elif nu is not None and b is None:
3553+
b = nu / sd
3554+
return nu, b, sd
3555+
raise ValueError('Rice distribution must specify either nu'
3556+
' or b.')
35303557

35313558
def random(self, point=None, size=None):
35323559
"""
@@ -3547,7 +3574,7 @@ def random(self, point=None, size=None):
35473574
"""
35483575
nu, sd = draw_values([self.nu, self.sd],
35493576
point=point, size=size)
3550-
return generate_samples(stats.rice.rvs, b=nu, scale=sd, loc=0,
3577+
return generate_samples(stats.rice.rvs, b=nu / sd, scale=sd, loc=0,
35513578
dist_shape=self.shape, size=size)
35523579

35533580
def logp(self, value):
@@ -3566,8 +3593,9 @@ def logp(self, value):
35663593
"""
35673594
nu = self.nu
35683595
sd = self.sd
3596+
b = self.b
35693597
x = value / sd
3570-
return bound(tt.log(x * tt.exp((-(x - nu) * (x - nu)) / 2) * i0e(x * nu) / sd),
3598+
return bound(tt.log(x * tt.exp((-(x - b) * (x - b)) / 2) * i0e(x * b) / sd),
35713599
sd >= 0,
35723600
nu >= 0,
35733601
value > 0,

pymc3/distributions/dist_math.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import scipy.linalg
1010
import theano.tensor as tt
1111
import theano
12-
from theano.scalar import UnaryScalarOp, upgrade_to_float
12+
from theano.scalar import UnaryScalarOp, upgrade_to_float_no_complex
1313
from theano.tensor.slinalg import Cholesky
1414
from theano.scan_module import until
1515
from theano import scan
@@ -270,6 +270,19 @@ def grad(self, inputs, grads):
270270
return [x_grad * self.grad_op(x)]
271271

272272

273+
class I1e(UnaryScalarOp):
274+
"""
275+
Modified Bessel function of the first kind of order 1, exponentially scaled.
276+
"""
277+
nfunc_spec = ('scipy.special.i1e', 1, 1)
278+
279+
def impl(self, x):
280+
return scipy.special.i1e(x)
281+
282+
283+
i1e_scalar = I1e(upgrade_to_float_no_complex, name="i1e")
284+
i1e = tt.Elemwise(i1e_scalar, name="Elemwise{i1e,no_inplace}")
285+
273286

274287
class I0e(UnaryScalarOp):
275288
"""
@@ -280,8 +293,14 @@ class I0e(UnaryScalarOp):
280293
def impl(self, x):
281294
return scipy.special.i0e(x)
282295

296+
def grad(self, inp, grads):
297+
x, = inp
298+
gz, = grads
299+
return gz * (i1e_scalar(x) - theano.scalar.sgn(x) * i0e_scalar(x)),
300+
283301

284-
i0e = I0e(upgrade_to_float, name='i0e')
302+
i0e_scalar = I0e(upgrade_to_float_no_complex, name="i0e")
303+
i0e = tt.Elemwise(i0e_scalar, name="Elemwise{i0e,no_inplace}")
285304

286305

287306
def random_choice(*args, **kwargs):

pymc3/tests/test_dist_math.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ..theanof import floatX
1111
from ..distributions import Discrete
1212
from ..distributions.dist_math import (
13-
bound, factln, alltrue_scalar, MvNormalLogp, SplineWrapper)
13+
bound, factln, alltrue_scalar, MvNormalLogp, SplineWrapper, i0e)
1414

1515

1616
def test_bound():
@@ -193,3 +193,12 @@ def test_hessian(self):
193193
g_x, = tt.grad(spline(x_var), [x_var])
194194
with pytest.raises(NotImplementedError):
195195
tt.grad(g_x, [x_var])
196+
197+
198+
class TestI0e(object):
199+
@theano.configparser.change_flags(compute_test_value="ignore")
200+
def test_grad(self):
201+
utt.verify_grad(i0e, [0.5])
202+
utt.verify_grad(i0e, [-2.])
203+
utt.verify_grad(i0e, [[0.5, -2.]])
204+
utt.verify_grad(i0e, [[[0.5, -2.]]])

pymc3/tests/test_distributions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1178,7 +1178,9 @@ def test_multidimensional_beta_construction(self):
11781178

11791179
def test_rice(self):
11801180
self.pymc3_matches_scipy(Rice, Rplus, {'nu': Rplus, 'sd': Rplusbig},
1181-
lambda value, nu, sd: sp.rice.logpdf(value, b=nu, loc=0, scale=sd))
1181+
lambda value, nu, sd: sp.rice.logpdf(value, b=nu / sd, loc=0, scale=sd))
1182+
self.pymc3_matches_scipy(Rice, Rplus, {'b': Rplus, 'sd': Rplusbig},
1183+
lambda value, b, sd: sp.rice.logpdf(value, b=b, loc=0, scale=sd))
11821184

11831185
@pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
11841186
def test_interpolated(self):

0 commit comments

Comments
 (0)