Skip to content

Commit b2b09bb

Browse files
committed
add tests for scale projection operators, fix bug for meanfield
1 parent 538bafe commit b2b09bb

File tree

2 files changed

+34
-2
lines changed

2 files changed

+34
-2
lines changed

src/families/location_scale.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,14 @@ function (re::RestructureMeanField)(flat::AbstractVector)
4646
n_dims = div(length(flat), 2)
4747
location = first(flat, n_dims)
4848
scale = Diagonal(last(flat, n_dims))
49-
MvLocationScale(location, scale, re.q.dist)
49+
MvLocationScale(location, scale, re.q.dist, re.q.scale_eps)
5050
end
5151

5252
function Optimisers.destructure(
5353
q::MvLocationScale{<:Diagonal, D, L}
5454
) where {D, L}
5555
@unpack location, scale, dist = q
56-
flat = vcat(location, diag(scale))
56+
flat = vcat(location, diag(scale))
5757
flat, RestructureMeanField(q)
5858
end
5959
# end

test/interface/location_scale.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,35 @@
138138
@test q == re(λ)
139139
end
140140
end
141+
142+
@testset "scale positive definite projection" begin
143+
@testset "$(string(covtype)) $(realtype) $(bijector)" for
144+
covtype = [:meanfield, :fullrank],
145+
realtype = [Float32, Float64],
146+
bijector = [nothing, :identity]
147+
148+
d = 5
149+
μ = zeros(realtype, d)
150+
ϵ = sqrt(realtype(0.5))
151+
q = if covtype == :fullrank
152+
L = LowerTriangular(Matrix{realtype}(I,d,d))
153+
FullRankGaussian(μ, L; scale_eps=ϵ)
154+
elseif covtype == :meanfield
155+
L = Diagonal(ones(realtype, d))
156+
MeanFieldGaussian(μ, L; scale_eps=ϵ)
157+
end
158+
q_trans = if isnothing(bijector)
159+
q
160+
else
161+
Bijectors.TransformedDistribution(q, identity)
162+
end
163+
g = deepcopy(q)
164+
165+
λ, re = Optimisers.destructure(q)
166+
grad, _ = Optimisers.destructure(g)
167+
opt_st = Optimisers.setup(Descent(one(realtype)), λ)
168+
_, λ′ = AdvancedVI.update_variational_params!(typeof(q), opt_st, λ, re, grad)
169+
q′ = re(λ′)
170+
@test all(diag(var(q′)) .≥ ϵ^2)
171+
end
172+
end

0 commit comments

Comments
 (0)