Skip to content

Commit 6dca57e

Browse files
authored
Improve documentation and implementation of custom distributions (#1431)
1 parent c05472c commit 6dca57e

File tree

3 files changed

+189
-77
lines changed

3 files changed

+189
-77
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.14.7"
3+
version = "0.14.8"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/Turing.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using Tracker: Tracker
1717

1818
import AdvancedVI
1919
import DynamicPPL: getspace, NoDist, NamedDist
20+
import Random
2021

2122
const PROGRESS = Ref(true)
2223

src/stdlib/distributions.jl

Lines changed: 187 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,139 +1,250 @@
1-
import Random: AbstractRNG
2-
3-
# No info
41
"""
5-
Flat <: ContinuousUnivariateDistribution
2+
Flat()
3+
4+
The *flat distribution* is the improper distribution of real numbers that has the improper
5+
probability density function
66
7-
A distribution with support and density of one
8-
everywhere.
7+
```math
8+
f(x) = 1.
9+
```
910
"""
1011
struct Flat <: ContinuousUnivariateDistribution end
1112

12-
Distributions.rand(rng::AbstractRNG, d::Flat) = rand(rng)
13-
Distributions.logpdf(d::Flat, x::Real) = zero(x)
14-
Distributions.minimum(d::Flat) = -Inf
15-
Distributions.maximum(d::Flat) = +Inf
13+
Base.minimum(::Flat) = -Inf
14+
Base.maximum(::Flat) = Inf
15+
16+
Base.rand(rng::Random.AbstractRNG, d::Flat) = rand(rng)
17+
Distributions.logpdf(::Flat, x::Real) = zero(x)
18+
19+
# TODO: only implement `logpdf(d, ::Real)` if support for Distributions < 0.24 is dropped
20+
Distributions.pdf(d::Flat, x::Real) = exp(logpdf(d, x))
1621

1722
# For vec support
18-
Distributions.logpdf(d::Flat, x::AbstractVector{<:Real}) = zero(x)
23+
Distributions.logpdf(::Flat, x::AbstractVector{<:Real}) = zero(x)
24+
Distributions.loglikelihood(::Flat, x::AbstractVector{<:Real}) = zero(eltype(x))
1925

2026
"""
2127
FlatPos(l::Real)
2228
23-
A distribution with a lower bound of `l` and a density
24-
of one at every `x` above `l`.
29+
The *positive flat distribution* with real-valued parameter `l` is the improper distribution
30+
of real numbers that has the improper probability density function
31+
32+
```math
33+
f(x) = \\begin{cases}
34+
0 & \\text{if } x \\leq l, \\\\
35+
1 & \\text{otherwise}.
36+
\\end{cases}
37+
```
2538
"""
2639
struct FlatPos{T<:Real} <: ContinuousUnivariateDistribution
2740
l::T
2841
end
2942

30-
Distributions.rand(rng::AbstractRNG, d::FlatPos) = rand(rng) + d.l
31-
Distributions.logpdf(d::FlatPos, x::Real) = x <= d.l ? -Inf : zero(x)
32-
Distributions.minimum(d::FlatPos) = d.l
33-
Distributions.maximum(d::FlatPos) = Inf
43+
Base.minimum(d::FlatPos) = d.l
44+
Base.maximum(d::FlatPos) = Inf
45+
46+
Base.rand(rng::Random.AbstractRNG, d::FlatPos) = rand(rng) + d.l
47+
function Distributions.logpdf(d::FlatPos, x::Real)
48+
z = float(zero(x))
49+
return x <= d.l ? oftype(z, -Inf) : z
50+
end
51+
52+
# TODO: only implement `logpdf(d, ::Real)` if support for Distributions < 0.24 is dropped
53+
Distributions.pdf(d::FlatPos, x::Real) = exp(logpdf(d, x))
3454

3555
# For vec support
36-
function Distributions.logpdf(d::FlatPos, x::AbstractVector{<:Real})
37-
return any(x .<= d.l) ? -Inf : zero(x)
56+
function Distributions.loglikelihood(d::FlatPos, x::AbstractVector{<:Real})
57+
lower = d.l
58+
T = float(eltype(x))
59+
return any(xi <= lower for xi in x) ? T(-Inf) : zero(T)
3860
end
3961

4062
"""
41-
BinomialLogit(n<:Real, I<:Integer)
63+
BinomialLogit(n, logitp)
64+
65+
The *Binomial distribution* with logit parameterization characterizes the number of
66+
successes in a sequence of independent trials.
67+
68+
It has two parameters: `n`, the number of trials, and `logitp`, the logit of the probability
69+
of success in an individual trial, with the distribution
70+
71+
```math
72+
P(X = k) = {n \\choose k}{(\\text{logistic}(logitp))}^k (1 - \\text{logistic}(logitp))^{n-k}, \\quad \\text{ for } k = 0,1,2, \\ldots, n.
73+
```
4274
43-
A univariate binomial logit distribution.
75+
See also: [`Binomial`](@ref)
4476
"""
45-
struct BinomialLogit{T<:Real, I<:Integer} <: DiscreteUnivariateDistribution
46-
n::I
77+
struct BinomialLogit{T<:Real,S<:Real} <: DiscreteUnivariateDistribution
78+
n::Int
4779
logitp::T
48-
end
80+
logconstant::S
4981

50-
function logpdf_binomial_logit(n, logitp, k)
51-
logcomb = -StatsFuns.log1p(n) - SpecialFunctions.logbeta(n - k + 1, k + 1)
52-
return logcomb + k * logitp - n * StatsFuns.log1pexp(logitp)
82+
function BinomialLogit{T}(n::Int, logitp::T) where T
83+
n >= 0 || error("parameter `n` has to be non-negative")
84+
logconstant = - (log1p(n) + n * StatsFuns.log1pexp(logitp))
85+
return new{T,typeof(logconstant)}(n, logitp, logconstant)
86+
end
5387
end
5488

55-
function Distributions.logpdf(d::BinomialLogit{<:Real}, k::Int)
56-
return logpdf_binomial_logit(d.n, d.logitp, k)
89+
BinomialLogit(n::Int, logitp::Real) = BinomialLogit{typeof(logitp)}(n, logitp)
90+
91+
Base.minimum(::BinomialLogit) = 0
92+
Base.maximum(d::BinomialLogit) = d.n
93+
94+
# TODO: only implement `logpdf(d, k::Real)` if support for Distributions < 0.24 is dropped
95+
Distributions.pdf(d::BinomialLogit, k::Real) = exp(logpdf(d, k))
96+
Distributions.logpdf(d::BinomialLogit, k::Real) = _logpdf(d, k)
97+
Distributions.logpdf(d::BinomialLogit, k::Integer) = _logpdf(d, k)
98+
99+
function _logpdf(d::BinomialLogit, k::Real)
100+
n, logitp, logconstant = d.n, d.logitp, d.logconstant
101+
_insupport = insupport(d, k)
102+
_k = _insupport ? round(Int, k) : 0
103+
result = logconstant + _k * logitp - SpecialFunctions.logbeta(n - _k + 1, _k + 1)
104+
return _insupport ? result : oftype(result, -Inf)
57105
end
58106

59-
function Distributions.pdf(d::BinomialLogit{<:Real}, k::Int)
60-
return exp(logpdf_binomial_logit(d.n, d.logitp, k))
107+
function Base.rand(rng::Random.AbstractRNG, d::BinomialLogit)
108+
return rand(rng, Binomial(d.n, logistic(d.logitp)))
61109
end
110+
Distributions.sampler(d::BinomialLogit) = sampler(Binomial(d.n, logistic(d.logitp)))
62111

63112
"""
64-
BernoulliLogit(p<:Real)
113+
BernoulliLogit(logitp::Real)
65114
66-
A univariate logit-parameterised Bernoulli distribution.
115+
Create a univariate logit-parameterised Bernoulli distribution.
67116
"""
68-
function BernoulliLogit(logitp::Real)
69-
return BinomialLogit(1, logitp)
70-
end
117+
BernoulliLogit(logitp::Real) = BinomialLogit(1, logitp)
71118

72119
"""
73-
OrderedLogistic(η::Any, cutpoints<:AbstractVector)
74-
75-
An ordered logistic distribution.
120+
OrderedLogistic(η, c::AbstractVector)
121+
122+
The *ordered logistic distribution* with real-valued parameter `η` and cutpoints `c` has the
123+
probability mass function
124+
125+
```math
126+
P(X = k) = \\begin{cases}
127+
1 - \\text{logistic}(\\eta - c_1) & \\text{if } k = 1, \\\\
128+
\\text{logistic}(\\eta - c_{k-1}) - \\text{logistic}(\\eta - c_k) & \\text{if } 1 < k < K, \\\\
129+
\\text{logistic}(\\eta - c_{K-1}) & \\text{if } k = K,
130+
\\end{cases}
131+
```
132+
where `K = length(c) + 1`.
76133
"""
77134
struct OrderedLogistic{T1, T2<:AbstractVector} <: DiscreteUnivariateDistribution
78-
η::T1
79-
cutpoints::T2
135+
η::T1
136+
cutpoints::T2
80137

81-
function OrderedLogistic(η, cutpoints)
82-
if !issorted(cutpoints)
83-
error("cutpoints are not sorted")
84-
end
138+
function OrderedLogistic{T1,T2}::T1, cutpoints::T2) where {T1,T2}
139+
issorted(cutpoints) || error("cutpoints are not sorted")
85140
return new{typeof(η), typeof(cutpoints)}(η, cutpoints)
86-
end
87-
88-
end
89-
90-
function Distributions.logpdf(d::OrderedLogistic, k::Int)
91-
K = length(d.cutpoints)+1
92-
c = d.cutpoints
93-
94-
if k==1
95-
logp= -softplus(-(c[k]-d.η)) #logp= log(logistic(c[k]-d.η))
96-
elseif k<K
97-
logp= log(logistic(c[k]-d.η) - logistic(c[k-1]-d.η))
98-
else
99-
logp= -softplus(c[k-1]-d.η) #logp= log(1-logistic(c[k-1]-d.η))
100141
end
142+
end
101143

