Skip to content

Commit b927e80

Browse files
committed
fix reduce computational cost of tests, use more sophisticated tests
1 parent 47aba44 commit b927e80

File tree

6 files changed

+72
-57
lines changed

6 files changed

+72
-57
lines changed

test/inference/repgradelbo_distributionsad.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,16 @@
2121
rng = StableRNG(seed)
2222

2323
modelstats = modelconstr(rng, realtype)
24-
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
24+
@unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats
2525

26-
T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
26+
T = 1000
27+
η = 1e-3
28+
opt = Optimisers.Descent(realtype(η))
29+
30+
# For small enough η, the error of SGD, Δλ, is bounded as
31+
# Δλ ≤ ρ^T Δλ0 + O(η),
32+
# where ρ = 1 - ημ, μ is the strong convexity constant.
33+
contraction_rate = 1 - η*strong_convexity
2734

2835
μ0 = Zeros(realtype, n_dims)
2936
L0 = Diagonal(Ones(realtype, n_dims))
@@ -33,7 +40,7 @@
3340
Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
3441
q, stats, _ = optimize(
3542
rng, model, objective, q0, T;
36-
optimizer = Optimisers.Adam(realtype(η)),
43+
optimizer = opt,
3744
show_progress = PROGRESS,
3845
adtype = adtype,
3946
)
@@ -42,7 +49,7 @@
4249
L = sqrt(cov(q))
4350
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
4451

45-
@test Δλ Δλ0/T^(1/4)
52+
@test Δλ contraction_rate^(T/2)*Δλ0
4653
@test eltype(μ) == eltype(μ_true)
4754
@test eltype(L) == eltype(L_true)
4855
end
@@ -51,7 +58,7 @@
5158
rng = StableRNG(seed)
5259
q, stats, _ = optimize(
5360
rng, model, objective, q0, T;
54-
optimizer = Optimisers.Adam(realtype(η)),
61+
optimizer = opt,
5562
show_progress = PROGRESS,
5663
adtype = adtype,
5764
)
@@ -61,7 +68,7 @@
6168
rng_repl = StableRNG(seed)
6269
q, stats, _ = optimize(
6370
rng_repl, model, objective, q0, T;
64-
optimizer = Optimisers.Adam(realtype(η)),
71+
optimizer = opt,
6572
show_progress = PROGRESS,
6673
adtype = adtype,
6774
)

test/inference/repgradelbo_locationscale.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,16 @@
2222
rng = StableRNG(seed)
2323

2424
modelstats = modelconstr(rng, realtype)
25-
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
25+
@unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats
2626

27-
T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
27+
T = 1000
28+
η = 1e-3
29+
opt = Optimisers.Descent(realtype(η))
30+
31+
# For small enough η, the error of SGD, Δλ, is bounded as
32+
# Δλ ≤ ρ^T Δλ0 + O(η),
33+
# where ρ = 1 - ημ, μ is the strong convexity constant.
34+
contraction_rate = 1 - η*strong_convexity
2835

2936
q0 = if is_meanfield
3037
MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims)))
@@ -37,7 +44,7 @@
3744
Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
3845
q, stats, _ = optimize(
3946
rng, model, objective, q0, T;
40-
optimizer = Optimisers.Adam(realtype(η)),
47+
optimizer = opt,
4148
show_progress = PROGRESS,
4249
adtype = adtype,
4350
)
@@ -46,7 +53,7 @@
4653
L = q.scale
4754
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
4855

49-
@test Δλ Δλ0/T^(1/4)
56+
@test Δλ contraction_rate^(T/2)*Δλ0
5057
@test eltype(μ) == eltype(μ_true)
5158
@test eltype(L) == eltype(L_true)
5259
end
@@ -55,7 +62,7 @@
5562
rng = StableRNG(seed)
5663
q, stats, _ = optimize(
5764
rng, model, objective, q0, T;
58-
optimizer = Optimisers.Adam(realtype(η)),
65+
optimizer = opt,
5966
show_progress = PROGRESS,
6067
adtype = adtype,
6168
)
@@ -65,7 +72,7 @@
6572
rng_repl = StableRNG(seed)
6673
q, stats, _ = optimize(
6774
rng_repl, model, objective, q0, T;
68-
optimizer = Optimisers.Adam(realtype(η)),
75+
optimizer = opt,
6976
show_progress = PROGRESS,
7077
adtype = adtype,
7178
)

test/inference/repgradelbo_locationscale_bijectors.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
rng = StableRNG(seed)
2222

2323
modelstats = modelconstr(rng, realtype)
24-
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats
24+
@unpack model, μ_true, L_true, n_dims, strong_convexity, is_meanfield = modelstats
2525

26-
T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
26+
T = 1000
27+
η = 1e-3
28+
opt = Optimisers.Descent(realtype(η))
2729

2830
b = Bijectors.bijector(model)
2931
b⁻¹ = inverse(b)
@@ -38,11 +40,16 @@
3840
end
3941
q0_z = Bijectors.transformed(q0_η, b⁻¹)
4042

43+
# For small enough η, the error of SGD, Δλ, is bounded as
44+
# Δλ ≤ ρ^T Δλ0 + O(η),
45+
# where ρ = 1 - ημ, μ is the strong convexity constant.
46+
contraction_rate = 1 - η*strong_convexity
47+
4148
@testset "convergence" begin
4249
Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
4350
q, stats, _ = optimize(
4451
rng, model, objective, q0_z, T;
45-
optimizer = Optimisers.Adam(realtype(η)),
52+
optimizer = opt,
4653
show_progress = PROGRESS,
4754
adtype = adtype,
4855
)
@@ -51,7 +58,7 @@
5158
L = q.dist.scale
5259
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)
5360

