Skip to content

Commit f5a2f9a

Browse files
committed
fix warpgaussian logpdf error/making neal funnel logpdf working with mooncake
1 parent 51bc5b9 commit f5a2f9a

File tree

4 files changed

+12
-20
lines changed

4 files changed

+12
-20
lines changed

example/targets/banana.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
using Distributions, Random
2-
using Plots
3-
using IrrationalConstants
4-
51
"""
62
Banana{T<:Real}
73

example/targets/cross.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using Distributions, Random
21
"""
32
Cross(μ::Real=2.0, σ::Real=0.15)
43

example/targets/neal_funnel.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using Distributions, Random
2-
31
"""
42
Funnel{T<:Real}
53
@@ -45,18 +43,19 @@ Funnel(dim::Int) = Funnel(dim, 0.0, 9.0)
4543
Base.length(p::Funnel) = p.dim
4644
Base.eltype(p::Funnel{T}) where {T<:Real} = T
4745

48-
function Distributions._rand!(rng::AbstractRNG, p::Funnel, x::AbstractVecOrMat)
49-
T = eltype(x)
46+
function Distributions._rand!(rng::AbstractRNG, p::Funnel{T}, x::AbstractVecOrMat{T}) where {T<:Real}
5047
d, μ, σ = p.dim, p.μ, p.σ
5148
d == size(x, 1) || error("Dimension mismatch")
5249
x[1, :] .= randn(rng, T, size(x, 2)) .* σ .+ μ
5350
x[2:end, :] .= randn(rng, T, d - 1, size(x, 2)) .* exp.(@view(x[1, :]) ./ 2)'
5451
return x
5552
end
5653

57-
function Distributions._logpdf(p::Funnel, x::AbstractVector)
54+
function Distributions._logpdf(p::Funnel{T}, x::AbstractVector{T}) where {T<:Real}
5855
d, μ, σ = p.dim, p.μ, p.σ
59-
lpdf1 = logpdf(Normal(μ, σ), x[1])
60-
lpdfs = logpdf.(Normal.(zeros(T, d - 1), exp(x[1] / 2)), @view(x[2:end]))
61-
return lpdf1 + sum(lpdfs)
56+
x1 = x[1]
57+
x2 = x[2:end]
58+
lpdf_x1 = logpdf(Normal(μ, σ), x1)
59+
lpdf_x2_given_1 = logpdf(MvNormal(zeros(T, d-1), exp(x1)I), x2)
60+
return lpdf_x1 + lpdf_x2_given_1
6261
end

example/targets/warped_gaussian.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
using Distributions, Random, LinearAlgebra, IrrationalConstants
2-
31
"""
42
WarpedGauss{T<:Real}
53
@@ -39,11 +37,11 @@ WarpedGauss(σ1::T, σ2::T) where {T<:Real} = WarpedGauss{T}(σ1, σ2)
3937
WarpedGauss() = WarpedGauss(1.0, 0.12)
4038

4139
Base.length(p::WarpedGauss) = 2
42-
Base.eltype(p::WarpedGauss{T}) where {T<:Real} = T
40+
Base.eltype(::WarpedGauss{T}) where {T<:Real} = T
4341
Distributions.sampler(p::WarpedGauss) = p
4442

4543
# Define the transformation function φ and the inverse ϕ⁻¹ for the warped Gaussian distribution
46-
function ϕ!(p::WarpedGauss, z::AbstractVector)
44+
function ϕ!(::WarpedGauss, z::AbstractVector)
4745
length(z) == 2 || error("Dimension mismatch")
4846
x, y = z
4947
r = norm(z)
@@ -53,7 +51,7 @@ function ϕ!(p::WarpedGauss, z::AbstractVector)
5351
return z
5452
end
5553

56-
function ϕ⁻¹(p::WarpedGauss, z::AbstractVector)
54+
function ϕ⁻¹(::WarpedGauss, z::AbstractVector)
5755
length(z) == 2 || error("Dimension mismatch")
5856
x, y = z
5957
r = norm(z)
@@ -71,7 +69,7 @@ end
7169

7270
function Distributions._rand!(rng::AbstractRNG, p::WarpedGauss, x::AbstractVecOrMat)
7371
size(x, 1) == 2 || error("Dimension mismatch")
74-
σ₁, σ₂ = p.σ₁, p.σ₂
72+
σ₁, σ₂ = p.σ1, p.σ2
7573
randn!(rng, x)
7674
x .*= [σ₁, σ₂]
7775
for y in eachcol(x)
@@ -82,7 +80,7 @@ end
8280

8381
function Distributions._logpdf(p::WarpedGauss, x::AbstractVector)
8482
size(x, 1) == 2 || error("Dimension mismatch")
85-
σ₁, σ₂ = p.σ₁, p.σ₂
83+
σ₁, σ₂ = p.σ1, p.σ2
8684
S = [σ₁, σ₂] .^ 2
8785
z, logJ = ϕ⁻¹(p, x)
8886
return -sum(z .^ 2 ./ S) / 2 - IrrationalConstants.log2π - log(σ₁) - log(σ₂) + logJ

0 commit comments

Comments
 (0)