Skip to content

Commit 5ced9c2

Browse files
authored
add indirection for update step, add projection for LocationScale (#65)
* add indirection for update step, add projection for `LocationScale` * add projection for `Bijectors` with `MvLocationScale`
1 parent 95a83c3 commit 5ced9c2

File tree

5 files changed

+127
-22
lines changed

5 files changed

+127
-22
lines changed

ext/AdvancedVIBijectorsExt.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,37 @@ module AdvancedVIBijectorsExt
44
if isdefined(Base, :get_extension)
55
using AdvancedVI
66
using Bijectors
7+
using LinearAlgebra
8+
using Optimisers
79
using Random
810
else
911
using ..AdvancedVI
1012
using ..Bijectors
13+
using ..LinearAlgebra
14+
using ..Optimisers
1115
using ..Random
1216
end
1317

18+
function AdvancedVI.update_variational_params!(
19+
::Type{<:Bijectors.TransformedDistribution{<:AdvancedVI.MvLocationScale}},
20+
opt_st,
21+
params,
22+
restructure,
23+
grad
24+
)
25+
opt_st, params = Optimisers.update!(opt_st, params, grad)
26+
q = restructure(params)
27+
ϵ = q.dist.scale_eps
28+
29+
# Project the scale matrix to the set of positive definite triangular matrices
30+
diag_idx = diagind(q.dist.scale)
31+
@. q.dist.scale[diag_idx] = max(q.dist.scale[diag_idx], ϵ)
32+
33+
params, _ = Optimisers.destructure(q)
34+
35+
opt_st, params
36+
end
37+
1438
function AdvancedVI.reparam_with_entropy(
1539
rng ::Random.AbstractRNG,
1640
q ::Bijectors.TransformedDistribution,

src/AdvancedVI.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,33 @@ Evaluate the value and gradient of a function `f` at `θ` using the automatic di
3737
"""
3838
function value_and_gradient! end
3939

40+
# Update for gradient descent step
41+
"""
42+
update_variational_params!(family_type, opt_st, params, restructure, grad)
43+
44+
Update variational distribution according to the update rule in the optimizer state `opt_st` and the variational family `family_type`.
45+
46+
This is a wrapper around `Optimisers.update!` to provide some indirection.
47+
For example, depending on the optimizer and the variational family, this may do additional things such as applying projection or proximal mappings.
48+
Same as the default behavior of `Optimisers.update!`, `params` and `opt_st` may be updated by the routine and are no longer valid after calling this functino.
49+
Instead, the return values should be used.
50+
51+
# Arguments
52+
- `family_type::Type`: Type of the variational family `typeof(restructure(params))`.
53+
- `opt_st`: Optimizer state returned by `Optimisers.setup`.
54+
- `params`: Current set of parameters to be updated.
55+
- `restructure`: Callable for restructuring the varitional distribution from `params`.
56+
- `grad`: Gradient to be used by the update rule of `opt_st`.
57+
58+
# Returns
59+
- `opt_st`: Updated optimizer state.
60+
- `params`: Updated parameters.
61+
"""
62+
function update_variational_params! end
63+
64+
update_variational_params!(::Type, opt_st, params, restructure, grad) =
65+
Optimisers.update!(opt_st, params, grad)
66+
4067
# estimators
4168
"""
4269
AbstractVariationalObjective

src/families/location_scale.jl

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,21 @@ represented as follows:
1414
```
1515
"""
1616
struct MvLocationScale{
17-
S, D <: ContinuousDistribution, L
17+
S, D <: ContinuousDistribution, L, E <: Real
1818
} <: ContinuousMultivariateDistribution
19-
location::L
20-
scale ::S
21-
dist ::D
19+
location ::L
20+
scale ::S
21+
dist ::D
22+
scale_eps::E
23+
end
24+
25+
function MvLocationScale(
26+
location ::AbstractVector{T},
27+
scale ::AbstractMatrix{T},
28+
dist ::ContinuousDistribution;
29+
scale_eps::T = sqrt(eps(T))
30+
) where {T <: Real}
31+
MvLocationScale(location, scale, dist, scale_eps)
2232
end
2333

2434
Functors.@functor MvLocationScale (location, scale)
@@ -36,14 +46,14 @@ function (re::RestructureMeanField)(flat::AbstractVector)
3646
n_dims = div(length(flat), 2)
3747
location = first(flat, n_dims)
3848
scale = Diagonal(last(flat, n_dims))
39-
MvLocationScale(location, scale, re.q.dist)
49+
MvLocationScale(location, scale, re.q.dist, re.q.scale_eps)
4050
end
4151

4252
function Optimisers.destructure(
4353
q::MvLocationScale{<:Diagonal, D, L}
4454
) where {D, L}
4555
@unpack location, scale, dist = q
46-
flat = vcat(location, diag(scale))
56+
flat = vcat(location, diag(scale))
4757
flat, RestructureMeanField(q)
4858
end
4959
# end
@@ -57,17 +67,17 @@ Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D)
5767
function StatsBase.entropy(q::MvLocationScale)
5868
@unpack location, scale, dist = q
5969
n_dims = length(location)
60-
n_dims*convert(eltype(location), entropy(dist)) + first(logdet(scale))
70+
n_dims*convert(eltype(location), entropy(dist)) + logdet(scale)
6171
end
6272

6373
function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
6474
@unpack location, scale, dist = q
65-
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale))
75+
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale)
6676
end
6777

6878
function Distributions._logpdf(q::MvLocationScale, z::AbstractVector{<:Real})
6979
@unpack location, scale, dist = q
70-
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logdet(scale))
80+
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - logdet(scale)
7181
end
7282

7383
function Distributions.rand(q::MvLocationScale)
@@ -128,14 +138,11 @@ Construct a Gaussian variational approximation with a dense covariance matrix.
128138
function FullRankGaussian(
129139
μ::AbstractVector{T},
130140
L::LinearAlgebra.AbstractTriangular{T};
131-
check_args::Bool = true
141+
scale_eps::T = sqrt(eps(T))
132142
) where {T <: Real}
133-
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite"
134-
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
135-
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
136-
end
143+
@assert minimum(diag(L)) sqrt(scale_eps) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior."
137144
q_base = Normal{T}(zero(T), one(T))
138-
MvLocationScale(μ, L, q_base)
145+
MvLocationScale(μ, L, q_base, scale_eps)
139146
end
140147

141148
"""
@@ -153,12 +160,25 @@ Construct a Gaussian variational approximation with a diagonal covariance matrix
153160
function MeanFieldGaussian(
154161
μ::AbstractVector{T},
155162
L::Diagonal{T};
156-
check_args::Bool = true
163+
scale_eps::T = sqrt(eps(T)),
157164
) where {T <: Real}
158-
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor"
159-
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
160-
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."
161-
end
165+
@assert minimum(diag(L)) sqrt(eps(eltype(L))) "Initial scale is too small (smallest diagonal value is $(minimum(diag(L)))). This might result in unstable optimization behavior."
162166
q_base = Normal{T}(zero(T), one(T))
163-
MvLocationScale(μ, L, q_base)
167+
MvLocationScale(μ, L, q_base, scale_eps)
168+
end
169+
170+
function update_variational_params!(
171+
::Type{<:MvLocationScale}, opt_st, params, restructure, grad
172+
)
173+
opt_st, params = Optimisers.update!(opt_st, params, grad)
174+
q = restructure(params)
175+
ϵ = q.scale_eps
176+
177+
# Project the scale matrix to the set of positive definite triangular matrices
178+
diag_idx = diagind(q.scale)
179+
@. q.scale[diag_idx] = max(q.scale[diag_idx], ϵ)
180+
181+
params, _ = Optimisers.destructure(q)
182+
183+
opt_st, params
164184
end

src/optimize.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,9 @@ function optimize(
8080
stat = merge(stat, stat′)
8181

8282
grad = DiffResults.gradient(grad_buf)
83-
opt_st, params = Optimisers.update!(opt_st, params, grad)
83+
opt_st, params = update_variational_params!(
84+
typeof(q_init), opt_st, params, restructure, grad
85+
)
8486

8587
if !isnothing(callback)
8688
stat′ = callback(

test/interface/location_scale.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,35 @@
138138
@test q == re(λ)
139139
end
140140
end
141+
142+
@testset "scale positive definite projection" begin
143+
@testset "$(string(covtype)) $(realtype) $(bijector)" for
144+
covtype = [:meanfield, :fullrank],
145+
realtype = [Float32, Float64],
146+
bijector = [nothing, :identity]
147+
148+
d = 5
149+
μ = zeros(realtype, d)
150+
ϵ = sqrt(realtype(0.5))
151+
q = if covtype == :fullrank
152+
L = LowerTriangular(Matrix{realtype}(I,d,d))
153+
FullRankGaussian(μ, L; scale_eps=ϵ)
154+
elseif covtype == :meanfield
155+
L = Diagonal(ones(realtype, d))
156+
MeanFieldGaussian(μ, L; scale_eps=ϵ)
157+
end
158+
q_trans = if isnothing(bijector)
159+
q
160+
else
161+
Bijectors.TransformedDistribution(q, identity)
162+
end
163+
g = deepcopy(q)
164+
165+
λ, re = Optimisers.destructure(q)
166+
grad, _ = Optimisers.destructure(g)
167+
opt_st = Optimisers.setup(Descent(one(realtype)), λ)
168+
_, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad)
169+
q′ = re(λ′)
170+
@test all(diag(var(q′)) .≥ ϵ^2)
171+
end
172+
end

0 commit comments

Comments
 (0)