102-
return logp
144+
function OrderedLogistic(η, cutpoints::AbstractVector)
145+
return OrderedLogistic{typeof(η),typeof(cutpoints)}(η, cutpoints)
103146
end
104147

105-
Distributions.pdf(d::OrderedLogistic, k::Int) = exp(logpdf(d,k))
148+
Base.minimum(d::OrderedLogistic) = 0
149+
Base.maximum(d::OrderedLogistic) = length(d.cutpoints) + 1
106150

107-
function Distributions.rand(rng::AbstractRNG, d::OrderedLogistic)
108-
cutpoints = d.cutpoints
109-
η = d.η
151+
# TODO: only implement `logpdf(d, k::Real)` if support for Distributions < 0.24 is dropped
152+
Distributions.pdf(d::OrderedLogistic, k::Real) = exp(logpdf(d, k))
153+
Distributions.logpdf(d::OrderedLogistic, k::Real) = _logpdf(d, k)
154+
Distributions.logpdf(d::OrderedLogistic, k::Integer) = _logpdf(d, k)
110155

111-
K = length(cutpoints)+1
112-
c = vcat(-Inf, cutpoints, Inf)
156+
function _logpdf(d::OrderedLogistic, k::Real)
157+
η, cutpoints = d.η, d.cutpoints
158+
K = length(cutpoints) + 1
113159

