Skip to content

Commit a58ff5d

Browse files
aloctavodiatwiecki
authored andcommitted
ENH add ZINB distribution and test (#1310)
1 parent f5ea795 commit a58ff5d

File tree

4 files changed

+78
-8
lines changed

4 files changed

+78
-8
lines changed

pymc3/distributions/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from .discrete import NegativeBinomial
3131
from .discrete import ConstantDist
3232
from .discrete import ZeroInflatedPoisson
33+
from .discrete import ZeroInflatedNegativeBinomial
3334
from .discrete import DiscreteUniform
3435
from .discrete import Geometric
3536
from .discrete import Categorical
@@ -88,6 +89,7 @@
8889
'NegativeBinomial',
8990
'ConstantDist',
9091
'ZeroInflatedPoisson',
92+
'ZeroInflatedNegativeBinomial',
9193
'DiscreteUniform',
9294
'Geometric',
9395
'Categorical',
@@ -109,4 +111,4 @@
109111
'GARCH11'
110112
]
111113

112-
114+

pymc3/distributions/discrete.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from .distribution import Discrete, draw_values, generate_samples
1010

1111
__all__ = ['Binomial', 'BetaBinomial', 'Bernoulli', 'Poisson',
12-
'NegativeBinomial', 'ConstantDist', 'ZeroInflatedPoisson',
13-
'DiscreteUniform', 'Geometric', 'Categorical']
12+
'NegativeBinomial', 'ConstantDist', 'ZeroInflatedPoisson',
13+
'ZeroInflatedNegativeBinomial', 'DiscreteUniform', 'Geometric',
14+
'Categorical']
1415

1516

1617
class Binomial(Discrete):
@@ -478,3 +479,55 @@ def logp(self, value):
478479
return tt.switch(value > 0,
479480
tt.log(self.psi) + self.pois.logp(value),
480481
tt.log((1. - self.psi) + self.psi * tt.exp(-self.theta)))
482+
483+
484+
class ZeroInflatedNegativeBinomial(Discrete):
485+
R"""
486+
Zero-Inflated Negative binomial log-likelihood.
487+
488+
The Zero-inflated version of the Negative Binomial (NB).
489+
The NB distribution describes a Poisson random variable
490+
whose rate parameter is gamma distributed.
491+
492+
.. math::
493+
494+
f(x \mid \mu, \alpha, \psi) = \left\{ \begin{array}{l}
495+
(1-\psi) + \psi \left (\frac{\alpha}{\alpha+\mu} \right) ^\alpha, \text{if } x = 0 \\
496+
\psi \frac{\Gamma(x+\alpha)}{x! \Gamma(\alpha)} \left (\frac{\alpha}{\mu+\alpha} \right)^\alpha \left( \frac{\mu}{\mu+\alpha} \right)^x, \text{if } x=1,2,3,\ldots
497+
\end{array} \right.
498+
499+
======== ==========================
500+
Support :math:`x \in \mathbb{N}_0`
501+
Mean :math:`\psi\mu`
502+
Var :math:`\psi\mu + \left (1 + \frac{\mu}{\alpha} + \frac{1-\psi}{\mu} \right)`
503+
======== ==========================
504+
505+
Parameters
506+
----------
507+
mu : float
508+
Poission distribution parameter (mu > 0).
509+
alpha : float
510+
Gamma distribution parameter (alpha > 0).
511+
psi : float
512+
Expected proportion of NegativeBinomial variates (0 < psi < 1)
513+
"""
514+
def __init__(self, mu, alpha, psi, *args, **kwargs):
515+
super(ZeroInflatedNegativeBinomial, self).__init__(*args, **kwargs)
516+
self.mu = mu
517+
self.alpha = alpha
518+
self.psi = psi
519+
self.nb = NegativeBinomial.dist(mu, alpha)
520+
self.mode = self.nb.mode
521+
522+
def random(self, point=None, size=None, repeat=None):
523+
mu, alpha, psi = draw_values([self.mu, self.alpha, self.psi], point=point)
524+
g = generate_samples(stats.gamma.rvs, alpha, scale=mu/alpha,
525+
dist_shape=self.shape,
526+
size=size)
527+
g[g == 0] = np.finfo(float).eps # Just in case
528+
return stats.poisson.rvs(g) * (np.random.random(np.squeeze(g.shape)) < psi)
529+
530+
def logp(self, value):
531+
return tt.switch(value > 0,
532+
tt.log(self.psi) + self.nb.logp(value),
533+
tt.log((1. - self.psi) + self.psi * (self.alpha/(self.alpha+self.mu))**self.alpha))

