Skip to content
10 changes: 10 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@ export
include("objectives/elbo/entropy.jl")
include("objectives/elbo/repgradelbo.jl")


# Variational Families
export
VILocationScale,
MeanFieldGaussian,
FullRankGaussian

include("families/location_scale.jl")


# Optimization Routine

function optimize end
Expand Down
160 changes: 160 additions & 0 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@

"""
VILocationScale(location, scale, dist) <: ContinuousMultivariateDistribution
The location scale variational family broadly represents various variational
families using `location` and `scale` variational parameters.
It generally represents any distribution for which the sampling path can be
represented as follows:
```julia
d = length(location)
u = rand(dist, d)
z = scale*u + location
```
"""
struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution
location::L
scale ::S
dist ::D
end

Functors.@functor VILocationScale (location, scale)

# Specialization of `Optimisers.destructure` for mean-field location-scale families.
# These are necessary because we only want to extract the diagonal elements of
# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD
# is very inefficient.
# begin
struct RestructureMeanField{L, S<:Diagonal, D}
q::VILocationScale{L, S, D}
end

function (re::RestructureMeanField)(flat::AbstractVector)
n_dims = div(length(flat), 2)
location = first(flat, n_dims)
scale = Diagonal(last(flat, n_dims))
VILocationScale(location, scale, re.q.dist)
end

function Optimisers.destructure(
q::VILocationScale{L, <:Diagonal, D}
) where {L, D}
@unpack location, scale, dist = q
flat = vcat(location, diag(scale))
flat, RestructureMeanField(q)
end
# end

Base.length(q::VILocationScale) = length(q.location)

Base.size(q::VILocationScale) = size(q.location)

Base.eltype(::Type{<:VILocationScale{L, S, D}}) where {L, S, D} = eltype(D)

function StatsBase.entropy(q::VILocationScale)
@unpack location, scale, dist = q
n_dims = length(location)
n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale))
end

function Distributions.logpdf(q::VILocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))
end

function Distributions._logpdf(q::VILocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))
end

function Distributions.rand(q::VILocationScale)
@unpack location, scale, dist = q
n_dims = length(location)
scale*rand(dist, n_dims) + location
end

function Distributions.rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int)
@unpack location, scale, dist = q
n_dims = length(location)
scale*rand(rng, dist, n_dims, num_samples) .+ location
end

# This specialization improves AD performance of the sampling path
function Distributions.rand(
rng::AbstractRNG, q::VILocationScale{L, <:Diagonal, D}, num_samples::Int
) where {L, D}
@unpack location, scale, dist = q
n_dims = length(location)
scale_diag = diag(scale)
scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location
end

function Distributions._rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVecOrMat{<:Real})
@unpack location, scale, dist = q
rand!(rng, dist, x)
x[:] = scale*x
return x .+= location
end

Distributions.mean(q::VILocationScale) = q.location

function Distributions.var(q::VILocationScale)
C = q.scale
Diagonal(C*C')
end

function Distributions.cov(q::VILocationScale)
C = q.scale
Hermitian(C*C')
end

"""
FullRankGaussian(location, scale; check_args = true)
Construct a Gaussian variational approximation with a dense covariance matrix.
# Arguments
- `location::AbstractVector{T}`: Mean of the Gaussian.
- `scale::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian.
# Keyword Arguments
- `check_args`: Check the conditioning of the initial scale (default: `true`).
"""
function FullRankGaussian(
μ::AbstractVector{T},
L::LinearAlgebra.AbstractTriangular{T};
check_args::Bool = true
) where {T <: Real}
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite"
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
end
q_base = Normal{T}(zero(T), one(T))
VILocationScale(μ, L, q_base)
end

"""
MeanFieldGaussian(location, scale; check_args = true)
Construct a Gaussian variational approximation with a diagonal covariance matrix.
# Arguments
- `location::AbstractVector{T}`: Mean of the Gaussian.
- `scale::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian.
# Keyword Arguments
- `check_args`: Check the conditioning of the initial scale (default: `true`).
"""
function MeanFieldGaussian(
μ::AbstractVector{T},
L::Diagonal{T};
check_args::Bool = true
) where {T <: Real}
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor"
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
end
q_base = Normal{T}(zero(T), one(T))
VILocationScale(μ, L, q_base)
end
82 changes: 82 additions & 0 deletions test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@

const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false

using Test

@testset "inference RepGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype ∈ [Float64, Float32],
(modelname, modelconstr) ∈ Dict(
:Normal=> normal_meanfield,
:Normal=> normal_fullrank,
),
(objname, objective) ∈ Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) ∈ Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
#:Enzyme => AutoEnzyme(),
)

seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)

modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats

T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)

q0 = if is_meanfield
MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims)))
else
L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular
FullRankGaussian(zeros(realtype, n_dims), L0)
end

@testset "convergence" begin
Δλ₀ = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)

μ = q.location
L = q.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ ≤ Δλ₀/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end

@testset "determinism" begin
rng = StableRNG(seed)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ = q.location
L = q.scale

rng_repl = StableRNG(seed)
q, stats, _ = optimize(
rng_repl, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ_repl = q.location
L_repl = q.scale
@test μ == μ_repl
@test L == L_repl
end
end
end

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false

using Test

@testset "inference RepGradELBO DistributionsAD Bijectors" begin
@testset "inference RepGradELBO VILocationScale Bijectors" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype ∈ [Float64, Float32],
(modelname, modelconstr) ∈ Dict(
Expand All @@ -15,7 +15,7 @@ using Test
),
(adbackname, adbackend) ∈ Dict(
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(),
:ReverseDiff => AutoReverseDiff(),
#:Zygote => AutoZygote(),
#:Enzyme => AutoEnzyme(),
)
Expand All @@ -30,23 +30,28 @@ using Test

b = Bijectors.bijector(model)
b⁻¹ = inverse(b)
μ₀ = Zeros(realtype, n_dims)
L₀ = Diagonal(Ones(realtype, n_dims))
μ0 = Zeros(realtype, n_dims)
L0 = Diagonal(Ones(realtype, n_dims))

q₀_η = TuringDiagMvNormal(μ₀, diag(L₀))
q₀_z = Bijectors.transformed(q₀_η, b⁻¹)
q0_η = if is_meanfield
MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims)))
else
L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular
FullRankGaussian(zeros(realtype, n_dims), L0)
end
q0_z = Bijectors.transformed(q0_η, b⁻¹)

@testset "convergence" begin
Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
q, stats, _ = optimize(
rng, model, objective, q₀_z, T;
rng, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)

μ = mean(q.dist)
L = sqrt(cov(q.dist))
μ = q.dist.location
L = q.dist.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ ≤ Δλ₀/T^(1/4)
Expand All @@ -57,23 +62,23 @@ using Test
@testset "determinism" begin
rng = StableRNG(seed)
q, stats, _ = optimize(
rng, model, objective, q₀_z, T;
rng, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ = mean(q.dist)
L = sqrt(cov(q.dist))
μ = q.dist.location
L = q.dist.scale

rng_repl = StableRNG(seed)
q, stats, _ = optimize(
rng_repl, model, objective, q₀_z, T;
rng_repl, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ_repl = mean(q.dist)
L_repl = sqrt(cov(q.dist))
μ_repl = q.dist.location
L_repl = q.dist.scale
@test μ == μ_repl
@test L == L_repl
end
Expand Down
Loading