Skip to content

Commit a99e429

Browse files
committed
Robustified logp calculations in zero-inflated distributions
1 parent a2f589c commit a99e429

File tree

1 file changed

+50
-23
lines changed

1 file changed

+50
-23
lines changed

pymc3/distributions/discrete.py

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .dist_math import bound, factln, binomln, betaln, logpow
99
from .distribution import Discrete, draw_values, generate_samples, reshape_sampled
1010
from pymc3.math import tround
11+
from ..math import logsumexp
1112

1213
__all__ = ['Binomial', 'BetaBinomial', 'Bernoulli', 'DiscreteWeibull',
1314
'Poisson', 'NegativeBinomial', 'ConstantDist', 'Constant',
@@ -593,7 +594,7 @@ class ZeroInflatedPoisson(Discrete):
593594
594595
.. math::
595596
596-
f(x \mid \theta, \psi) = \left\{ \begin{array}{l}
597+
f(x \mid \psi, \theta) = \left\{ \begin{array}{l}
597598
(1-\psi) + \psi e^{-\theta}, \text{if } x = 0 \\
598599
\psi \frac{e^{-\theta}\theta^x}{x!}, \text{if } x=1,2,3,\ldots
599600
\end{array} \right.
@@ -606,15 +607,14 @@ class ZeroInflatedPoisson(Discrete):
606607
607608
Parameters
608609
----------
610+
psi : float
611+
Expected proportion of Poisson variates (0 < psi < 1)
609612
theta : float
610613
Expected number of occurrences during the given interval
611614
(theta >= 0).
612-
psi : float
613-
Expected proportion of Poisson variates (0 < psi < 1)
614-
615615
"""
616616

617-
def __init__(self, theta, psi, *args, **kwargs):
617+
def __init__(self, psi, theta, *args, **kwargs):
618618
super(ZeroInflatedPoisson, self).__init__(*args, **kwargs)
619619
self.theta = theta = tt.as_tensor_variable(theta)
620620
self.psi = psi = tt.as_tensor_variable(psi)
@@ -630,9 +630,17 @@ def random(self, point=None, size=None, repeat=None):
630630
return reshape_sampled(sampled, size, self.shape)
631631

632632
def logp(self, value):
633-
return tt.switch(value > 0,
634-
tt.log(self.psi) + self.pois.logp(value),
635-
tt.log((1. - self.psi) + self.psi * tt.exp(-self.theta)))
633+
psi = self.psi
634+
theta = self.theta
635+
636+
logp_val = tt.switch(value > 0,
637+
logsumexp(tt.log(psi) + self.pois.logp(value)),
638+
logsumexp(tt.log((1. - psi) + psi * tt.exp(-theta))))
639+
640+
return bound(logp_val.sum(),
641+
0 <= value,
642+
0 <= psi, psi <= 1,
643+
0 <= theta)
636644

637645
def _repr_latex_(self, name=None, dist=None):
638646
if dist is None:
@@ -650,7 +658,7 @@ class ZeroInflatedBinomial(Discrete):
650658
651659
.. math::
652660
653-
f(x \mid n, p, \psi) = \left\{ \begin{array}{l}
661+
f(x \mid \psi, n, p) = \left\{ \begin{array}{l}
654662
(1-\psi) + \psi (1-p)^{n}, \text{if } x = 0 \\
655663
\psi {n \choose x} p^x (1-p)^{n-x}, \text{if } x=1,2,3,\ldots,n
656664
\end{array} \right.
@@ -663,16 +671,16 @@ class ZeroInflatedBinomial(Discrete):
663671
664672
Parameters
665673
----------
674+
psi : float
675+
Expected proportion of Binomial variates (0 < psi < 1)
666676
n : int
667677
Number of Bernoulli trials (n >= 0).
668678
p : float
669679
Probability of success in each trial (0 < p < 1).
670-
psi : float
671-
Expected proportion of Binomial variates (0 < psi < 1)
672680
673681
"""
674682

675-
def __init__(self, n, p, psi, *args, **kwargs):
683+
def __init__(self, psi, n, p, *args, **kwargs):
676684
super(ZeroInflatedBinomial, self).__init__(*args, **kwargs)
677685
self.n = n = tt.as_tensor_variable(n)
678686
self.p = p = tt.as_tensor_variable(p)
@@ -689,9 +697,18 @@ def random(self, point=None, size=None, repeat=None):
689697
return reshape_sampled(sampled, size, self.shape)
690698

691699
def logp(self, value):
692-
return tt.switch(value > 0,
693-
tt.log(self.psi) + self.bin.logp(value),
694-
tt.log((1. - self.psi) + self.psi * tt.pow(1 - self.p, self.n)))
700+
psi = self.psi
701+
p = self.p
702+
n = self.n
703+
704+
logp_val = tt.switch(value > 0,
705+
logsumexp(tt.log(psi) + self.bin.logp(value)),
706+
logsumexp(tt.log((1. - psi) + psi * tt.pow(1 - p, n))))
707+
708+
return bound(logp_val.sum(),
709+
0 <= value, value <= n,
710+
0 <= psi, psi <= 1,
711+
0 <= p, p <= 1)
695712

696713
def _repr_latex_(self, name=None, dist=None):
697714
if dist is None:
@@ -703,7 +720,7 @@ def _repr_latex_(self, name=None, dist=None):
703720
get_variable_name(n),
704721
get_variable_name(p),
705722
get_variable_name(psi))
706-
723+
707724

708725
class ZeroInflatedNegativeBinomial(Discrete):
709726
R"""
@@ -715,7 +732,7 @@ class ZeroInflatedNegativeBinomial(Discrete):
715732
716733
.. math::
717734
718-
f(x \mid \mu, \alpha, \psi) = \left\{ \begin{array}{l}
735+
f(x \mid \psi, \mu, \alpha) = \left\{ \begin{array}{l}
719736
(1-\psi) + \psi \left (\frac{\alpha}{\alpha+\mu} \right) ^\alpha, \text{if } x = 0 \\
720737
\psi \frac{\Gamma(x+\alpha)}{x! \Gamma(\alpha)} \left (\frac{\alpha}{\mu+\alpha} \right)^\alpha \left( \frac{\mu}{\mu+\alpha} \right)^x, \text{if } x=1,2,3,\ldots
721738
\end{array} \right.
@@ -728,15 +745,16 @@ class ZeroInflatedNegativeBinomial(Discrete):
728745
729746
Parameters
730747
----------
748+
psi : float
749+
Expected proportion of NegativeBinomial variates (0 < psi < 1)
731750
mu : float
732751
Poission distribution parameter (mu > 0).
733752
alpha : float
734753
Gamma distribution parameter (alpha > 0).
735-
psi : float
736-
Expected proportion of NegativeBinomial variates (0 < psi < 1)
754+
737755
"""
738756

739-
def __init__(self, mu, alpha, psi, *args, **kwargs):
757+
def __init__(self, psi, mu, alpha, *args, **kwargs):
740758
super(ZeroInflatedNegativeBinomial, self).__init__(*args, **kwargs)
741759
self.mu = mu = tt.as_tensor_variable(mu)
742760
self.alpha = alpha = tt.as_tensor_variable(alpha)
@@ -755,9 +773,18 @@ def random(self, point=None, size=None, repeat=None):
755773
return reshape_sampled(sampled, size, self.shape)
756774

757775
def logp(self, value):
758-
return tt.switch(value > 0,
759-
tt.log(self.psi) + self.nb.logp(value),
760-
tt.log((1. - self.psi) + self.psi * (self.alpha / (self.alpha + self.mu))**self.alpha))
776+
alpha = self.alpha
777+
mu = self.mu
778+
psi = self.psi
779+
780+
logp_val = tt.switch(value > 0,
781+
logsumexp(tt.log(psi) + self.nb.logp(value)),
782+
logsumexp(tt.log((1. - psi) + psi * (alpha / (alpha + mu))**alpha)))
783+
784+
return bound(logp_val.sum(),
785+
0 <= value,
786+
0 <= psi, psi <= 1,
787+
mu > 0, alpha > 0)
761788

762789
def _repr_latex_(self, name=None, dist=None):
763790
if dist is None:

0 commit comments

Comments
 (0)