Skip to content

Commit 254be63

Browse files
committed
Corrected tests
1 parent 909dfd8 commit 254be63

File tree

2 files changed

+110
-95
lines changed

2 files changed

+110
-95
lines changed

test/test_custom.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,10 @@ KernelFunctions.metric(::MyKernel) = SqEuclidean()
1414
@test kernelmatrix(MyKernel(), [1 2; 3 4]) == kernelmatrix(SqExponentialKernel(), [1 2; 3 4])
1515

1616
# some syntactic sugar
17-
::MyKernel)(d::Real) = kappa(κ, d)
1817
::MyKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) = kappa(κ, x, y)
1918
::MyKernel)(X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}; obsdim = 2) = kernelmatrix(κ, X, Y; obsdim = obsdim)
2019
::MyKernel)(X::AbstractMatrix{<:Real}; obsdim = 2) = kernelmatrix(κ, X; obsdim = obsdim)
2120

22-
@test MyKernel()(3) == SqExponentialKernel()(3)
2321
@test MyKernel()([1, 2], [3, 4]) == SqExponentialKernel()([1, 2], [3, 4])
2422
@test MyKernel()([1 2; 3 4], [5 6; 7 8]) == SqExponentialKernel()([1 2; 3 4], [5 6; 7 8])
2523
@test MyKernel()([1 2; 3 4]) == SqExponentialKernel()([1 2; 3 4])

test/test_kernels.jl

Lines changed: 110 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -3,183 +3,200 @@ using LinearAlgebra
33
using KernelFunctions
44
using SpecialFunctions
55

