Skip to content

Commit 43c06ab

Browse files
zuhengxutorfjelde
andauthored
add more synthetic targets (#20)
* add Neal's Funnel and Warped Gaussian * fixed bug in warped gaussian * add reference for warped Gauss * add Cross dsitribution * udpate docs for cross * Update example/targets/cross.jl change comment into docs Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update example/targets/cross.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * update cross docs * minor ed * Update example/targets/neal_funnel.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update example/targets/neal_funnel.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * doc banana using * fixing docs with latex * baanan docs with latex * add NF quick intro * Revert "add NF quick intro" This reverts commit e399274. * rm unnecesary code for cross * rm example/manifest * Update example/targets/neal_funnel.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update example/targets/cross.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update example/targets/cross.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * minor update to cross docs --------- Co-authored-by: Tor Erlend Fjelde <[email protected]>
1 parent 45101e0 commit 43c06ab

File tree

4 files changed

+206
-8
lines changed

4 files changed

+206
-8
lines changed

example/targets/banana.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,25 @@ using IrrationalConstants
88
Multidimensional banana-shape distribution.
99
1010
# Fields
11-
- 'dim::Int': Dimension of the distribution, must be >= 2
12-
- 'b::T': Banananicity constant, the larger "|b|" the more curved the banana
13-
- 'var::T': Variance of the first dimension, must be > 0
14-
11+
$(FIELDS)
1512
1613
# Explanation
1714
1815
The banana distribution is obtained by applying a transformation ϕ to a multivariate normal
19-
distribution "N(0, diag(var, 1, 1, …, 1))". The transformation ϕ is defined as
20-
"ϕ(x₁, … , xₚ) = (x₁, x₂ - B x₁² + var*B, x₃, … , xₚ)",
16+
distribution ``\\mathcal{N}(0, \\text{diag}(var, 1, 1, …, 1))``. The transformation ϕ is defined as
17+
```math
18+
\phi(x_1, … , x_p) = (x_1, x_2 - B x_1^² + \text{var}*B, x_3, … , x_p)
19+
````
2120
which has a unit Jacobian determinant.
2221
2322
Hence the density "fb" of a p-dimensional banana distribution is given by
24-
"fb(x₁, … , xₚ) = exp[ -½x₁²/var - ½(x₂ + B x₁² - var*B)² - ½(x₃² + x₄² + … + xₚ²)] / Z",
23+
```math
24+
fb(x_1, \dots, x_p) = \exp\left[ -\frac{1}{2}\frac{x_1^2}{\text{var}} -
25+
\frac{1}{2}(x_2 + B x_1^2 - \text{var}*B)^2 - \frac{1}{2}(x_3^2 + x_4^2 + \dots
26+
+ x_p^2) \right] / Z,
27+
```
2528
where "B" is the "banananicity" constant, determining the curvature of a banana, and
26-
"Z = sqrt(var * ()^p))" is the normalization constant.
29+
``Z = \\sqrt{\\text{var} * (2\\pi)^p)}`` is the normalization constant.
2730
2831
2932
# Reference
@@ -33,8 +36,11 @@ Gareth O. Roberts and Jeffrey S. Rosenthal
3336
Journal of computational and graphical statistics, Volume 18, Number 2 (2009): 349-367.
3437
"""
3538
struct Banana{T<:Real} <: ContinuousMultivariateDistribution
39+
"Dimension of the distribution, must be >= 2"
3640
dim::Int # Dimension
41+
"Banananicity constant, the larger |b| the more curved the banana"
3742
b::T # Curvature
43+
"Variance of the first dimension, must be > 0"
3844
var::T # Variance
3945
function Banana{T}(dim::Int, b::T, var::T) where {T<:Real}
4046
dim >= 2 || error("dim must be >= 2")

example/targets/cross.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using Distributions, Random
2+
"""
3+
Cross(μ::Real=2.0, σ::Real=0.15)
4+
5+
2-dimensional Cross distribution
6+
7+
8+
# Explanation
9+
10+
The Cross distribution is a 2-dimension 4-component Gaussian distribution with a "cross"
11+
shape that is symmetric about the y- and x-axises. The mixture is defined as
12+
13+
```math
14+
\begin{aligned}
15+
p(x) =
16+
& 0.25 \mathcal{N}(x | (0, \mu), (\sigma, 1)) + \\
17+
& 0.25 \mathcal{N}(x | (\mu, 0), (1, \sigma)) + \\
18+
& 0.25 \mathcal{N}(x | (0, -\mu), (\sigma, 1)) + \\
19+
& 0.25 \mathcal{N}(x | (-\mu, 0), (1, \sigma)))
20+
\end{aligned}
21+
```
22+
23+
where ``μ`` and ``σ`` are the mean and standard deviation of the Gaussian components,
24+
respectively. See an example of the Cross distribution in Page 18 of [1].
25+
26+
# Reference
27+
[1] Zuheng Xu, Naitong Chen, Trevor Campbell
28+
"MixFlows: principled variational inference via mixed flows."
29+
International Conference on Machine Learning, 2023
30+
"""
31+
Cross() = Cross(2.0, 0.15)
32+
function Cross::T, σ::T) where {T<:Real}
33+
return MixtureModel([
34+
MvNormal([zero(μ), μ], [σ, one(σ)]),
35+
MvNormal([-μ, one(μ)], [one(σ), σ]),
36+
MvNormal([μ, one(μ)], [one(σ), σ]),
37+
MvNormal([zero(μ), -μ], [σ, one(σ)]),
38+
])
39+
end

example/targets/neal_funnel.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
using Distributions, Random
2+
3+
"""
4+
Funnel{T<:Real}
5+
6+
Multidimensional Neal's Funnel distribution
7+
8+
# Fields
9+
$(FIELDS)
10+
11+
# Explanation
12+
13+
The Neal's Funnel distribution is a p-dimensional distribution with a funnel shape,
14+
originally proposed by Radford Neal in [2].
15+
The marginal distribution of ``x_1`` is Gaussian with mean "μ" and standard
16+
deviation "σ". The conditional distribution of ``x_2, \dots, x_p | x_1`` are independent
17+
Gaussian distributions with mean 0 and standard deviation ``\\exp(x_1/2)``.
18+
The generative process is given by
19+
```math
20+
x_1 \sim \mathcal{N}(\mu, \sigma^2), \quad x_2, \ldots, x_p \sim \mathcal{N}(0, \exp(x_1))
21+
```
22+
23+
24+
# Reference
25+
[1] Stan User’s Guide:
26+
https://mc-stan.org/docs/2_18/stan-users-guide/reparameterization-section.html#ref-Neal:2003
27+
[2] Radford Neal 2003. “Slice Sampling.” Annals of Statistics 31 (3): 705–67.
28+
"""
29+
struct Funnel{T<:Real} <: ContinuousMultivariateDistribution
30+
"Dimension of the distribution, must be >= 2"
31+
dim::Int
32+
"Mean of the first dimension"
33+
μ::T
34+
"Standard deviation of the first dimension, must be > 0"
35+
σ::T
36+
function Funnel{T}(dim::Int, μ::T, σ::T) where {T<:Real}
37+
dim >= 2 || error("dim must be >= 2")
38+
σ > 0 || error("σ must be > 0")
39+
return new{T}(dim, μ, σ)
40+
end
41+
end
42+
Funnel(dim::Int, μ::T, σ::T) where {T<:Real} = Funnel{T}(dim, μ, σ)
43+
Funnel(dim::Int, σ::T) where {T<:Real} = Funnel{T}(dim, zero(T), σ)
44+
Funnel(dim::Int) = Funnel(dim, 0.0, 9.0)
45+
46+
Base.length(p::Funnel) = p.dim
47+
Base.eltype(p::Funnel{T}) where {T<:Real} = T
48+
49+
function Distributions._rand!(rng::AbstractRNG, p::Funnel, x::AbstractVecOrMat)
50+
T = eltype(x)
51+
d, μ, σ = p.dim, p.μ, p.σ
52+
d == size(x, 1) || error("Dimension mismatch")
53+
x[1, :] .= randn(rng, T, size(x, 2)) .* σ .+ μ
54+
x[2:end, :] .= randn(rng, T, d - 1, size(x, 2)) .* exp.(@view(x[1, :]) ./ 2)'
55+
return x
56+
end
57+
58+
function Distributions._logpdf(p::Funnel, x::AbstractVector)
59+
d, μ, σ = p.dim, p.μ, p.σ
60+
lpdf1 = logpdf(Normal(μ, σ), x[1])
61+
lpdfs = logpdf.(Normal.(zeros(T, d - 1), exp(x[1] / 2)), @view(x[2:end]))
62+
return lpdf1 + sum(lpdfs)
63+
end

example/targets/warped_gaussian.jl

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
using Distributions, Random, LinearAlgebra, IrrationalConstants
2+
3+
"""
4+
WarpedGauss{T<:Real}
5+
6+
2-dimensional warped Gaussian distribution
7+
8+
# Fields
9+
$(FIELDS)
10+
11+
# Explanation
12+
13+
The banana distribution is obtained by applying a transformation ϕ to a 2-dimensional normal
14+
distribution ``\\mathcal{N}(0, diag(\\sigma_1, \\sigma_2))``. The transformation ϕ(x) is defined as
15+
```math
16+
ϕ(x_1, x_2) = (r*\cos(\theta + r/2), r*\sin(\theta + r/2)),
17+
```
18+
where ``r = \\sqrt{x\_1^2 + x_2^2}``, ``\\theta = \\atan(x₂, x₁)``,
19+
and "atan(y, x) ∈ [-π, π]" is the angle, in radians, between the positive x axis and the
20+
ray to the point "(x, y)". See page 18. of [1] for reference.
21+
22+
23+
# Reference
24+
[1] Zuheng Xu, Naitong Chen, Trevor Campbell
25+
"MixFlows: principled variational inference via mixed flows."
26+
International Conference on Machine Learning, 2023
27+
"""
28+
struct WarpedGauss{T<:Real} <: ContinuousMultivariateDistribution
29+
"Standard deviation of the first dimension, must be > 0"
30+
σ1::T
31+
"Standard deviation of the second dimension, must be > 0"
32+
σ2::T
33+
function WarpedGauss{T}(σ1, σ2) where {T<:Real}
34+
σ1 > 0 || error("σ₁ must be > 0")
35+
σ2 > 0 || error("σ₂ must be > 0")
36+
return new{T}(σ1, σ2)
37+
end
38+
end
39+
WarpedGauss(σ1::T, σ2::T) where {T<:Real} = WarpedGauss{T}(σ1, σ2)
40+
WarpedGauss() = WarpedGauss(1.0, 0.12)
41+
42+
Base.length(p::WarpedGauss) = 2
43+
Base.eltype(p::WarpedGauss{T}) where {T<:Real} = T
44+
Distributions.sampler(p::WarpedGauss) = p
45+
46+
# Define the transformation function φ and the inverse ϕ⁻¹ for the warped Gaussian distribution
47+
function ϕ!(p::WarpedGauss, z::AbstractVector)
48+
length(z) == 2 || error("Dimension mismatch")
49+
x, y = z
50+
r = norm(z)
51+
θ = atan(y, x) #in [-π , π]
52+
θ -= r / 2
53+
z .= r .* [cos(θ), sin(θ)]
54+
return z
55+
end
56+
57+
function ϕ⁻¹(p::WarpedGauss, z::AbstractVector)
58+
length(z) == 2 || error("Dimension mismatch")
59+
x, y = z
60+
r = norm(z)
61+
θ = atan(y, x) #in [-π , π]
62+
# increase θ depending on r to "smear"
63+
θ += r / 2
64+
65+
# get the x,y coordinates foαtransformed point
66+
xn = r * cos(θ)
67+
yn = r * sin(θ)
68+
# compute jacobian
69+
logJ = log(r)
70+
return [xn, yn], logJ
71+
end
72+
73+
function Distributions._rand!(rng::AbstractRNG, p::WarpedGauss, x::AbstractVecOrMat)
74+
size(x, 1) == 2 || error("Dimension mismatch")
75+
σ₁, σ₂ = p.σ₁, p.σ₂
76+
randn!(rng, x)
77+
x .*= [σ₁, σ₂]
78+
for y in eachcol(x)
79+
ϕ!(p, y)
80+
end
81+
return x
82+
end
83+
84+
function Distributions._logpdf(p::WarpedGauss, x::AbstractVector)
85+
size(x, 1) == 2 || error("Dimension mismatch")
86+
σ₁, σ₂ = p.σ₁, p.σ₂
87+
S = [σ₁, σ₂] .^ 2
88+
z, logJ = ϕ⁻¹(p, x)
89+
return -sum(z .^ 2 ./ S) / 2 - IrrationalConstants.log2π - log(σ₁) - log(σ₂) + logJ
90+
end

0 commit comments

Comments
 (0)