Skip to content

Commit 879cfac

Browse files
willtebbuttWill Tebbutttheogfdevmotion
authored
test utils revamp (#159)
* Fix style * Fix convention * First pass over test set implementation * Add standardised tests to BaseKernels * Test composite kernels * Fix some tests * Fix maha * Fix sm * Fix up maha * Remove redundant file * Move existing test utils over to module * Add Gamma Exponential kernel reference * Update src/matrix/kernelpdmat.jl Co-authored-by: Théo Galy-Fajou <[email protected]> * Remove repeated code * Warn about breaking change * Update src/test_utils.jl Co-authored-by: David Widmann <[email protected]> * Bump patch * Fix up tests * Remove dead space * Fix rational quadratic parameter test * Fix some style issues * Add extra parameter check * Update src/basekernels/rationalquad.jl Co-authored-by: David Widmann <[email protected]> * Tweak check * Fix RQ convention to match EQ * Refactor tests * Fix nn issues * Fix weird printing issue * Update src/test_utils.jl Co-authored-by: David Widmann <[email protected]> * Update test/kernels/kernelsum.jl Co-authored-by: Théo Galy-Fajou <[email protected]> * Update src/test_utils.jl Co-authored-by: Théo Galy-Fajou <[email protected]> * Test FBM kernel * Fix up Gabor * Loosen dof bound * Perturb test Co-authored-by: Will Tebbutt <[email protected]> Co-authored-by: Théo Galy-Fajou <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent 4dd8ac0 commit 879cfac

32 files changed

+412
-394
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
88
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
99
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1010
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
11+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1112
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1213
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1314
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1415
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
16+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1517
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
1618

1719
[compat]

src/KernelFunctions.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ export NystromFact, nystrom
4444

4545
export spectral_mixture_kernel, spectral_mixture_product_kernel
4646

47+
export ColVecs, RowVecs
48+
4749
export MOInput
4850
export IndependentMOKernel, LatentFactorMOKernel
4951

@@ -108,6 +110,8 @@ include(joinpath("mokernels", "slfm.jl"))
108110

109111
include("zygote_adjoints.jl")
110112

113+
include("test_utils.jl")
114+
111115
function __init__()
112116
@require Kronecker="2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin
113117
include(joinpath("matrix", "kernelkroneckermat.jl"))

src/basekernels/gabor.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,10 @@ end
5757

5858
Base.show(io::IO, κ::GaborKernel) = print(io, "Gabor Kernel (ell = ", κ.ell, ", p = ", κ.p, ")")
5959

60-
function kernelmatrix::GaborKernel, X::AbstractMatrix; obsdim::Int=defaultobs)
61-
return kernelmatrix.kernel, X; obsdim=obsdim)
62-
end
60+
kernelmatrix::GaborKernel, x::AbstractVector) = kernelmatrix.kernel, x)
6361

64-
function kernelmatrix(
65-
κ::GaborKernel, X::AbstractMatrix, Y::AbstractMatrix;
66-
obsdim::Int=defaultobs,
67-
)
68-
return kernelmatrix.kernel, X, Y; obsdim=obsdim)
62+
function kernelmatrix::GaborKernel, x::AbstractVector, y::AbstractVector)
63+
return kernelmatrix.kernel, x, y)
6964
end
7065

71-
function kerneldiagmatrix::GaborKernel, X::AbstractMatrix; obsdim::Int=defaultobs) #TODO Add test
72-
return kerneldiagmatrix.kernel, X; obsdim=obsdim)
73-
end
66+
kerneldiagmatrix::GaborKernel, x::AbstractVector) = kerneldiagmatrix.kernel, x)

