diff --git a/Project.toml b/Project.toml index 3f4ac3f1a1..4edcc7ff14 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,9 @@ version = "0.25.15" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +LambertW = "984bce1d-4616-540c-a9ee-88d1112d94c9" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" diff --git a/docs/src/fit.md b/docs/src/fit.md index eb91aa39be..8ce259abad 100644 --- a/docs/src/fit.md +++ b/docs/src/fit.md @@ -54,6 +54,7 @@ The `fit_mle` method has been implemented for the following distributions: - [`InverseGaussian`](@ref) - [`Uniform`](@ref) - [`Weibull`](@ref) +- [`ZeroInflatedPoisson`](@ref) **Multivariate:** diff --git a/docs/src/univariate.md b/docs/src/univariate.md index 0418a38faf..765bbdcefb 100644 --- a/docs/src/univariate.md +++ b/docs/src/univariate.md @@ -159,6 +159,7 @@ NegativeBinomial Poisson PoissonBinomial Skellam +ZeroInflatedPoisson ``` ### Vectorized evaluation diff --git a/src/Distributions.jl b/src/Distributions.jl index d33a026557..700a5d2546 100644 --- a/src/Distributions.jl +++ b/src/Distributions.jl @@ -23,7 +23,7 @@ import StatsBase: kurtosis, skewness, entropy, mode, modes, import PDMats: dim, PDMat, invquad -using SpecialFunctions +using SpecialFunctions, LambertW import ChainRulesCore @@ -161,6 +161,7 @@ export WalleniusNoncentralHypergeometric, Weibull, Wishart, + ZeroInflatedPoisson, ZeroMeanIsoNormal, ZeroMeanIsoNormalCanon, ZeroMeanDiagNormal, @@ -239,6 +240,7 @@ export quantile, # inverse of cdf (defined for p in (0,1)) qqbuild, # build a paired quantiles data structure for qqplots rate, # get the rate parameter + excessprob, # get the exess probability of zeros parameter (ZeroInflatedPoison) sampler, # create a Sampler object for efficient samples scale, # get the scale parameter scale!, # provide storage for the scale parameter (used in multivariate distribution mvlognormal) @@ -334,7 +336,7 @@ Supported distributions: QQPair, Rayleigh, Skellam, Soliton, StudentizedRange, SymTriangularDist, TDist, TriangularDist, Triweight, Truncated, TruncatedNormal, Uniform, UnivariateGMM, VonMises, VonMisesFisher, WalleniusNoncentralHypergeometric, Weibull, - Wishart, ZeroMeanIsoNormal, ZeroMeanIsoNormalCanon, + Wishart, ZeroInflatedPoisson, ZeroMeanIsoNormal, ZeroMeanIsoNormalCanon, ZeroMeanDiagNormal, ZeroMeanDiagNormalCanon, ZeroMeanFullNormal, ZeroMeanFullNormalCanon diff --git a/src/univariate/discrete/zeroinflatedpoisson.jl b/src/univariate/discrete/zeroinflatedpoisson.jl new file mode 100644 index 0000000000..fa6f49102e --- /dev/null +++ b/src/univariate/discrete/zeroinflatedpoisson.jl @@ -0,0 +1,165 @@ +""" + ZeroInflatedPoisson(λ, p) +A *Zero-Inflated Poisson distribution* is a mixture distribution in which data arise from two processes. The first process is is a Poisson distribution, with mean λ, that descibes the number of independent events occurring within a unit time interval: +```math +P(X = k) = (1 - p) \\frac{\\lambda^k}{k!} e^{-\\lambda}, \\quad \\text{ for } k = 0,1,2,\\ldots. +``` +Zeros may arise from this process, an additional Bernoulli process, where the probability of observing an excess zero is given as p: +```math +P(X = 0) = p + (1 - p) e^{-\\lambda} +``` +As p approaches 0, the distribution tends toward Poisson(λ). +```julia +ZeroInflatedPoisson() # Zero-Inflated Poisson distribution with rate parameter 1, and probability of observing a zero 0.5 +ZeroInflatedPoisson(λ) # ZeroInflatedPoisson distribution with rate parameter λ, and probability of observing a zero 0.5 +params(d) # Get the parameters, i.e. (λ, p) +mean(d) # Get the mean of the mixture distribution +var(d) # Get the variance of the mixture distribution +``` +External links: +* [Zero-inflated Poisson Regression on UCLA IDRE Statistical Consulting](https://stats.idre.ucla.edu/stata/dae/zero-inflated-poisson-regression/) +* [Zero-inflated model on Wikipedia](https://en.wikipedia.org/wiki/Zero-inflated_model) +* McElreath, R. (2020). Statistical Rethinking: A Bayesian Course with Examples in R and Stan (2nd ed.). Chapman and Hall/CRC. https://doi.org/10.1201/9780429029608 + +""" +struct ZeroInflatedPoisson{T<:Real} <: DiscreteUnivariateDistribution + λ::T + p::T + + function ZeroInflatedPoisson{T}(λ::T, p::T) where {T <: Real} + return new{T}(λ, p) + end +end + +function ZeroInflatedPoisson(λ::T, p::T; check_args = true) where {T <: Real} + if check_args + @check_args(Poisson, λ >= zero(λ)) + @check_args(ZeroInflatedPoisson, zero(p) <= p <= one(p)) + end + return ZeroInflatedPoisson{T}(λ, p) +end + +ZeroInflatedPoisson(λ::Real, p::Real) = ZeroInflatedPoisson(promote(λ, p)...) +ZeroInflatedPoisson(λ::Integer, p::Integer) = ZeroInflatedPoisson(float(λ), float(p)) +ZeroInflatedPoisson(λ::Real) = ZeroInflatedPoisson(λ, 0.0) +ZeroInflatedPoisson() = ZeroInflatedPoisson(1.0, 0.0, check_args = false) + +@distr_support ZeroInflatedPoisson 0 (d.λ == zero(typeof(d.λ)) ? 0 : Inf) + +### Statistics + +mean(d::ZeroInflatedPoisson) = (1 - d.p) * d.λ + +var(d::ZeroInflatedPoisson) = d.λ * (1 - d.p) * (1 + d.p * d.λ) + +#### Conversions + +function convert(::Type{ZeroInflatedPoisson{T}}, λ::Real, p::Real) where {T<:Real} + return ZeroInflatedPoisson(T(λ), T(p)) +end + +function convert(::Type{ZeroInflatedPoisson{T}}, d::ZeroInflatedPoisson{S}) where {T <: Real, S <: Real} + return ZeroInflatedPoisson(T(d.λ), T(d.p), check_args = false) +end + +#### Parameters + +params(d::ZeroInflatedPoisson) = (d.λ, d.p,) +partype(::ZeroInflatedPoisson{T}) where {T} = T + +rate(d::ZeroInflatedPoisson) = d.λ +excessprob(d::ZeroInflatedPoisson) = d.p + +#### Evaluation + +function logpdf(d::ZeroInflatedPoisson, y::Real) + lp = if iszero(y) + logaddexp(log(d.p), log1p(-d.p) - d.λ) + else + log1p(-d.p) + logpdf(Poisson(d.λ), y) + end + return lp +end + +function cdf(d::ZeroInflatedPoisson, x::Real) + pd = Poisson(d.λ) + + deflat_limit = -1.0 / expm1(d.λ) + + if x < 0 + out = 0.0 + elseif d.p < deflat_limit + out = NaN + else + out = d.p + (1 - d.p) * cdf(pd, x) + end + return out +end + +# quantile +function quantile(d::ZeroInflatedPoisson, q::Real) + + deflat_limit = -1.0 / expm1(d.λ) + + if (q <= d.p) + out = 0 + elseif (d.p < deflat_limit) + out = convert(Int64, NaN) + elseif (d.p < q) & (deflat_limit <= d.p) & (q < 1.0) + qp = (q - d.p) / (1.0 - d.p) + pd = Poisson(d.λ) + out = quantile(pd, qp) # handles d.p == 1 as InexactError(Inf) + end + return out +end + +#### Fitting + +struct ZeroInflatedPoissonStats <: SufficientStats + sx::Float64 # (weighted) sum of x + p0::Float64 # observed proportion of zeros + tw::Float64 # total sample weight +end + +suffstats(::Type{<:ZeroInflatedPoisson}, x::AbstractArray{T}) where {T<:Integer} = ZeroInflatedPoissonStats( + sum(x), + mean(iszero, x), + length(x) + ) + +# weighted +function suffstats(::Type{<:ZeroInflatedPoisson}, x::AbstractArray{T}, w::AbstractArray{Float64}) where T<:Integer + n = length(x) + n == length(w) || throw(DimensionMismatch("Inconsistent array lengths.")) + sx = 0. + tw = 0. + p0 = 0. + for i = 1 : n + @inbounds wi = w[i] + @inbounds sx += x[i] * wi + tw += wi + @inbounds p0i = (x[i] == 0) * wi + p0 += p0i + end + return ZeroInflatedPoissonStats(sx, p0, tw) +end + +function fit_mle(::Type{<:ZeroInflatedPoisson}, ss::ZeroInflatedPoissonStats) + m = ss.sx / ss.tw + s = m / (1 - ss.p0) + + λhat = lambertw(-s * exp(-s)) + s + phat = 1 - (m / λhat) + + return ZeroInflatedPoisson(λhat, phat) +end + +function fit_mle(::Type{<:ZeroInflatedPoisson}, x::AbstractArray{T}) where T<:Real + pstat = suffstats(ZeroInflatedPoisson, x) + return fit_mle(ZeroInflatedPoisson, pstat) +end + +function fit_mle(::Type{<:ZeroInflatedPoisson}, x::AbstractArray{T}, w::AbstractArray{Float64}) where T<:Real + pstat = suffstats(ZeroInflatedPoisson, x, w) + return fit_mle(ZeroInflatedPoisson, pstat) +end diff --git a/src/univariates.jl b/src/univariates.jl index f0b9920877..4022218b2a 100644 --- a/src/univariates.jl +++ b/src/univariates.jl @@ -681,7 +681,8 @@ const discrete_distributions = [ "poisson", "skellam", "soliton", - "poissonbinomial" + "poissonbinomial", + "zeroinflatedpoisson" ] const continuous_distributions = [ diff --git a/test/ref/discrete/zeroinflatedpoisson.R b/test/ref/discrete/zeroinflatedpoisson.R new file mode 100644 index 0000000000..02fffda6d7 --- /dev/null +++ b/test/ref/discrete/zeroinflatedpoisson.R @@ -0,0 +1,32 @@ + +ZeroInflatedPoisson <- R6Class("ZeroInflatedPoisson", + inherit = DiscreteDistribution, + public = list( + names = c("lambda", "p"), + lambda = NA, + p = NA, + initialize = function(lambda = 1, p = 0) { + self$lambda <- lambda + self$p <- p + }, + supp = function() { c(0, Inf) }, + properties = function() { + lam <- self$lambda + p <- self$p + list(rate = lam, + excessprob = p, + mean = (1 - p) * lam, + var = lam * (1 - p) * (1 + p * lam) + ) + }, + pdf = function(x, log=FALSE) { + VGAM::dzipois(x, self$lambda, pstr0 = self$p, log = log) + }, + cdf = function(x) { + VGAM::pzipois(x, self$lambda, pstr0 = self$p) + }, + quan = function(v) { + VGAM::qzipois(v, self$lambda, pstr0 = self$p) + } + ) +) diff --git a/test/ref/discrete_test.lst b/test/ref/discrete_test.lst index f130624af1..d9a1f10d47 100644 --- a/test/ref/discrete_test.lst +++ b/test/ref/discrete_test.lst @@ -68,3 +68,14 @@ WalleniusNoncentralHypergeometric(8, 6, 10, 0.1) WalleniusNoncentralHypergeometric(40, 30, 50, 1) WalleniusNoncentralHypergeometric(40, 30, 50, 0.5) WalleniusNoncentralHypergeometric(40, 30, 50, 2) + +ZeroInflatedPoisson() +ZeroInflatedPoisson(1.0) +ZeroInflatedPoisson(0.5, 0.0) +ZeroInflatedPoisson(0.5, 1.0) +ZeroInflatedPoisson(2.0, 0.0) +ZeroInflatedPoisson(2.0, 1.0) +ZeroInflatedPoisson(10.0, 0.0) +ZeroInflatedPoisson(10.0, 1.0) +ZeroInflatedPoisson(80.0, 0.0) +ZeroInflatedPoisson(80.0, 1.0) diff --git a/test/ref/rdistributions.R b/test/ref/rdistributions.R index 699f1f7182..aa42cf6f00 100644 --- a/test/ref/rdistributions.R +++ b/test/ref/rdistributions.R @@ -33,6 +33,7 @@ source("discrete/negativebinomial.R") source("discrete/noncentralhypergeometric.R") source("discrete/poisson.R") source("discrete/skellam.R") +source("discrete/zeroinflatedpoisson.R") ################################################# # diff --git a/test/ref/readme.md b/test/ref/readme.md index 4c7ed5b375..89badbcb07 100644 --- a/test/ref/readme.md +++ b/test/ref/readme.md @@ -15,7 +15,7 @@ in addition to the R language itself: | stringr | For string parsing | | R6 | OOP for implementing distributions | | extraDistr | A number of distributions | -| VGAM | For ``Frechet`` and ``Levy`` | +| VGAM | For ``Frechet`` and ``Levy`` and ``ZeroInflatedPoisson``| | distr | For ``Arcsine`` | | chi | For ``Chi`` | | circular | For ``VonMises`` |