Skip to content

Commit c6ef433

Browse files
Add unified interface test utils (#19)
* Add unified interface test utils * Clean up comments * Add docstring * Apply suggestions from code review Co-authored-by: willtebbutt <[email protected]> * Add documentation for arguments * Reversing patch bump made last time. * Remove irrelevant comment Co-authored-by: willtebbutt <[email protected]>
1 parent 775140a commit c6ef433

File tree

7 files changed

+64
-64
lines changed

7 files changed

+64
-64
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GPLikelihoods"
22
uuid = "6031954c-0455-49d7-b3b9-3e1c99afaf40"
33
authors = ["willtebbutt <[email protected]>"]
4-
version = "0.1.1"
4+
version = "0.1.0"
55

66
[deps]
77
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"

test/likelihoods/bernoulli.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,4 @@
11
@testset "BernoulliLikelihood" begin
2-
rng = MersenneTwister(123)
3-
gp = GP(SqExponentialKernel())
4-
x = rand(rng, 10)
5-
y = rand(rng, 10)
62
lik = BernoulliLikelihood()
7-
lgp = LatentGP(gp, lik, 1e-5)
8-
lfgp = lgp(x)
9-
10-
@test typeof(lik(rand(rng, lfgp.fx))) <: Distribution
11-
@test length(rand(rng, lik(rand(rng, lfgp.fx)))) == 10
12-
@test Functors.functor(lik)[1] == ()
3+
test_interface(lik, SqExponentialKernel(), rand(10))
134
end

test/likelihoods/categorical.jl

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,8 @@
11
@testset "CategoricalLikelihood" begin
2-
rng = MersenneTwister(123)
3-
gp = GP(IndependentMOKernel(SqExponentialKernel()))
2+
lik = CategoricalLikelihood()
43
IN_DIM = 3
54
OUT_DIM = 4
65
N = 10
7-
x = [rand(rng, IN_DIM) for _=1:N]
8-
X = MOInput(x, OUT_DIM)
9-
lik = CategoricalLikelihood()
10-
lgp = LatentGP(gp, lik, 1e-5)
11-
lfgp = lgp(X)
12-
13-
Y = rand(rng, lfgp.fx)
14-
15-
y = [Y[[i + j*N for j in 0:(OUT_DIM - 1)]] for i in 1:N]
16-
# Replace with mo_inverse_transform once it is merged
17-
18-
@test length(lik(rand(3)).p) == 4
19-
@test lik(y) isa Distribution
20-
@test length(rand(rng, lik(y))) == 10
21-
@test Functors.functor(lik)[1] == ()
6+
X = MOInput([rand(IN_DIM) for _ in 1:N], OUT_DIM)
7+
test_interface(lik, IndependentMOKernel(SqExponentialKernel()), X)
228
end

test/likelihoods/gaussian.jl

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,13 @@
11
@testset "GaussianLikelihood" begin
2-
rng = MersenneTwister(123)
3-
gp = GP(SqExponentialKernel())
4-
x = rand(rng, 10)
5-
y = rand(rng, 10)
62
lik = GaussianLikelihood(1e-5)
7-
lgp = LatentGP(gp, lik, 1e-5)
8-
lfgp = lgp(x)
9-
10-
@test lik(rand(rng, lfgp.fx)) isa Distribution
11-
@test length(rand(rng, lik(rand(rng, lfgp.fx)))) == 10
12-
@test keys(Functors.functor(lik)[1]) == (:σ²,)
3+
test_interface(lik, SqExponentialKernel(), rand(10); functor_args=(:σ²,))
134
end
145

156
@testset "HeteroscedasticGaussianLikelihood" begin
16-
rng = MersenneTwister(123)
17-
gp = GP(IndependentMOKernel(SqExponentialKernel()))
7+
lik = HeteroscedasticGaussianLikelihood()
188
IN_DIM = 3
199
OUT_DIM = 2 # one for the mean the other for the log-standard deviation
2010
N = 10
21-
x = [rand(rng, IN_DIM) for _ in 1:N]
22-
X = MOInput(x, OUT_DIM)
23-
lik = HeteroscedasticGaussianLikelihood()
24-
lgp = LatentGP(gp, lik, 1e-5)
25-
lfgp = lgp(X)
26-
27-
Y = rand(rng, lfgp.fx)
28-
29-
y = [Y[[i + j*N for j in 0:(OUT_DIM - 1)]] for i in 1:N]
30-
# Replace with mo_inverse_transform once it is merged
31-
32-
@test lik(y) isa Distribution
33-
@test length(rand(rng, lik(y))) == 10
34-
@test Functors.functor(lik)[1] == ()
11+
X = MOInput([rand(IN_DIM) for _ in 1:N], OUT_DIM)
12+
test_interface(lik, IndependentMOKernel(SqExponentialKernel()), X)
3513
end

test/likelihoods/poisson.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,4 @@
11
@testset "PoissonLikelihood" begin
2-
rng = MersenneTwister(123)
3-
gp = GP(SqExponentialKernel())
4-
x = rand(rng, 10)
5-
y = rand(rng, 10)
62
lik = PoissonLikelihood()
7-
lgp = LatentGP(gp, lik, 1e-5)
8-
lfgp = lgp(x)
9-
10-
@test lik(rand(rng, lfgp.fx)) isa Distribution
11-
@test length(rand(rng, lik(rand(rng, lfgp.fx)))) == 10
12-
@test Functors.functor(lik)[1] == ()
3+
test_interface(lik, SqExponentialKernel(), rand(10))
134
end

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ using Distributions
77

88
@testset "GPLikelihoods.jl" begin
99

10+
include("test_utils.jl")
11+
1012
@testset "likelihoods" begin
1113
include("likelihoods/bernoulli.jl")
1214
include("likelihoods/categorical.jl")

test/test_utils.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
function test_interface(
2+
rng::AbstractRNG, lik, k::Kernel, x::AbstractVector; functor_args=(),
3+
)
4+
gp = GP(k)
5+
lgp = LatentGP(gp, lik, 1e-5)
6+
lfgp = lgp(x)
7+
8+
# Check if likelihood produces a distribution
9+
@test lik(rand(rng, lfgp.fx)) isa Distribution
10+
11+
N = length(x)
12+
y = rand(rng, lfgp.fx)
13+
14+
if x isa MOInput
15+
# TODO: replace with mo_inverse_transform
16+
N = length(x.x)
17+
y = [y[[i + j*N for j in 0:(x.out_dim - 1)]] for i in 1:N]
18+
end
19+
20+
# Check if the likelihood samples are of correct length
21+
@test length(rand(rng, lik(y))) == N
22+
23+
# Check if functor works properly
24+
if functor_args == ()
25+
@test Functors.functor(lik)[1] == functor_args
26+
else
27+
@test keys(Functors.functor(lik)[1]) == functor_args
28+
end
29+
end
30+
31+
"""
32+
test_interface(lik, k::Kernel, x::AbstractVector; functor_args=())
33+
34+
This function provides unified method to check the interface of the various likelihoods
35+
defined. It checks if the likelihood produces a distribution, length of likelihood
36+
samples is correct and if the functor works as intended.
37+
...
38+
# Arguments
39+
- `lik`: the likelihood to test the interface of
40+
- `k::Kernel`: the kernel to use for the GP
41+
- `x::AbstractVector`: intputs to compute the likelihood on
42+
- `functor_args=()`: a collection of symbols of arguments to match functor parameters with.
43+
...
44+
"""
45+
function test_interface(
46+
lik,
47+
k::KernelFunctions.Kernel,
48+
x::AbstractVector;
49+
kwargs...
50+
)
51+
test_interface(Random.GLOBAL_RNG, lik, k, x; kwargs...)
52+
end

0 commit comments

Comments
 (0)