src/basekernels/nn.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ function kernelmatrix(::NeuralNetworkKernel, x::RowVecs, y::RowVecs)
4242
X_2 = sum(x.X .* x.X; dims=2)
4343
Y_2 = sum(y.X .* y.X; dims=2)
4444
XY = x.X * y.X'
45-
return asin.(XY ./ sqrt.((X_2 .+ 1)' * (Y_2 .+ 1)))
45+
return asin.(XY ./ sqrt.((X_2 .+ 1) * (Y_2 .+ 1)'))
4646
end
4747

4848
function kernelmatrix(::NeuralNetworkKernel, x::RowVecs)
4949
X_2_1 = sum(x.X .* x.X; dims=2) .+ 1
5050
XX = x.X * x.X'
51-
return asin.(XX ./ sqrt.(X_2_1' * X_2_1))
51+
return asin.(XX ./ sqrt.(X_2_1 * X_2_1'))
5252
end
5353

5454
Base.show(io::IO, κ::NeuralNetworkKernel) = print(io, "Neural Network Kernel")

src/basekernels/periodic.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,6 @@ metric(κ::PeriodicKernel) = Sinus(κ.r)
2626

2727
kappa::PeriodicKernel, d::Real) = exp(- 0.5d)
2828

29-
Base.show(io::IO, κ::PeriodicKernel) = print(io, "Periodic Kernel (length(r) = ", length.r), ")")
29+
function Base.show(io::IO, κ::PeriodicKernel)
30+
print(io, "Periodic Kernel, length(r) = $(length.r))")
31+
end

src/basekernels/rationalquad.jl

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,62 @@
11
"""
2-
RationalQuadraticKernel(; α = 2.0)
2+
RationalQuadraticKernel(; α=2.0)
33
44
The rational-quadratic kernel is a Mercer kernel given by the formula:
55
```
6-
κ(x,y)=(1+||xy||²)^(-α)
6+
κ(x, y) = (1 + ||xy||² / (2α))^(-α)
77
```
8-
where `α` is a shape parameter of the Euclidean distance. Check [`GammaRationalQuadraticKernel`](@ref) for a generalization.
8+
where `α` is a shape parameter of the Euclidean distance. Check
9+
[`GammaRationalQuadraticKernel`](@ref) for a generalization.
910
"""
1011
struct RationalQuadraticKernel{Tα<:Real} <: SimpleKernel
1112
α::Vector{Tα}
1213
function RationalQuadraticKernel(;alpha::T=2.0, α::T=alpha) where {T}
13-
@check_args(RationalQuadraticKernel, α, α > zero(T), "α > 1")
14+
@check_args(RationalQuadraticKernel, α, α > zero(T), "α > 0")
1415
return new{T}([α])
1516
end
1617
end
1718

1819
@functor RationalQuadraticKernel
1920

20-
kappa::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+/first.α))^(-first.α))
21+
function kappa::RationalQuadraticKernel, d²::T) where {T<:Real}
22+
return (one(T) +/ (2 * first.α)))^(-first.α))
23+
end
24+
2125
metric(::RationalQuadraticKernel) = SqEuclidean()
2226

23-
Base.show(io::IO, κ::RationalQuadraticKernel) = print(io, "Rational Quadratic Kernel (α = ", first.α), ")")
27+
function Base.show(io::IO, κ::RationalQuadraticKernel)
28+
print(io, "Rational Quadratic Kernel (α = $(first.α)))")
29+
end
2430

2531
"""
26-
`GammaRationalQuadraticKernel([ρ=1.0[,α=2.0[,γ=2.0]]])`
32+
`GammaRationalQuadraticKernel([α=2.0 [, γ=2.0]])`
33+
2734
The Gamma-rational-quadratic kernel is an isotropic Mercer kernel given by the formula:
2835
```
29-
κ(x,y)=(1+ρ^(2γ)||x−y||^(2γ)/α)^(-α)
36+
κ(x, y) = (1 + ||x−y||^γ / α)^(-α)
3037
```
3138
where `α` is a shape parameter of the Euclidean distance and `γ` is another shape parameter.
3239
"""
3340
struct GammaRationalQuadraticKernel{Tα<:Real, Tγ<:Real} <: SimpleKernel
3441
α::Vector{Tα}
3542
γ::Vector{Tγ}
36-
function GammaRationalQuadraticKernel(;alpha::Tα=2.0, gamma::Tγ=2.0, α::Tα=alpha, γ::Tγ=gamma) where {Tα<:Real, Tγ<:Real}
37-
@check_args(GammaRationalQuadraticKernel, α, α > one(Tα), "α > 1")
38-
@check_args(GammaRationalQuadraticKernel, γ, γ >= one(Tγ), "γ >= 1")
43+
function GammaRationalQuadraticKernel(
44+
;alpha::Tα=2.0, gamma::Tγ=2.0, α::Tα=alpha, γ::Tγ=gamma,
45+
) where {Tα<:Real, Tγ<:Real}
46+
@check_args(GammaRationalQuadraticKernel, α, α > zero(Tα), "α > 0")
47+
@check_args(GammaRationalQuadraticKernel, γ, zero(γ) < γ <= 2, "0 < γ <= 2")
3948
return new{Tα, Tγ}([α], [γ])
4049
end
4150
end
4251

4352
@functor GammaRationalQuadraticKernel
4453

45-
kappa::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+^first.γ)/first.α))^(-first.α))
54+
function kappa::GammaRationalQuadraticKernel, d²::Real)
55+
return (one(d²) +^(first.γ) / 2) / first.α))^(-first.α))
56+
end
57+
4658
metric(::GammaRationalQuadraticKernel) = SqEuclidean()
4759

