Skip to content

Commit ec86faf

Browse files
authored
Support zero variance in LogitNormal (#2002)
* Support zero variance in `LogitNormal` * Bump version to 0.25.121 * Use StatsFuns API
1 parent 642bacb commit ec86faf

File tree

3 files changed

+370
-30
lines changed

3 files changed

+370
-30
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Distributions"
22
uuid = "31c24e10-a181-5473-b8eb-7969acd0382f"
33
authors = ["JuliaStats"]
4-
version = "0.25.120"
4+
version = "0.25.121"
55

66
[deps]
77
AliasTables = "66dad0bd-aa9a-41b7-9441-69ab47430ed8"

src/univariate/continuous/logitnormal.jl

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ struct LogitNormal{T<:Real} <: ContinuousUnivariateDistribution
5959
end
6060

6161
function LogitNormal::T, σ::T; check_args::Bool=true) where {T <: Real}
62-
@check_args LogitNormal (σ, σ > zero(σ))
62+
@check_args LogitNormal (σ, σ >= zero(σ))
6363
return LogitNormal{T}(μ, σ)
6464
end
6565

@@ -111,44 +111,46 @@ end
111111

112112
#### Evaluation
113113

114-
#TODO check pd and logpdf
115-
function pdf(d::LogitNormal{T}, x::Real) where T<:Real
116-
if zero(x) < x < one(x)
117-
return normpdf(d.μ, d.σ, logit(x)) / (x * (1-x))
114+
# We directly use the StatsFuns API instead of going through `Normal(...)`
115+
# to avoid overhead introduced by the parameter checks of `Normal`
116+
# Ref https://github.com/JuliaStats/Distributions.jl/pull/2003
117+
118+
function pdf(d::LogitNormal, x::Real)
119+
if x zero(x) || x oneunit(x)
120+
logitx = oftype(float(x), -Inf)
121+
z = oneunit(x * (1 - x))
118122
else
119-
return T(0)
123+
logitx = logit(x)
124+
z = x * (1 - x)
120125
end
126+
return StatsFuns.normpdf(d.μ, d.σ, logitx) / z
121127
end
122-
123-
function logpdf(d::LogitNormal{T}, x::Real) where T<:Real
124-
if zero(x) < x < one(x)
125-
lx = logit(x)
126-
return normlogpdf(d.μ, d.σ, lx) - log(x) - log1p(-x)
128+
function logpdf(d::LogitNormal, x::Real)
129+
if x zero(x) || x one(x)
130+
logitx = oftype(float(x), -Inf)
131+
z = zero(float(x))
127132
else
128-
return -T(Inf)
133+
logitx = logit(x)
134+
z = log(x * (1 - x))
129135
end
136+
return StatsFuns.normlogpdf(d.μ, d.σ, logitx) - z
130137
end
131138

132-
cdf(d::LogitNormal{T}, x::Real) where {T<:Real} =
133-
x 0 ? zero(T) : x 1 ? one(T) : normcdf(d.μ, d.σ, logit(x))
134-
ccdf(d::LogitNormal{T}, x::Real) where {T<:Real} =
135-
x 0 ? one(T) : x 1 ? zero(T) : normccdf(d.μ, d.σ, logit(x))
136-
logcdf(d::LogitNormal{T}, x::Real) where {T<:Real} =
137-
x 0 ? -T(Inf) : x 1 ? zero(T) : normlogcdf(d.μ, d.σ, logit(x))
138-
logccdf(d::LogitNormal{T}, x::Real) where {T<:Real} =
139-
x 0 ? zero(T) : x 1 ? -T(Inf) : normlogccdf(d.μ, d.σ, logit(x))
139+
cdf(d::LogitNormal, x::Real) = StatsFuns.normcdf(d.μ, d.σ, logit(clamp(x, zero(x), oneunit(x))))
140+
ccdf(d::LogitNormal, x::Real) = StatsFuns.normccdf(d.μ, d.σ, logit(clamp(x, zero(x), oneunit(x))))
141+
logcdf(d::LogitNormal, x::Real) = StatsFuns.normlogcdf(d.μ, d.σ, logit(clamp(x, zero(x), oneunit(x))))
142+
logccdf(d::LogitNormal, x::Real) = StatsFuns.normlogccdf(d.μ, d.σ, logit(clamp(x, zero(x), oneunit(x))))
140143

141-
quantile(d::LogitNormal, q::Real) = logistic(norminvcdf(d.μ, d.σ, q))
142-
cquantile(d::LogitNormal, q::Real) = logistic(norminvccdf(d.μ, d.σ, q))
143-
invlogcdf(d::LogitNormal, lq::Real) = logistic(norminvlogcdf(d.μ, d.σ, lq))
144-
invlogccdf(d::LogitNormal, lq::Real) = logistic(norminvlogccdf(d.μ, d.σ, lq))
144+
quantile(d::LogitNormal, q::Real) = logistic(StatsFuns.norminvcdf(d.μ, d.σ, q))
145+
cquantile(d::LogitNormal, q::Real) = logistic(StatsFuns.norminvccdf(d.μ, d.σ, q))
146+
invlogcdf(d::LogitNormal, lq::Real) = logistic(StatsFuns.norminvlogcdf(d.μ, d.σ, lq))
147+
invlogccdf(d::LogitNormal, lq::Real) = logistic(StatsFuns.norminvlogccdf(d.μ, d.σ, lq))
145148

146149
function gradlogpdf(d::LogitNormal, x::Real)
147-
μ, σ = params(d)
148-
_insupport = insupport(d, x)
149-
_x = _insupport ? x : zero(x)
150-
z =- logit(_x) + σ^2 * (2 * _x - 1)) /^2 * _x * (1 - _x))
151-
return _insupport ? z : oftype(z, NaN)
150+
outofsupport = x zero(x) || x oneunit(x)
151+
y = outofsupport ? zero(x) : x
152+
z = ((d.μ - logit(y)) / d.σ^2 + 2 * y - 1) / (y * (1 - y))
153+
return outofsupport ? zero(z) : z
152154
end
153155

154156
# mgf(d::LogitNormal)

0 commit comments

Comments
 (0)