pymc3/tests/test_distributions.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from ..model import Model, Point, Potential
77
from ..blocking import DictToVarBijection, DictToArrayBijection, ArrayOrdering
88
from ..distributions import (DensityDist, Categorical, Multinomial, VonMises, Dirichlet,
9-
MvStudentT, MvNormal, ZeroInflatedPoisson, ConstantDist,
10-
Poisson, Bernoulli, Beta, BetaBinomial, StudentTpos,
9+
MvStudentT, MvNormal, ZeroInflatedPoisson, ZeroInflatedNegativeBinomial,
10+
ConstantDist, Poisson, Bernoulli, Beta, BetaBinomial, StudentTpos,
1111
StudentT, Weibull, Pareto, InverseGamma, Gamma, Cauchy,
1212
HalfCauchy, Lognormal, Laplace, NegativeBinomial, Geometric,
1313
Exponential, ExGaussian, Normal, Flat, LKJCorr, Wald,
@@ -382,6 +382,9 @@ def test_constantdist():
382382
def test_zeroinflatedpoisson():
383383
checkd(ZeroInflatedPoisson, Nat, {'theta': Rplus, 'psi': Unit})
384384

385+
def test_zeroinflatednegativebinomial():
386+
checkd(ZeroInflatedNegativeBinomial, Nat, {'mu': Rplusbig, 'alpha': Rplusbig, 'psi': Unit})
387+
385388
def test_mvnormal():
386389
for n in [1, 2]:
387390
yield check_mvnormal, n

pymc3/tests/test_distributions_random.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
I, Simplex, Vector, PdMatrix)
1212

1313
from ..distributions import (DensityDist, Categorical, Multinomial, VonMises, Dirichlet,
14-
MvStudentT, MvNormal, ZeroInflatedPoisson, ConstantDist,
15-
Poisson, Bernoulli, Beta, BetaBinomial, StudentTpos,
14+
MvStudentT, MvNormal, ZeroInflatedPoisson, ZeroInflatedNegativeBinomial,
15+
ConstantDist, Poisson, Bernoulli, Beta, BetaBinomial, StudentTpos,
1616
StudentT, Weibull, Pareto, InverseGamma, Gamma, Cauchy,
1717
HalfCauchy, Lognormal, Laplace, NegativeBinomial, Geometric,
1818
Exponential, ExGaussian, Normal, Flat, LKJCorr, Wald,
@@ -202,7 +202,10 @@ def test_constant_dist(self):
202202

203203
def test_zero_inflated_poisson(self):
204204
self.check(ZeroInflatedPoisson, theta=1, psi=0.3)
205-
205+
206+
def test_zero_inflated_negative_binomial(self):
207+
self.check(ZeroInflatedNegativeBinomial, mu=1., alpha=1., psi=0.3)
208+
206209
def test_discrete_uniform(self):
207210
self.check(DiscreteUniform, lower=0., upper=10)
208211

@@ -295,6 +298,9 @@ def test_constant_dist(self):
295298

296299
def test_zero_inflated_poisson(self):
297300
self.check(ZeroInflatedPoisson, theta=1, psi=0.3)
301+
302+
def test_zero_inflated_negative_binomial(self):
303+
self.check(ZeroInflatedNegativeBinomial, mu=1., alpha=1., psi=0.3)
298304

299305
def test_discrete_uniform(self):
300306
self.check(DiscreteUniform, lower=0., upper=10)
@@ -398,6 +404,9 @@ def test_constantDist(self):
398404

399405
def test_zero_inflated_poisson(self):
400406
self.check(ZeroInflatedPoisson, theta=self.ones, psi=self.ones/2)
407+
408+
def test_zero_inflated_negative_binomial(self):
409+
self.check(ZeroInflatedNegativeBinomial, mu=self.ones, alpha=self.ones, psi=self.ones/2)
401410

402411
def test_discrete_uniform(self):
403412
self.check(DiscreteUniform,
@@ -508,6 +517,9 @@ def test_constantDist(self):
508517

509518
def test_zero_inflated_poisson(self):
510519
self.check(ZeroInflatedPoisson, theta=self.ones*2, psi=self.ones/3)
520+
521+
def test_zero_inflated_negative_binomial(self):
522+
self.check(ZeroInflatedNegativeBinomial, mu=self.ones*2, alpha=self.ones*2, psi=self.ones/3)
511523

512524
def test_discrete_uniform(self):
513525
self.check(DiscreteUniform, lower=self.zeros.astype(int),

0 commit comments

Comments
 (0)