48-
Base.show(io::IO, κ::GammaRationalQuadraticKernel) = print(io, "Gamma Rational Quadratic Kernel (α = ", first.α), ", γ = ", first.γ), ")")
60+
function Base.show(io::IO, κ::GammaRationalQuadraticKernel)
61+
print(io, "Gamma Rational Quadratic Kernel (α = $(first.α)), γ = $(first.γ)))")
62+
end

src/basekernels/sm.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ function spectral_mixture_kernel(
5454
γs::AbstractMatrix{<:Real},
5555
ωs::AbstractMatrix{<:Real}
5656
)
57-
spectral_mixture_kernel(SqExponentialKernel(), αs, γs, ωs)
57+
return spectral_mixture_kernel(SqExponentialKernel(), αs, γs, ωs)
5858
end
5959

6060
"""
@@ -95,14 +95,14 @@ function spectral_mixture_product_kernel(
9595
throw(DimensionMismatch("The dimensions of αs, γs, ans ωs do not match"))
9696
end
9797
return TensorProduct(spectral_mixture_kernel(h, α, reshape(γ, 1, :), reshape(ω, 1, :))
98-
for (α, γ, ω) in zip(eachrow(αs), eachrow(γs), eachrow(ωs)))
98+
for (α, γ, ω) in zip(eachrow(αs), eachrow(γs), eachrow(ωs)))
9999
end
100100

101101
function spectral_mixture_product_kernel(
102102
αs::AbstractMatrix{<:Real},
103103
γs::AbstractMatrix{<:Real},
104104
ωs::AbstractMatrix{<:Real}
105105
)
106-
spectral_mixture_product_kernel(SqExponentialKernel(), αs, γs, ωs)
106+
return spectral_mixture_product_kernel(SqExponentialKernel(), αs, γs, ωs)
107107
end
108108

src/test_utils.jl

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
module TestUtils
2+
3+
const __ATOL = 1e-9
4+
5+
using LinearAlgebra
6+
using KernelFunctions
7+
using Random
8+
using Test
9+
10+
"""
11+
test_interface(
12+
k::Kernel,
13+
x0::AbstractVector,
14+
x1::AbstractVector,
15+
x2::AbstractVector;
16+
atol=__ATOL,
17+
)
18+
19+
Run various consistency checks on `k` at the inputs `x0`, `x1`, and `x2`.
20+
`x0` and `x1` should be of the same length with different values, while `x0` and `x2` should
21+
be of different lengths.
22+
23+
test_interface([rng::AbstractRNG], k::Kernel, T::Type{<:AbstractVector}; atol=__ATOL)
24+
25+
`test_interface` offers certain types of test data generation to make running these tests
26+
require less code for common input types. For example, `Vector{<:Real}`, `ColVecs{<:Real}`,
27+
and `RowVecs{<:Real}` are supported. For other input vector types, please provide the data
28+
manually.
29+
"""
30+
function test_interface(
31+
k::Kernel,
32+
x0::AbstractVector,
33+
x1::AbstractVector,
34+
x2::AbstractVector;
35+
atol=__ATOL,
36+
)
37+
# TODO: uncomment the tests of ternary kerneldiagmatrix.
38+
39+
# Ensure that we have the required inputs.
40+
@assert length(x0) == length(x1)
41+
@assert length(x0) length(x2)
42+
43+
# Check that kerneldiagmatrix basically works.
44+
# @test kerneldiagmatrix(k, x0, x1) isa AbstractVector
45+
# @test length(kerneldiagmatrix(k, x0, x1)) == length(x0)
46+
47+
# Check that pairwise basically works.
48+
@test kernelmatrix(k, x0, x2) isa AbstractMatrix
49+
@test size(kernelmatrix(k, x0, x2)) == (length(x0), length(x2))
50+
51+
# Check that elementwise is consistent with pairwise.
52+
# @test kerneldiagmatrix(k, x0, x1) ≈ diag(kernelmatrix(k, x0, x1)) atol=atol
53+
54+
# Check additional binary elementwise properties for kernels.
55+
# @test kerneldiagmatrix(k, x0, x1) ≈ kerneldiagmatrix(k, x1, x0)
56+
@test kernelmatrix(k, x0, x2) kernelmatrix(k, x2, x0)' atol=atol
57+
58+
# Check that unary elementwise basically works.
59+
@test kerneldiagmatrix(k, x0) isa AbstractVector
60+
@test length(kerneldiagmatrix(k, x0)) == length(x0)
61+
62+
# Check that unary pairwise basically works.
63+
@test kernelmatrix(k, x0) isa AbstractMatrix
64+
@test size(kernelmatrix(k, x0)) == (length(x0), length(x0))
65+
@test kernelmatrix(k, x0) kernelmatrix(k, x0)' atol=atol
66+
67+
# Check that unary elementwise is consistent with unary pairwise.
68+
@test kerneldiagmatrix(k, x0) diag(kernelmatrix(k, x0)) atol=atol
69+
70+
# Check that unary pairwise produces a positive definite matrix (approximately).
71+
@test eigmin(Matrix(kernelmatrix(k, x0))) > -atol
72+
73+
# Check that unary elementwise / pairwise are consistent with the binary versions.
74+
# @test kerneldiagmatrix(k, x0) ≈ kerneldiagmatrix(k, x0, x0) atol=atol
75+
@test kernelmatrix(k, x0) kernelmatrix(k, x0, x0) atol=atol
76+
77+
# Check that basic kernel evaluation succeeds and is consistent with `kernelmatrix`.
78+
@test k(first(x0), first(x1)) isa Real
79+
@test kernelmatrix(k, x0, x2) [k(xl, xr) for xl in x0, xr in x2]
80+
81+
tmp = Matrix{Float64}(undef, length(x0), length(x2))
82+
@test kernelmatrix!(tmp, k, x0, x2) kernelmatrix(k, x0, x2)
83+
84+
tmp_square = Matrix{Float64}(undef, length(x0), length(x0))
85+
@test kernelmatrix!(tmp_square, k, x0) kernelmatrix(k, x0)
86+
87+
tmp_diag = Vector{Float64}(undef, length(x0))
88+
@test kerneldiagmatrix!(tmp_diag, k, x0) kerneldiagmatrix(k, x0)
89+
end
90+
91+
function test_interface(
92+
rng::AbstractRNG, k::Kernel, ::Type{Vector{T}}; kwargs...
93+
) where {T<:Real}
94+
test_interface(k, randn(rng, T, 3), randn(rng, T, 3), randn(rng, T, 2); kwargs...)
95+
end
96+
97+
function test_interface(
98+
rng::AbstractRNG, k::Kernel, ::Type{<:ColVecs{T}}; dim_in=2, kwargs...,
99+
) where {T<:Real}
100+
test_interface(
101+
k,
102+
ColVecs(randn(rng, T, dim_in, 3)),
103+
ColVecs(randn(rng, T, dim_in, 3)),
104+
ColVecs(randn(rng, T, dim_in, 2));
105+
kwargs...,
106+
)
107+
end
108+
109+
function test_interface(
110+
rng::AbstractRNG, k::Kernel, ::Type{<:RowVecs{T}}; dim_in=2, kwargs...,
111+
) where {T<:Real}
112+
test_interface(
113+
k,
114+
RowVecs(randn(rng, T, 3, dim_in)),
115+
RowVecs(randn(rng, T, 3, dim_in)),
116+
RowVecs(randn(rng, T, 2, dim_in));
117+
kwargs...,
118+
)
119+
end
120+
121+
function test_interface(k::Kernel, T::Type{<:AbstractVector}; kwargs...)
122+
test_interface(Random.GLOBAL_RNG, k, T; kwargs...)
123+
end
124+
125+
function test_interface(rng::AbstractRNG, k::Kernel, T::Type{<:Real}; kwargs...)
126+
test_interface(rng, k, Vector{T}; kwargs...)
127+
test_interface(rng, k, ColVecs{T}; kwargs...)
128+
test_interface(rng, k, RowVecs{T}; kwargs...)
129+
end
130+
131+
function test_interface(k::Kernel, T::Type{<:Real}=Float64; kwargs...)
132+
test_interface(Random.GLOBAL_RNG, k, T; kwargs...)
133+
end
134+
135+
end # module

src/transform/lineartransform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ end
2929
(t::LinearTransform)(x::Real) = vec(t.A * x)
3030
(t::LinearTransform)(x::AbstractVector{<:Real}) = t.A * x
3131

32-
_map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * x')
32+
_map(t::LinearTransform, x::AbstractVector{<:Real}) = ColVecs(t.A * collect(x'))
3333
_map(t::LinearTransform, x::ColVecs) = ColVecs(t.A * x.X)
3434
_map(t::LinearTransform, x::RowVecs) = RowVecs(x.X * t.A')
3535

src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ struct RowVecs{T, TX<:AbstractMatrix{T}, S} <: AbstractVector{S}
7070
end
7171
end
7272

73+
RowVecs(x::AbstractVector) = RowVecs(reshape(x, :, 1))
74+
7375
Base.size(D::RowVecs) = (size(D.X, 1),)
7476
Base.getindex(D::RowVecs, i::Int) = view(D.X, i, :)
7577
Base.getindex(D::RowVecs, i::CartesianIndex{1}) = view(D.X, i, :)

0 commit comments

Comments
 (0)