@@ -59,7 +59,7 @@ struct LogitNormal{T<:Real} <: ContinuousUnivariateDistribution
5959end
6060
6161function LogitNormal (μ:: T , σ:: T ; check_args:: Bool = true ) where {T <: Real }
62- @check_args LogitNormal (σ, σ > zero (σ))
62+ @check_args LogitNormal (σ, σ >= zero (σ))
6363 return LogitNormal {T} (μ, σ)
6464end
6565
@@ -111,44 +111,46 @@ end
111111
112112# ### Evaluation
113113
114- # TODO check pd and logpdf
115- function pdf (d:: LogitNormal{T} , x:: Real ) where T<: Real
116- if zero (x) < x < one (x)
117- return normpdf (d. μ, d. σ, logit (x)) / (x * (1 - x))
114+ # We directly use the StatsFuns API instead of going through `Normal(...)`
115+ # to avoid overhead introduced by the parameter checks of `Normal`
116+ # Ref https://github.com/JuliaStats/Distributions.jl/pull/2003
117+
118+ function pdf (d:: LogitNormal , x:: Real )
119+ if x ≤ zero (x) || x ≥ oneunit (x)
120+ logitx = oftype (float (x), - Inf )
121+ z = oneunit (x * (1 - x))
118122 else
119- return T (0 )
123+ logitx = logit (x)
124+ z = x * (1 - x)
120125 end
126+ return StatsFuns. normpdf (d. μ, d. σ, logitx) / z
121127end
122-
123- function logpdf (d:: LogitNormal{T} , x:: Real ) where T<: Real
124- if zero (x) < x < one (x)
125- lx = logit (x)
126- return normlogpdf (d. μ, d. σ, lx) - log (x) - log1p (- x)
128+ function logpdf (d:: LogitNormal , x:: Real )
129+ if x ≤ zero (x) || x ≥ one (x)
130+ logitx = oftype (float (x), - Inf )
131+ z = zero (float (x))
127132 else
128- return - T (Inf )
133+ logitx = logit (x)
134+ z = log (x * (1 - x))
129135 end
136+ return StatsFuns. normlogpdf (d. μ, d. σ, logitx) - z
130137end
131138
132- cdf (d:: LogitNormal{T} , x:: Real ) where {T<: Real } =
133- x ≤ 0 ? zero (T) : x ≥ 1 ? one (T) : normcdf (d. μ, d. σ, logit (x))
134- ccdf (d:: LogitNormal{T} , x:: Real ) where {T<: Real } =
135- x ≤ 0 ? one (T) : x ≥ 1 ? zero (T) : normccdf (d. μ, d. σ, logit (x))
136- logcdf (d:: LogitNormal{T} , x:: Real ) where {T<: Real } =
137- x ≤ 0 ? - T (Inf ) : x ≥ 1 ? zero (T) : normlogcdf (d. μ, d. σ, logit (x))
138- logccdf (d:: LogitNormal{T} , x:: Real ) where {T<: Real } =
139- x ≤ 0 ? zero (T) : x ≥ 1 ? - T (Inf ) : normlogccdf (d. μ, d. σ, logit (x))
139+ cdf (d:: LogitNormal , x:: Real ) = StatsFuns. normcdf (d. μ, d. σ, logit (clamp (x, zero (x), oneunit (x))))
140+ ccdf (d:: LogitNormal , x:: Real ) = StatsFuns. normccdf (d. μ, d. σ, logit (clamp (x, zero (x), oneunit (x))))
141+ logcdf (d:: LogitNormal , x:: Real ) = StatsFuns. normlogcdf (d. μ, d. σ, logit (clamp (x, zero (x), oneunit (x))))
142+ logccdf (d:: LogitNormal , x:: Real ) = StatsFuns. normlogccdf (d. μ, d. σ, logit (clamp (x, zero (x), oneunit (x))))
140143
141- quantile (d:: LogitNormal , q:: Real ) = logistic (norminvcdf (d. μ, d. σ, q))
142- cquantile (d:: LogitNormal , q:: Real ) = logistic (norminvccdf (d. μ, d. σ, q))
143- invlogcdf (d:: LogitNormal , lq:: Real ) = logistic (norminvlogcdf (d. μ, d. σ, lq))
144- invlogccdf (d:: LogitNormal , lq:: Real ) = logistic (norminvlogccdf (d. μ, d. σ, lq))
144+ quantile (d:: LogitNormal , q:: Real ) = logistic (StatsFuns . norminvcdf (d. μ, d. σ, q))
145+ cquantile (d:: LogitNormal , q:: Real ) = logistic (StatsFuns . norminvccdf (d. μ, d. σ, q))
146+ invlogcdf (d:: LogitNormal , lq:: Real ) = logistic (StatsFuns . norminvlogcdf (d. μ, d. σ, lq))
147+ invlogccdf (d:: LogitNormal , lq:: Real ) = logistic (StatsFuns . norminvlogccdf (d. μ, d. σ, lq))
145148
146149function gradlogpdf (d:: LogitNormal , x:: Real )
147- μ, σ = params (d)
148- _insupport = insupport (d, x)
149- _x = _insupport ? x : zero (x)
150- z = (μ - logit (_x) + σ^ 2 * (2 * _x - 1 )) / (σ^ 2 * _x * (1 - _x))
151- return _insupport ? z : oftype (z, NaN )
150+ outofsupport = x ≤ zero (x) || x ≥ oneunit (x)
151+ y = outofsupport ? zero (x) : x
152+ z = ((d. μ - logit (y)) / d. σ^ 2 + 2 * y - 1 ) / (y * (1 - y))
153+ return outofsupport ? zero (z) : z
152154end
153155
154156# mgf(d::LogitNormal)
0 commit comments