6-
x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
6+
x = rand() * 2;
7+
v1 = rand(3);
8+
v2 = rand(3);
9+
id = IdentityTransform();
710
@testset "Kappa functions of kernels" begin
811
@testset "Constant" begin
912
@testset "ZeroKernel" begin
1013
k = ZeroKernel()
1114
@test eltype(k) == Any
12-
@test kappa(k,2.0) == 0.0
15+
@test kappa(k, 2.0) == 0.0
1316
end
1417
@testset "WhiteKernel" begin
1518
k = WhiteKernel()
1619
@test eltype(k) == Any
17-
@test kappa(k,1.0) == 1.0
18-
@test kappa(k,0.0) == 0.0
20+
@test kappa(k, 1.0) == 1.0
21+
@test kappa(k, 0.0) == 0.0
1922
@test EyeKernel == WhiteKernel
2023
end
2124
@testset "ConstantKernel" begin
2225
c = 2.0
23-
k = ConstantKernel(c=c)
26+
k = ConstantKernel(c = c)
2427
@test eltype(k) == Any
25-
@test kappa(k,1.0) == c
26-
@test kappa(k,0.5) == c
28+
@test kappa(k, 1.0) == c
29+
@test kappa(k, 0.5) == c
2730
end
2831
end
2932
@testset "Cosine" begin
3033
k = CosineKernel()
3134
@test eltype(k) == Any
32-
@test kappa(k, 1.0) -1.0 atol=1e-5
33-
@test kappa(k, 2.0) 1.0 atol=1e-5
34-
@test kappa(k, 1.5) 0.0 atol=1e-5
35-
@test kappa(k,x) cospi(x) atol=1e-5
36-
@test k(v1, v2) cospi(sqrt(sum(abs2.(v1-v2)))) atol=1e-5
35+
@test kappa(k, 1.0) -1.0 atol = 1e-5
36+
@test kappa(k, 2.0) 1.0 atol = 1e-5
37+
@test kappa(k, 1.5) 0.0 atol = 1e-5
38+
@test kappa(k, x) cospi(x) atol = 1e-5
39+
@test k(v1, v2) cospi(sqrt(sum(abs2.(v1 - v2)))) atol = 1e-5
3740
end
3841
@testset "Exponential" begin
3942
@testset "SqExponentialKernel" begin
4043
k = SqExponentialKernel()
41-
@test kappa(k,x) exp(-x)
42-
@test k(v1,v2) exp(-norm(v1-v2)^2)
43-
@test kappa(SqExponentialKernel(),x) == kappa(k,x)
44+
@test kappa(k, x) exp(-x)
45+
@test k(v1, v2) exp(-norm(v1 - v2)^2)
46+
@test kappa(SqExponentialKernel(), x) == kappa(k, x)
4447
end
4548
@testset "ExponentialKernel" begin
4649
k = ExponentialKernel()
47-
@test kappa(k,x) exp(-x)
48-
@test k(v1,v2) exp(-norm(v1-v2))
49-
@test kappa(ExponentialKernel(),x) == kappa(k,x)
50+
@test kappa(k, x) exp(-x)
51+
@test k(v1, v2) exp(-norm(v1 - v2))
52+
@test kappa(ExponentialKernel(), x) == kappa(k, x)
5053
end
5154
@testset "GammaExponentialKernel" begin
5255
γ = 2.0
53-
k = GammaExponentialKernel=γ)
54-
@test kappa(k,x) exp(-(x)^(γ))
55-
@test k(v1,v2) exp(-norm(v1-v2)^(2γ))
56-
@test kappa(GammaExponentialKernel(),x) == kappa(k,x)
57-
@test GammaExponentialKernel(gamma=γ).γ == [γ]
56+
k = GammaExponentialKernel = γ)
57+
@test kappa(k, x) exp(-(x)^(γ))
58+
@test k(v1, v2) exp(-norm(v1 - v2)^(2γ))
59+
@test kappa(GammaExponentialKernel(), x) == kappa(k, x)
60+
@test GammaExponentialKernel(gamma = γ).γ == [γ]
5861
#Coherence :
59-
@test KernelFunctions._kernel(GammaExponentialKernel=1.0),v1,v2) KernelFunctions._kernel(SqExponentialKernel(),v1,v2)
60-
@test KernelFunctions._kernel(GammaExponentialKernel=0.5),v1,v2) KernelFunctions._kernel(ExponentialKernel(),v1,v2)
62+
@test KernelFunctions._kernel(
63+
GammaExponentialKernel= 1.0),
64+
v1,
65+
v2,
66+
) KernelFunctions._kernel(SqExponentialKernel(), v1, v2)
67+
@test KernelFunctions._kernel(
68+
GammaExponentialKernel= 0.5),
69+
v1,
70+
v2,
71+
) KernelFunctions._kernel(ExponentialKernel(), v1, v2)
6172
end
6273
end
6374
@testset "Exponentiated" begin
6475
@testset "ExponentiatedKernel" begin
6576
k = ExponentiatedKernel()
66-
@test kappa(k,x) exp(x)
67-
@test kappa(k,-x) exp(-x)
68-
@test k(v1,v2) exp(dot(v1,v2))
77+
@test kappa(k, x) exp(x)
78+
@test kappa(k, -x) exp(-x)
79+
@test k(v1, v2) exp(dot(v1, v2))
6980
end
7081
end
7182
@testset "Matern" begin
7283
@testset "MaternKernel" begin
7384
ν = 2.0
74-
k = MaternKernel=ν)
75-
matern(x,ν) = 2^(1-ν)/gamma(ν)*(sqrt(2ν)*x)^ν*besselk(ν,sqrt(2ν)*x)
76-
@test MaternKernel(nu=ν).ν == [ν]
77-
@test kappa(k,x) matern(x,ν)
78-
@test kappa(k,0.0) == 1.0
79-
@test kappa(MaternKernel=ν),x) == kappa(k,x)
85+
k = MaternKernel= ν)
86+
matern(x, ν) =
87+
2^(1 - ν) / gamma(ν) *
88+
(sqrt(2ν) * x)^ν *
89+
besselk(ν, sqrt(2ν) * x)
90+
@test MaternKernel(nu = ν).ν == [ν]
91+
@test kappa(k, x) matern(x, ν)
92+
@test kappa(k, 0.0) == 1.0
93+
@test kappa(MaternKernel= ν), x) == kappa(k, x)
8094
end
8195
@testset "Matern32Kernel" begin
8296
k = Matern32Kernel()
83-
@test kappa(k,x) (1+sqrt(3)*x)exp(-sqrt(3)*x)
84-
@test k(v1,v2) (1+sqrt(3)*norm(v1-v2))exp(-sqrt(3)*norm(v1-v2))
85-
@test kappa(Matern32Kernel(),x) == kappa(k,x)
97+
@test kappa(k, x) (1 + sqrt(3) * x) * exp(-sqrt(3) * x)
98+
@test k(v1, v2)
99+
(1 + sqrt(3) * norm(v1 - v2)) * exp(-sqrt(3) * norm(v1 - v2))
100+
@test kappa(Matern32Kernel(), x) == kappa(k, x)
86101
end
87102
@testset "Matern52Kernel" begin
88103
k = Matern52Kernel()
89-
@test kappa(k,x) (1+sqrt(5)*x+5/3*x^2)exp(-sqrt(5)*x)
90-
@test k(v1,v2) (1+sqrt(5)*norm(v1-v2)+5/3*norm(v1-v2)^2)exp(-sqrt(5)*norm(v1-v2))
91-
@test kappa(Matern52Kernel(),x) == kappa(k,x)
104+
@test kappa(k, x)
105+
(1 + sqrt(5) * x + 5 / 3 * x^2) * exp(-sqrt(5) * x)
106+
@test k(v1, v2)
107+
(1 + sqrt(5) * norm(v1 - v2) + 5 / 3 * norm(v1 - v2)^2) *
108+
exp(-sqrt(5) * norm(v1 - v2))
109+
@test kappa(Matern52Kernel(), x) == kappa(k, x)
92110
end
93111
@testset "Coherence Materns" begin
94-
@test kappa(MaternKernel=0.5),x) kappa(ExponentialKernel(),x)
95-
@test kappa(MaternKernel=1.5),x) kappa(Matern32Kernel(),x)
96-
@test kappa(MaternKernel=2.5),x) kappa(Matern52Kernel(),x)
112+
@test kappa(MaternKernel= 0.5), x)
113+
kappa(ExponentialKernel(), x)
114+
@test kappa(MaternKernel= 1.5), x) kappa(Matern32Kernel(), x)
115+
@test kappa(MaternKernel= 2.5), x) kappa(Matern52Kernel(), x)
97116
end
98117
end
99118
@testset "Polynomial" begin
100-
c = randn();
119+
c = randn()
101120
@testset "LinearKernel" begin
102121
k = LinearKernel()
103-
@test kappa(k,x) x
104-
@test k(v1,v2) dot(v1,v2)
105-
@test kappa(LinearKernel(),x) == kappa(k,x)
122+
@test kappa(k, x) x
123+
@test k(v1, v2) dot(v1, v2)
124+
@test kappa(LinearKernel(), x) == kappa(k, x)
106125
end
107126
@testset "PolynomialKernel" begin
108127
k = PolynomialKernel()
109-
@test kappa(k,x) x^2
110-
@test k(v1,v2) dot(v1,v2)^2
111-
@test kappa(PolynomialKernel(),x) == kappa(k,x)
128+
@test kappa(k, x) x^2
129+
@test k(v1, v2) dot(v1, v2)^2
130+
@test kappa(PolynomialKernel(), x) == kappa(k, x)
112131
#Coherence test
113-
@test kappa(PolynomialKernel(d=1.0,c=c),x) kappa(LinearKernel(c=c),x)
132+
@test kappa(PolynomialKernel(d = 1.0, c = c), x)
133+
kappa(LinearKernel(c = c), x)
114134
end
115135
end
116-
@testset "Mahalanobis" begin
117-
P = rand(3,3)
118-
k = MahalanobisKernel(P)
119-
@test kappa(k,x) == exp(-x)
120-
@test k(v1,v2) exp(-sqmahalanobis(v1,v2, k.P))
121-
@test kappa(ExponentialKernel(),x) == kappa(k,x)
122-
end
123136
@testset "RationalQuadratic" begin
124137
@testset "RationalQuadraticKernel" begin
125138
α = 2.0
126-
k = RationalQuadraticKernel=α)
127-
@test RationalQuadraticKernel(alpha=α).α == [α]
128-
@test kappa(k,x) (1.0+x/2.0)^-2
129-
@test k(v1,v2) (1.0+norm(v1-v2)^2/2.0)^-2
130-
@test kappa(RationalQuadraticKernel=α),x) == kappa(k,x)
139+
k = RationalQuadraticKernel = α)
140+
@test RationalQuadraticKernel(alpha = α).α == [α]
141+
@test kappa(k, x) (1.0 + x / 2.0)^-2
142+
@test k(v1, v2) (1.0 + norm(v1 - v2)^2 / 2.0)^-2
143+
@test kappa(RationalQuadraticKernel = α), x) == kappa(k, x)
131144
end
132145
@testset "GammaRationalQuadraticKernel" begin
133146
k = GammaRationalQuadraticKernel()
134-
@test kappa(k,x) (1.0+x^2.0/2.0)^-2
135-
@test k(v1,v2) (1.0+norm(v1-v2)^4.0/2.0)^-2
136-
@test kappa(GammaRationalQuadraticKernel(),x) == kappa(k,x)
147+
@test kappa(k, x) (1.0 + x^2.0 / 2.0)^-2
148+
@test k(v1, v2) (1.0 + norm(v1 - v2)^4.0 / 2.0)^-2
149+
@test kappa(GammaRationalQuadraticKernel(), x) == kappa(k, x)
137150
a = 1.0 + rand()
138-
@test GammaRationalQuadraticKernel(alpha=a).α == [a]
151+
@test GammaRationalQuadraticKernel(alpha = a).α == [a]
139152
#Coherence test
140-
@test kappa(GammaRationalQuadraticKernel=a,γ=1.0),x) kappa(RationalQuadraticKernel=a),x)
153+
@test kappa(GammaRationalQuadraticKernel= a, γ = 1.0), x)
154+
kappa(RationalQuadraticKernel= a), x)
141155
end
142156
end
143157
@testset "Transformed/Scaled Kernel" begin
144158
s = rand()
145159
v = rand(3)
146160
k = SqExponentialKernel()
147-
kt = TransformedKernel(k,ScaleTransform(s))
148-
ktard = TransformedKernel(k,ARDTransform(v))
149-
ks = ScaledKernel(k,s)
150-
@test kappa(kt,v1,v2) == kappa(transform(k,ScaleTransform(s)),v1,v2)
151-
@test kappa(kt,v1,v2) == kappa(transform(k,s),v1,v2)
152-
@test kappa(kt,v1,v2) kappa(k,s*v1,s*v2) atol=1e-5
153-
@test kappa(ktard,v1,v2) kappa(transform(k,ARDTransform(v)),v1,v2) atol=1e-5
154-
@test kappa(ktard,v1,v2) == kappa(transform(k,v),v1,v2)
155-
@test kappa(ktard,v1,v2) == kappa(k,v.*v1,v.*v2)
161+
kt = TransformedKernel(k, ScaleTransform(s))
162+
ktard = TransformedKernel(k, ARDTransform(v))
163+
ks = ScaledKernel(k, s)
164+
@test kappa(kt, v1, v2) ==
165+
kappa(transform(k, ScaleTransform(s)), v1, v2)
166+
@test kappa(kt, v1, v2) == kappa(transform(k, s), v1, v2)
167+
@test kappa(kt, v1, v2) kappa(k, s * v1, s * v2) atol = 1e-5
168+
@test kappa(ktard, v1, v2)
169+
kappa(transform(k, ARDTransform(v)), v1, v2) atol = 1e-5
170+
@test kappa(ktard, v1, v2) == kappa(transform(k, v), v1, v2)
171+
@test kappa(ktard, v1, v2) == kappa(k, v .* v1, v .* v2)
156172
@test KernelFunctions.metric(kt) == KernelFunctions.metric(k)
157-
@test kappa(ks,x) == s*kappa(k,x)
158-
@test kappa(ks,x) == kappa(s*k,x)
173+
@test kappa(ks, x) == s * kappa(k, x)
174+
@test kappa(ks, x) == kappa(s * k, x)
159175
end
160176
@testset "KernelCombinations" begin
161177
k1 = LinearKernel()
162178
k2 = SqExponentialKernel()
163179
k3 = RationalQuadraticKernel()
164-
X = rand(2,2)
180+
X = rand(2, 2)
165181
@testset "KernelSum" begin
166-
w = [2.0,0.5]
167-
k = KernelSum([k1,k2],w)
168-
ks1 = 2.0*k1
169-
ks2 = 0.5*k2
182+
w = [2.0, 0.5]
183+
k = KernelSum([k1, k2], w)
184+
ks1 = 2.0 * k1
185+
ks2 = 0.5 * k2
170186
@test length(k) == 2
171-
@test kappa(k,v1,v2) == kappa(2.0*k1+0.5*k2,v1,v2)
172-
@test kappa(k+k3,v1,v2) kappa(k3+k,v1,v2)
173-
@test kappa(k1+k2,v1,v2) == kappa(KernelSum([k1,k2]),v1,v2)
174-
@test kappa(k+ks1,v1,v2) kappa(ks1+k,v1,v2)
175-
@test kappa(k+k,v1,v2) == kappa(KernelSum([k1,k2,k1,k2],vcat(w,w)),v1,v2)
187+
@test kappa(k, v1, v2) == kappa(2.0 * k1 + 0.5 * k2, v1, v2)
188+
@test kappa(k + k3, v1, v2) kappa(k3 + k, v1, v2)
189+
@test kappa(k1 + k2, v1, v2) == kappa(KernelSum([k1, k2]), v1, v2)
190+
@test kappa(k + ks1, v1, v2) kappa(ks1 + k, v1, v2)
191+
@test kappa(k + k, v1, v2) ==
192+
kappa(KernelSum([k1, k2, k1, k2], vcat(w, w)), v1, v2)
176193
end
177194
@testset "KernelProduct" begin
178-
k = KernelProduct([k1,k2])
195+
k = KernelProduct([k1, k2])
179196
@test length(k) == 2
180-
@test kappa(k,v1,v2) == kappa(k1*k2,v1,v2)
181-
@test kappa(k*k,v1,v2) kappa(k,v1,v2)^2
182-
@test kappa(k*k3,v1,v2) kappa(k3*k,v1,v2)
197+
@test kappa(k, v1, v2) == kappa(k1 * k2, v1, v2)
198+
@test kappa(k * k, v1, v2) kappa(k, v1, v2)^2
199+
@test kappa(k * k3, v1, v2) kappa(k3 * k, v1, v2)
183200
end
184201
end
185202
end

0 commit comments

Comments
 (0)