114-
ps = [logistic- i[1]) - logistic- i[2]) for i in zip(c[1:(end-1)],c[2:end])]
160+
_insupport = insupport(d, k)
161+
_k = _insupport ? round(Int, k) : 1
162+
logp = unsafe_logpdf_ordered_logistic(η, cutpoints, K, _k)
115163

164+
return _insupport ? logp : oftype(logp, -Inf)
165+
end
166+
167+
function Base.rand(rng::Random.AbstractRNG, d::OrderedLogistic)
168+
η, cutpoints = d.η, d.cutpoints
169+
K = length(cutpoints) + 1
170+
# evaluate probability mass function
171+
ps = map(1:K) do k
172+
exp(unsafe_logpdf_ordered_logistic(η, cutpoints, K, k))
173+
end
116174
k = rand(rng, Categorical(ps))
175+
return k
176+
end
177+
function Distributions.sampler(d::OrderedLogistic)
178+
η, cutpoints = d.η, d.cutpoints
179+
K = length(cutpoints) + 1
180+
# evaluate probability mass function
181+
ps = map(1:K) do k
182+
exp(unsafe_logpdf_ordered_logistic(η, cutpoints, K, k))
183+
end
184+
return sampler(Categorical(ps))
185+
end
117186

118-
if all(x -> x > zero(x), ps)
119-
return(k)
120-
else
121-
return(-Inf)
187+
# unsafe version without bounds checking
188+
function unsafe_logpdf_ordered_logistic(η, cutpoints, K, k::Int)
189+
@inbounds begin
190+
logp = if k == 1
191+
-StatsFuns.log1pexp- cutpoints[k])
192+
elseif k < K
193+
tmp = StatsFuns.log1pexp(cutpoints[k-1] - η)
194+
-tmp + StatsFuns.log1mexp(tmp - StatsFuns.log1pexp(cutpoints[k] - η))
195+
else
196+
-StatsFuns.log1pexp(cutpoints[k-1] - η)
197+
end
122198
end
199+
return logp
123200
end
124201