54-
@test Δλ Δλ0/T^(1/4)
61+
@test Δλ contraction_rate^(T/2)*Δλ0
5562
@test eltype(μ) == eltype(μ_true)
5663
@test eltype(L) == eltype(L_true)
5764
end
@@ -60,7 +67,7 @@
6067
rng = StableRNG(seed)
6168
q, stats, _ = optimize(
6269
rng, model, objective, q0_z, T;
63-
optimizer = Optimisers.Adam(realtype(η)),
70+
optimizer = opt,
6471
show_progress = PROGRESS,
6572
adtype = adtype,
6673
)
@@ -70,7 +77,7 @@
7077
rng_repl = StableRNG(seed)
7178
q, stats, _ = optimize(
7279
rng_repl, model, objective, q0_z, T;
73-
optimizer = Optimisers.Adam(realtype(η)),
80+
optimizer = opt
7481
show_progress = PROGRESS,
7582
adtype = adtype,
7683
)

test/models/normal.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,24 +20,28 @@ end
2020
function normal_fullrank(rng::Random.AbstractRNG, realtype::Type)
2121
n_dims = 5
2222

23-
μ = randn(rng, realtype, n_dims)
24-
L = tril(I + ones(realtype, n_dims, n_dims))/2
25-
Σ = L*L' |> Hermitian
23+
σ0 = realtype(0.3)
24+
μ = Fill(realtype(5), n_dims)
25+
L = Matrix(σ0*I, n_dims, n_dims)
26+
Σ = L*L' |> Hermitian
2627

2728
model = TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0)))
2829

29-
TestModel(model, μ, L, n_dims, false)
30+
TestModel(model, μ, LowerTriangular(L), n_dims, 1/σ0^2, false)
3031
end
3132

3233
function normal_meanfield(rng::Random.AbstractRNG, realtype::Type)
3334
n_dims = 5
3435

35-
μ = randn(rng, realtype, n_dims)
36-
σ = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
36+
σ0 = realtype(0.3)
37+
μ = Fill(realtype(5), n_dims)
38+
#randn(rng, realtype, n_dims)
39+
σ = Fill(σ0, n_dims)
40+
#log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
3741

3842
model = TestNormal(μ, Diagonal.^2))
3943

4044
L = σ |> Diagonal
4145

42-
TestModel(model, μ, L, n_dims, true)
46+
TestModel(model, μ, L, n_dims, 1/σ0^2, true)
4347
end

test/models/normallognormal.jl

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -26,40 +26,29 @@ function Bijectors.bijector(model::NormalLogNormal)
2626
[1:1, 2:1+length(μ_y)])
2727
end
2828

29-
function normallognormal_fullrank(rng::Random.AbstractRNG, realtype::Type)
30-
n_dims = 5
31-
32-
μ_x = randn(rng, realtype)
33-
σ_x =
34-
μ_y = randn(rng, realtype, n_dims)
35-
L_y = tril(I + ones(realtype, n_dims, n_dims))/2
36-
Σ_y = L_y*L_y' |> Hermitian
37-
38-
model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0)))
39-
40-
Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1)
41-
Σ[1,1] = σ_x^2
42-
Σ[2:end,2:end] = Σ_y
43-
Σ = Σ |> Hermitian
44-
45-
μ = vcat(μ_x, μ_y)
46-
L = cholesky(Σ).L
47-
48-
TestModel(model, μ, L, n_dims+1, false)
29+
function normallognormal_fullrank(::Random.AbstractRNG, realtype::Type)
30+
n_y_dims = 5
31+
32+
σ0 = realtype(0.3)
33+
μ = Fill(realtype(5.0), n_y_dims+1)
34+
L = Matrix(σ0*I, n_y_dims+1, n_y_dims+1)
35+
Σ = L*L' |> Hermitian
36+
37+
model = NormalLogNormal(
38+
μ[1], L[1,1], μ[2:end], PDMat(Σ[2:end,2:end], Cholesky(L[2:end,2:end], 'L', 0))
39+
)
40+
TestModel(model, μ, LowerTriangular(L), n_y_dims+1, 1/σ0^2, false)
4941
end
5042

51-
function normallognormal_meanfield(rng::Random.AbstractRNG, realtype::Type)
52-
n_dims = 5
53-
54-
μ_x = randn(rng, realtype)
55-
σ_x =
56-
μ_y = randn(rng, realtype, n_dims)
57-
σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1)
43+
function normallognormal_meanfield(::Random.AbstractRNG, realtype::Type)
44+
n_y_dims = 5
5845

59-
model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2))
46+
σ0 = realtype(0.3)
47+
μ = Fill(realtype(5), n_y_dims + 1)
48+
σ = Fill(σ0, n_y_dims + 1)
49+
L = Diagonal(σ)
6050

61-
μ = vcat(μ_x, μ_y)
62-
L = vcat(σ_x, σ_y) |> Diagonal
51+
model = NormalLogNormal(μ[1], σ[1], μ[2:end], Diagonal(σ[2:end].^2))
6352

64-
TestModel(model, μ, L, n_dims+1, true)
53+
TestModel(model, μ, L, n_y_dims+1, 1/σ0^2, true)
6554
end

test/runtests.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@ using AdvancedVI
2525
const GROUP = get(ENV, "GROUP", "All")
2626

2727
# Models for Inference Tests
28-
struct TestModel{M,L,S}
28+
struct TestModel{M,L,S,SC}
2929
model::M
3030
μ_true::L
3131
L_true::S
3232
n_dims::Int
33+
strong_convexity::SC
3334
is_meanfield::Bool
3435
end
3536
include("models/normal.jl")

0 commit comments

Comments
 (0)