Skip to content

Commit 960bad2

Browse files
committed
Spread tests for all base kernels
1 parent 07631b6 commit 960bad2

20 files changed

+47
-9
lines changed

test/basekernels/constant.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
@test kappa(k,2.0) == 0.0
66
@test KernelFunctions.metric(ZeroKernel()) == KernelFunctions.Delta()
77
@test repr(k) == "Zero Kernel"
8-
test_AD("Zero", ZeroKernel)
8+
test_ADs(ZeroKernel)
99
end
1010
@testset "WhiteKernel" begin
1111
k = WhiteKernel()
@@ -15,7 +15,7 @@
1515
@test EyeKernel == WhiteKernel
1616
@test metric(WhiteKernel()) == KernelFunctions.Delta()
1717
@test repr(k) == "White Kernel"
18-
test_AD("WhiteKernel", WhiteKernel)
18+
test_ADs(WhiteKernel)
1919
end
2020
@testset "ConstantKernel" begin
2121
c = 2.0
@@ -26,6 +26,6 @@
2626
@test metric(ConstantKernel()) == KernelFunctions.Delta()
2727
@test metric(ConstantKernel(c=2.0)) == KernelFunctions.Delta()
2828
@test repr(k) == "Constant Kernel (c = $(c))"
29-
test_AD("ConstantKernel", c->ConstantKernel(c=first(c)), [c])
29+
test_ADs(c->ConstantKernel(c=first(c)), [c])
3030
end
3131
end

test/basekernels/cosine.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@
1212
@test kappa(k,x) cospi(x) atol=1e-5
1313
@test k(v1, v2) cospi(sqrt(sum(abs2.(v1-v2)))) atol=1e-5
1414
@test repr(k) == "Cosine Kernel"
15+
test_ADs(CosineKernel)
1516
end

test/basekernels/exponential.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
@test SEKernel == SqExponentialKernel
1515
@test repr(k) == "Squared Exponential Kernel"
1616
@test KernelFunctions.iskroncompatible(k) == true
17+
test_ADs(SEKernel)
1718
end
1819
@testset "ExponentialKernel" begin
1920
k = ExponentialKernel()
@@ -24,6 +25,7 @@
2425
@test repr(k) == "Exponential Kernel"
2526
@test LaplacianKernel == ExponentialKernel
2627
@test KernelFunctions.iskroncompatible(k) == true
28+
test_ADs(ExponentialKernel)
2729
end
2830
@testset "GammaExponentialKernel" begin
2931
γ = 2.0
@@ -36,7 +38,8 @@
3638
@test metric(GammaExponentialKernel=2.0)) == SqEuclidean()
3739
@test repr(k) == "Gamma Exponential Kernel (γ = $(γ))"
3840
@test KernelFunctions.iskroncompatible(k) == true
39-
41+
test_ADs-> GammaExponentialKernel(gamma=first(γ)), [γ], ADs = [:ForwardDiff, :ReverseDiff])
42+
@test_broken "Zygote gradient given γ"
4043
#Coherence :
4144
@test GammaExponentialKernel=1.0)(v1,v2) SqExponentialKernel()(v1,v2)
4245
@test GammaExponentialKernel=0.5)(v1,v2) ExponentialKernel()(v1,v2)

test/basekernels/exponentiated.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010
@test k(v1,v2) exp(dot(v1,v2))
1111
@test metric(ExponentiatedKernel()) == KernelFunctions.DotProduct()
1212
@test repr(k) == "Exponentiated Kernel"
13+
test_ADs(ExponentiatedKernel)
1314
end

test/basekernels/fbm.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,6 @@
2121
@test kernelmatrix(k, x1*ones(1,1), x2*ones(1,1))[1] k(x1, x2) atol=1e-5
2222

2323
@test repr(k) == "Fractional Brownian Motion Kernel (h = $(h))"
24+
test_ADs(FBMKernel, ADs = [:ReverseDiff])
25+
@test_broken "Tests failing for kernelmatrix(k, x) for ForwardDiff and Zygote"
2426
end

test/basekernels/gabor.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,6 @@
1717
@test k.ell 1.0 atol=1e-5
1818
@test k.p 1.0 atol=1e-5
1919
@test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)"
20+
test_ADs(x -> GaborKernel(ell = x[1], p = x[2]), [ell, p], ADs = [:ForwardDiff, :ReverseDiff])
21+
@test_broken "Tests failing for Zygote on differentiating through ell and p"
2022
end

test/basekernels/maha.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,6 @@
1111
@test k(v1, v2) exp(-sqmahalanobis(v1, v2, P))
1212
@test kappa(ExponentialKernel(), x) == kappa(k, x)
1313
@test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))"
14+
# test_ADs(P -> MahalanobisKernel(P), P)
15+
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)"
1416
end

test/basekernels/matern.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
@test metric(MaternKernel()) == Euclidean()
1515
@test metric(MaternKernel=2.0)) == Euclidean()
1616
@test repr(k) == "Matern Kernel (ν = $(ν))"
17+
test_ADs(x->MaternKernel(nu=first(x)),[ν])
18+
@test_broken "All fails (because of logabsgamma for ForwardDiff and ReverseDiff and because of nu for Zygote)"
1719
end
1820
@testset "Matern32Kernel" begin
1921
k = Matern32Kernel()
@@ -22,6 +24,7 @@
2224
@test kappa(Matern32Kernel(),x) == kappa(k,x)
2325
@test metric(Matern32Kernel()) == Euclidean()
2426
@test repr(k) == "Matern 3/2 Kernel"
27+
test_ADs(Matern32Kernel)
2528
end
2629
@testset "Matern52Kernel" begin
2730
k = Matern52Kernel()
@@ -30,6 +33,7 @@
3033
@test kappa(Matern52Kernel(),x) == kappa(k,x)
3134
@test metric(Matern52Kernel()) == Euclidean()
3235
@test repr(k) == "Matern 5/2 Kernel"
36+
test_ADs(Matern52Kernel)
3337
end
3438
@testset "Coherence Materns" begin
3539
@test kappa(MaternKernel=0.5),x) kappa(ExponentialKernel(),x)

test/basekernels/nn.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,6 @@
4343
@test_throws DimensionMismatch kernelmatrix!(A5, k, ones(4,3), ones(3,4))
4444

4545
@test k([x1], [x2]) k(x1, x2) atol=1e-5
46-
46+
test_ADs(NeuralNetworkKernel, ADs = [:ForwardDiff, :ReverseDiff])
47+
@test_broken "Zygote uncompatible with BaseKernel"
4748
end

test/basekernels/periodic.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,6 @@
77
@test k(v1, v2) == k(v2, v1)
88
@test PeriodicKernel(3)(v1, v2) == PeriodicKernel(r = ones(3))(v1, v2)
99
@test repr(k) == "Periodic Kernel, length(r) = $(length(r)))"
10+
test_ADs(r->PeriodicKernel(r =r), r, ADs = [:ForwardDiff, :ReverseDiff])
11+
@test_broken "Undefined adjoint for Sinus metric"
1012
end

0 commit comments

Comments
 (0)