125202
"""
126-
Numerically stable Poisson log likelihood.
127-
* `logλ`: log of rate parameter
203+
LogPoisson(logλ)
204+
205+
The *Poisson distribution* with logarithmic parameterization of the rate parameter
206+
descibes the number of independent events occurring within a unit time interval, given the
207+
average rate of occurrence ``exp(logλ)``.
208+
209+
The distribution has the probability mass function
210+
211+
```math
212+
P(X = k) = \\frac{e^{k \\cdot logλ}{k!} e^{-e^{logλ}}, \\quad \\text{ for } k = 0,1,2,\\ldots.
213+
```
214+
215+
See also: [`Poisson`](@ref)
128216
"""
129-
struct LogPoisson{T<:Real} <: DiscreteUnivariateDistribution
217+
struct LogPoisson{T<:Real,S} <: DiscreteUnivariateDistribution
130218
logλ::T
219+
λ::S
220+
221+
function LogPoisson{T}(logλ::T) where T
222+
λ = exp(logλ)
223+
return new{T,typeof(λ)}(logλ, λ)
224+
end
131225
end
132226

133-
function Distributions.logpdf(lp::LogPoisson, k::Int)
134-
return k * lp.logλ - exp(lp.logλ) - SpecialFunctions.loggamma(k + 1)
227+
LogPoisson(logλ::Real) = LogPoisson{typeof(logλ)}(logλ)
228+
229+
Base.minimum(d::LogPoisson) = 0
230+
Base.maximum(d::LogPoisson) = Inf
231+
232+
function _logpdf(d::LogPoisson, k::Real)
233+
_insupport = insupport(d, k)
234+
_k = _insupport ? round(Int, k) : 0
235+
logp = _k * d.logλ - d.λ - SpecialFunctions.loggamma(_k + 1)
236+
237+
return _insupport ? logp : oftype(logp, -Inf)
135238
end
136239

240+
# TODO: only implement `logpdf(d, ::Real)` if support for Distributions < 0.24 is dropped
241+
Distributions.pdf(d::LogPoisson, k::Real) = exp(logpdf(d, k))
242+
Distributions.logpdf(d::LogPoisson, k::Integer) = _logpdf(d, k)
243+
Distributions.logpdf(d::LogPoisson, k::Real) = _logpdf(d, k)
244+
245+
Base.rand(rng::Random.AbstractRNG, d::LogPoisson) = rand(rng, Poisson(d.λ))
246+
Distributions.sampler(d::LogPoisson) = sampler(Poisson(d.λ))
247+
137248
Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0
138249
Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}, ::Bool) = 0
139250
function Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool)

0 commit comments

Comments
 (0)