Skip to content

Commit 6c4c257

Browse files
committed
create ApproximateGPs.TestUtils
1 parent 2627feb commit 6c4c257

File tree

4 files changed

+83
-63
lines changed

4 files changed

+83
-63
lines changed

src/ApproximateGPs.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,6 @@ include("LaplaceApproximationModule.jl")
2323

2424
include("deprecations.jl")
2525

26+
include("TestUtils.jl")
27+
2628
end

src/TestUtils.jl

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
module TestUtils
2+
3+
using LinearAlgebra
4+
using Random
5+
using Test
6+
7+
using Distributions
8+
using LogExpFunctions: logistic
9+
10+
using AbstractGPs
11+
using ApproximateGPs
12+
13+
function generate_data()
14+
X = range(0, 23.5; length=48)
15+
# The random number generator changed in 1.6->1.7. The following vector was generated in Julia 1.6.
16+
# The generating code below is only kept for illustrative purposes.
17+
#! format: off
18+
Y = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
19+
#! format: on
20+
# Random.seed!(1)
21+
# fs = @. 3 * sin(10 + 0.6X) + sin(0.1X) - 1
22+
# # invlink = normcdf
23+
# invlink = logistic
24+
# ps = invlink.(fs)
25+
# Y = @. rand(Bernoulli(ps))
26+
return X, Y
27+
end
28+
29+
dist_y_given_f(f) = Bernoulli(logistic(f))
30+
31+
function build_latent_gp(theta)
32+
variance = softplus(theta[1])
33+
lengthscale = softplus(theta[2])
34+
kernel = variance * with_lengthscale(SqExponentialKernel(), lengthscale)
35+
return LatentGP(GP(kernel), dist_y_given_f, 1e-8)
36+
end
37+
38+
function test_approximation_predictions(approx)
39+
rng = MersenneTwister(123456)
40+
N_cond = 5
41+
N_a = 6
42+
N_b = 7
43+
44+
# Specify prior.
45+
f = GP(Matern32Kernel())
46+
# Sample from prior.
47+
x = collect(range(-1.0, 1.0; length=N_cond))
48+
noise_scale = 0.1
49+
fx = f(x, noise_scale^2)
50+
y = rand(rng, fx)
51+
52+
jitter = 0.0 # not needed in Gaussian case
53+
lf = LatentGP(f, f -> Normal(f, noise_scale), jitter)
54+
f_approx_post = posterior(approx, lf(x), y)
55+
56+
@testset "AbstractGPs API" begin
57+
a = collect(range(-1.2, 1.2; length=N_a))
58+
b = randn(rng, N_b)
59+
AbstractGPs.TestUtils.test_internal_abstractgps_interface(rng, f_approx_post, a, b)
60+
end
61+
62+
@testset "exact GPR equivalence for Gaussian likelihood" begin
63+
f_exact_post = posterior(f(x, noise_scale^2), y)
64+
xt = vcat(x, randn(rng, 3)) # test at training and new points
65+
66+
m_approx, c_approx = mean_and_cov(f_approx_post(xt))
67+
m_exact, c_exact = mean_and_cov(f_exact_post(xt))
68+
69+
@test m_approx m_exact
70+
@test c_approx c_exact
71+
end
72+
end
73+
74+
end

test/LaplaceApproximationModule.jl

Lines changed: 5 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,7 @@
11
@testset "laplace" begin
2-
function generate_data()
3-
X = range(0, 23.5; length=48)
4-
# The random number generator changed in 1.6->1.7. The following vector was generated in Julia 1.6.
5-
# The generating code below is only kept for illustrative purposes.
6-
#! format: off
7-
Y = [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0]
8-
#! format: on
9-
# Random.seed!(1)
10-
# fs = @. 3 * sin(10 + 0.6X) + sin(0.1X) - 1
11-
# # invlink = normcdf
12-
# invlink = logistic
13-
# ps = invlink.(fs)
14-
# Y = [rand(Bernoulli(p)) for p in ps]
15-
return X, Y
16-
end
17-
18-
dist_y_given_f(f) = Bernoulli(logistic(f))
19-
20-
function build_latent_gp(theta)
21-
variance = softplus(theta[1])
22-
lengthscale = softplus(theta[2])
23-
kernel = variance * with_lengthscale(SqExponentialKernel(), lengthscale)
24-
return LatentGP(GP(kernel), dist_y_given_f, 1e-8)
25-
end
2+
generate_data = ApproximateGPs.TestUtils.generate_data
3+
dist_y_given_f = ApproximateGPs.TestUtils.dist_y_given_f
4+
build_latent_gp = ApproximateGPs.TestUtils.build_latent_gp
265

276
function optimize_elbo(
287
build_latent_gp,
@@ -49,43 +28,8 @@
4928
end
5029

5130
@testset "predictions" begin
52-
rng = MersenneTwister(123456)
53-
N_cond = 5
54-
N_a = 6
55-
N_b = 7
56-
57-
# Specify prior.
58-
f = GP(Matern32Kernel())
59-
# Sample from prior.
60-
x = collect(range(-1.0, 1.0; length=N_cond))
61-
noise_scale = 0.1
62-
fx = f(x, noise_scale^2)
63-
y = rand(rng, fx)
64-
65-
jitter = 0.0 # not needed in Gaussian case
66-
lf = LatentGP(f, f -> Normal(f, noise_scale), jitter)
67-
# in Gaussian case, Laplace converges to f_opt in one step; we need the
68-
# second step to compute the cache at f_opt rather than f_init!
69-
f_approx_post = posterior(LaplaceApproximation(; maxiter=2), lf(x), y)
70-
71-
@testset "AbstractGPs API" begin
72-
a = collect(range(-1.2, 1.2; length=N_a))
73-
b = randn(rng, N_b)
74-
AbstractGPs.TestUtils.test_internal_abstractgps_interface(
75-
rng, f_approx_post, a, b
76-
)
77-
end
78-
79-
@testset "equivalence to exact GPR for Gaussian likelihood" begin
80-
f_exact_post = posterior(f(x, noise_scale^2), y)
81-
xt = vcat(x, randn(rng, 3)) # test at training and new points
82-
83-
m_approx, c_approx = mean_and_cov(f_approx_post(xt))
84-
m_exact, c_exact = mean_and_cov(f_exact_post(xt))
85-
86-
@test m_approx m_exact
87-
@test c_approx c_exact
88-
end
31+
approx = LaplaceApproximation(; maxiter=2)
32+
ApproximateGPs.TestUtils.test_approximation_predictions(approx)
8933
end
9034

9135
@testset "gradients" begin

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
using Random
22
using Test
3-
using ApproximateGPs
43
using Flux
54
using IterTools
65
using AbstractGPs
7-
using AbstractGPs: LatentFiniteGP, TestUtils
86
using Distributions
97
using LogExpFunctions: logistic
108
using LinearAlgebra
@@ -14,6 +12,8 @@ using Zygote
1412
using ChainRulesCore
1513
using ChainRulesTestUtils
1614
using FiniteDifferences
15+
16+
using ApproximateGPs
1717
using ApproximateGPs: SparseVariationalApproximationModule, LaplaceApproximationModule
1818

1919
# Writing tests:

0 commit comments

Comments
 (0)