Skip to content

Commit fb3f410

Browse files
authored
create ApproximateGPs.TestUtils (#117)
* create ApproximateGPs.TestUtils * revert reexporting AbstractGPs
1 parent 6a5877f commit fb3f410

File tree

6 files changed

+139
-76
lines changed

6 files changed

+139
-76
lines changed

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ApproximateGPs"
22
uuid = "298c2ebc-0411-48ad-af38-99e88101b606"
33
authors = ["JuliaGaussianProcesses Team"]
4-
version = "0.3.2"
4+
version = "0.3.3"
55

66
[deps]
77
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
@@ -13,11 +13,14 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1313
GPLikelihoods = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
1414
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
1515
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
16+
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
1617
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
18+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1719
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1820
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1921
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2022
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
23+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2124

2225
[compat]
2326
AbstractGPs = "0.3, 0.4, 0.5"

src/ApproximateGPs.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module ApproximateGPs
22

33
using Reexport
44

5+
@reexport using AbstractGPs
56
@reexport using GPLikelihoods
67

78
include("API.jl")
@@ -22,4 +23,6 @@ include("LaplaceApproximationModule.jl")
2223

2324
include("deprecations.jl")
2425

26+
include("TestUtils.jl")
27+
2528
end

src/TestUtils.jl

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
module TestUtils
2+
3+
using LinearAlgebra
4+
using Random
5+
using Test
6+
7+
using Distributions
8+
using LogExpFunctions: logistic, softplus
9+
10+
using AbstractGPs
11+
using ApproximateGPs
12+
13+
function generate_data()
14+
X = range(0, 23.5; length=48)
15+
# The random number generator changed in 1.6->1.7. The following vector was generated in Julia 1.6.
16+
# The generating code below is only kept for illustrative purposes.
17+
#! format: off
18+
Y = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
19+
#! format: on
20+
# Random.seed!(1)
21+
# fs = @. 3 * sin(10 + 0.6X) + sin(0.1X) - 1
22+
# # invlink = normcdf
23+
# invlink = logistic
24+
# ps = invlink.(fs)
25+
# Y = @. rand(Bernoulli(ps))
26+
return X, Y
27+
end
28+
29+
dist_y_given_f(f) = Bernoulli(logistic(f))
30+
31+
function build_latent_gp(theta)
32+
variance = softplus(theta[1])
33+
lengthscale = softplus(theta[2])
34+
kernel = variance * with_lengthscale(SqExponentialKernel(), lengthscale)
35+
return LatentGP(GP(kernel), dist_y_given_f, 1e-8)
36+
end
37+
38+
"""
39+
test_approx_lml(approx)
40+
41+
Test whether in the conjugate case `approx_lml(approx, LatentGP(f,
42+
GaussianLikelihood(), jitter)(x), y)` gives approximately the same answer as
43+
the log marginal likelihood in exact GP regression.
44+
45+
!!! todo
46+
Not yet implemented.
47+
48+
Will not necessarily work for approximations that rely on optimization such
49+
as `SparseVariationalApproximation`.
50+
51+
!!! todo
52+
Also test gradients (for hyperparameter optimization).
53+
"""
54+
function test_approx_lml end
55+
56+
"""
57+
test_approximation_predictions(approx)
58+
59+
Test whether the prediction interface for `approx` works and whether in the
60+
conjugate case `posterior(approx, LatentGP(f, GaussianLikelihood(), jitter)(x), y)`
61+
gives approximately the same answer as the exact GP regression posterior.
62+
63+
!!! note
64+
Should be satisfied by all approximate inference methods, but note that
65+
this does not currently apply for some approximations which rely on
66+
optimization such as `SparseVariationalApproximation`.
67+
68+
!!! warning
69+
Do not rely on this as the only test of a new approximation!
70+
71+
See `test_approx_lml`.
72+
"""
73+
function test_approximation_predictions(approx)
74+
rng = MersenneTwister(123456)
75+
N_cond = 5
76+
N_a = 6
77+
N_b = 7
78+
79+
# Specify prior.
80+
f = GP(Matern32Kernel())
81+
# Sample from prior.
82+
x = collect(range(-1.0, 1.0; length=N_cond))
83+
# TODO: Change to x = ColVecs(rand(2, N_cond)) once #109 is fixed
84+
noise_scale = 0.1
85+
fx = f(x, noise_scale^2)
86+
y = rand(rng, fx)
87+
88+
jitter = 0.0 # not needed in Gaussian case
89+
lf = LatentGP(f, f -> Normal(f, noise_scale), jitter)
90+
f_approx_post = posterior(approx, lf(x), y)
91+
92+
@testset "AbstractGPs API" begin
93+
a = collect(range(-1.2, 1.2; length=N_a))
94+
b = randn(rng, N_b)
95+
AbstractGPs.TestUtils.test_internal_abstractgps_interface(rng, f_approx_post, a, b)
96+
end
97+
98+
@testset "exact GPR equivalence for Gaussian likelihood" begin
99+
f_exact_post = posterior(f(x, noise_scale^2), y)
100+
xt = vcat(x, randn(rng, 3)) # test at training and new points
101+
102+
m_approx, c_approx = mean_and_cov(f_approx_post(xt))
103+
m_exact, c_exact = mean_and_cov(f_exact_post(xt))
104+
105+
@test m_approx m_exact
106+
@test c_approx c_exact
107+
end
108+
end
109+
110+
end

test/LaplaceApproximationModule.jl

Lines changed: 5 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,7 @@
11
@testset "laplace" begin
2-
function generate_data()
3-
X = range(0, 23.5; length=48)
4-
# The random number generator changed in 1.6->1.7. The following vector was generated in Julia 1.6.
5-
# The generating code below is only kept for illustrative purposes.
6-
#! format: off
7-
Y = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
8-
#! format: on
9-
# Random.seed!(1)
10-
# fs = @. 3 * sin(10 + 0.6X) + sin(0.1X) - 1
11-
# # invlink = normcdf
12-
# invlink = logistic
13-
# ps = invlink.(fs)
14-
# Y = [rand(Bernoulli(p)) for p in ps]
15-
return X, Y
16-
end
17-
18-
dist_y_given_f(f) = Bernoulli(logistic(f))
19-
20-
function build_latent_gp(theta)
21-
variance = softplus(theta[1])
22-
lengthscale = softplus(theta[2])
23-
kernel = variance * with_lengthscale(SqExponentialKernel(), lengthscale)
24-
return LatentGP(GP(kernel), dist_y_given_f, 1e-8)
25-
end
2+
generate_data = ApproximateGPs.TestUtils.generate_data
3+
dist_y_given_f = ApproximateGPs.TestUtils.dist_y_given_f
4+
build_latent_gp = ApproximateGPs.TestUtils.build_latent_gp
265

276
function optimize_elbo(
287
build_latent_gp,
@@ -49,43 +28,8 @@
4928
end
5029

5130
@testset "predictions" begin
52-
rng = MersenneTwister(123456)
53-
N_cond = 5
54-
N_a = 6
55-
N_b = 7
56-
57-
# Specify prior.
58-
f = GP(Matern32Kernel())
59-
# Sample from prior.
60-
x = collect(range(-1.0, 1.0; length=N_cond))
61-
noise_scale = 0.1
62-
fx = f(x, noise_scale^2)
63-
y = rand(rng, fx)
64-
65-
jitter = 0.0 # not needed in Gaussian case
66-
lf = LatentGP(f, f -> Normal(f, noise_scale), jitter)
67-
# in Gaussian case, Laplace converges to f_opt in one step; we need the
68-
# second step to compute the cache at f_opt rather than f_init!
69-
f_approx_post = posterior(LaplaceApproximation(; maxiter=2), lf(x), y)
70-
71-
@testset "AbstractGPs API" begin
72-
a = collect(range(-1.2, 1.2; length=N_a))
73-
b = randn(rng, N_b)
74-
AbstractGPs.TestUtils.test_internal_abstractgps_interface(
75-
rng, f_approx_post, a, b
76-
)
77-
end
78-
79-
@testset "equivalence to exact GPR for Gaussian likelihood" begin
80-
f_exact_post = posterior(f(x, noise_scale^2), y)
81-
xt = vcat(x, randn(rng, 3)) # test at training and new points
82-
83-
m_approx, c_approx = mean_and_cov(f_approx_post(xt))
84-
m_exact, c_exact = mean_and_cov(f_exact_post(xt))
85-
86-
@test m_approx m_exact
87-
@test c_approx c_exact
88-
end
31+
approx = LaplaceApproximation(; maxiter=2)
32+
ApproximateGPs.TestUtils.test_approximation_predictions(approx)
8933
end
9034

9135
@testset "gradients" begin

test/SparseVariationalApproximationModule.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
b = randn(rng, N_b)
2929

3030
@testset "AbstractGPs interface - Centered" begin
31-
TestUtils.test_internal_abstractgps_interface(rng, f_approx_post_Centered, a, b)
31+
AbstractGPs.TestUtils.test_internal_abstractgps_interface(
32+
rng, f_approx_post_Centered, a, b
33+
)
3234
end
3335

3436
@testset "NonCentered" begin
@@ -50,7 +52,7 @@
5052
f_approx_post_non_Centered = posterior(approx_non_Centered)
5153

5254
@testset "AbstractGPs interface - NonCentered" begin
53-
TestUtils.test_internal_abstractgps_interface(
55+
AbstractGPs.TestUtils.test_internal_abstractgps_interface(
5456
rng, f_approx_post_non_Centered, a, b
5557
)
5658
end
@@ -170,7 +172,7 @@
170172

171173
# Train the SVGP model
172174
data = [(x, y)]
173-
opt = ADAM(0.001)
175+
opt = Flux.ADAM(0.001)
174176

175177
svgp_ps = Flux.params(svgp_model)
176178

test/runtests.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
1+
using LinearAlgebra
12
using Random
23
using Test
3-
using ApproximateGPs
4-
using Flux
5-
using IterTools
6-
using AbstractGPs
7-
using AbstractGPs: LatentFiniteGP, TestUtils
8-
using Distributions
9-
using LogExpFunctions: logistic
10-
using LinearAlgebra
11-
using PDMats
12-
using Optim
13-
using Zygote
4+
145
using ChainRulesCore
156
using ChainRulesTestUtils
7+
using Distributions
168
using FiniteDifferences
9+
using Flux: Flux
10+
using IterTools
11+
using LogExpFunctions: softplus
12+
using Optim
13+
using PDMats
14+
using Zygote
15+
16+
using AbstractGPs
17+
using ApproximateGPs
1718
using ApproximateGPs: SparseVariationalApproximationModule, LaplaceApproximationModule
1819

1920
# Writing tests:

0 commit comments

Comments
 (0)