Skip to content

Commit 40e28e8

Browse files
authored
Implement Skellam distribution (#260)
1 parent dd2a060 commit 40e28e8

File tree

4 files changed

+110
-2
lines changed

4 files changed

+110
-2
lines changed

docs/api_reference.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Distributions
3333
GeneralizedPoisson
3434
GenExtreme
3535
R2D2M2CP
36+
Skellam
3637
histogram_approximation
3738

3839

pymc_experimental/distributions/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919

2020
from pymc_experimental.distributions.continuous import Chi, GenExtreme
21-
from pymc_experimental.distributions.discrete import GeneralizedPoisson
21+
from pymc_experimental.distributions.discrete import GeneralizedPoisson, Skellam
2222
from pymc_experimental.distributions.histogram_utils import histogram_approximation
2323
from pymc_experimental.distributions.multivariate import R2D2M2CP
2424
from pymc_experimental.distributions.timeseries import DiscreteMarkovChain

pymc_experimental/distributions/discrete.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,3 +171,98 @@ def logp(value, mu, lam):
171171
(-mu / 4) <= lam,
172172
msg="0 < mu, max(-1, -mu/4)) <= lam <= 1",
173173
)
174+
175+
176+
class Skellam:
177+
R"""
178+
Skellam distribution.
179+
180+
The Skellam distribution is the distribution of the difference of two
181+
Poisson random variables.
182+
183+
The pmf of this distribution is
184+
185+
.. math::
186+
187+
f(x | \mu_1, \mu_2) = e^{{-(\mu _{1}\!+\!\mu _{2})}}\left({\frac {\mu _{1}}{\mu _{2}}}\right)^{{x/2}}\!\!I_{{x}}(2{\sqrt {\mu _{1}\mu _{2}}})
188+
189+
where :math:`I_{x}` is the modified Bessel function of the first kind of order :math:`x`.
190+
191+
Read more about the Skellam distribution at https://en.wikipedia.org/wiki/Skellam_distribution
192+
193+
.. plot::
194+
:context: close-figs
195+
196+
import matplotlib.pyplot as plt
197+
import numpy as np
198+
import scipy.stats as st
199+
import arviz as az
200+
plt.style.use('arviz-darkgrid')
201+
x = np.arange(-15, 15)
202+
params = [
203+
(1, 1),
204+
(5, 5),
205+
(5, 1),
206+
]
207+
for mu1, mu2 in params:
208+
pmf = st.skellam.pmf(x, mu1, mu2)
209+
plt.plot(x, pmf, "-o", label=r'$\mu_1$ = {}, $\mu_2$ = {}'.format(mu1, mu2))
210+
plt.xlabel('x', fontsize=12)
211+
plt.ylabel('f(x)', fontsize=12)
212+
plt.legend(loc=1)
213+
plt.show()
214+
215+
======== ======================================
216+
Support :math:`x \in \mathbb{Z}`
217+
Mean :math:`\mu_{1} - \mu_{2}`
218+
Variance :math:`\mu_{1} + \mu_{2}`
219+
======== ======================================
220+
221+
Parameters
222+
----------
223+
mu1 : tensor_like of float
224+
Mean parameter (mu1 >= 0).
225+
mu2 : tensor_like of float
226+
Mean parameter (mu2 >= 0).
227+
"""
228+
229+
@staticmethod
230+
def skellam_dist(mu1, mu2, size):
231+
return pm.Poisson.dist(mu=mu1, size=size) - pm.Poisson.dist(mu=mu2, size=size)
232+
233+
@staticmethod
234+
def skellam_logp(value, mu1, mu2):
235+
res = (
236+
-mu1
237+
- mu2
238+
+ 0.5 * value * (pt.log(mu1) - pt.log(mu2))
239+
+ pt.log(pt.iv(value, 2 * pt.sqrt(mu1 * mu2)))
240+
)
241+
return check_parameters(
242+
res,
243+
mu1 >= 0,
244+
mu2 >= 0,
245+
msg="mu1 >= 0, mu2 >= 0",
246+
)
247+
248+
def __new__(cls, name, mu1, mu2, **kwargs):
249+
return pm.CustomDist(
250+
name,
251+
mu1,
252+
mu2,
253+
dist=cls.skellam_dist,
254+
logp=cls.skellam_logp,
255+
class_name="Skellam",
256+
**kwargs,
257+
)
258+
259+
@classmethod
260+
def dist(cls, mu1, mu2, **kwargs):
261+
return pm.CustomDist.dist(
262+
mu1,
263+
mu2,
264+
dist=cls.skellam_dist,
265+
logp=cls.skellam_logp,
266+
class_name="Skellam",
267+
**kwargs,
268+
)

pymc_experimental/tests/distributions/test_discrete.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,15 @@
2121
from pymc.testing import (
2222
BaseTestDistributionRandom,
2323
Domain,
24+
I,
2425
Rplus,
2526
assert_moment_is_expected,
27+
check_logp,
2628
discrete_random_tester,
2729
)
2830
from pytensor import config
2931

30-
from pymc_experimental.distributions import GeneralizedPoisson
32+
from pymc_experimental.distributions import GeneralizedPoisson, Skellam
3133

3234

3335
class TestGeneralizedPoisson:
@@ -118,3 +120,13 @@ def test_moment(self, mu, lam, size, expected):
118120
with pm.Model() as model:
119121
GeneralizedPoisson("x", mu=mu, lam=lam, size=size)
120122
assert_moment_is_expected(model, expected)
123+
124+
125+
class TestSkellam:
126+
def test_logp(self):
127+
check_logp(
128+
Skellam,
129+
I,
130+
{"mu1": Rplus, "mu2": Rplus},
131+
lambda value, mu1, mu2: scipy.stats.skellam.logpmf(value, mu1, mu2),
132+
)

0 commit comments

Comments
 (0)