Skip to content

Commit f70adc1

Browse files
committed
Created two function for testing any kernel, any AD and compare with FiniteDifferences.jl
1 parent b6a7901 commit f70adc1

File tree

2 files changed

+88
-94
lines changed

2 files changed

+88
-94
lines changed

test/test_AD.jl

Lines changed: 10 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,102 +1,19 @@
11
using KernelFunctions
22
using KernelFunctions: kappa, ColVecs, RowVecs
3-
using Flux: params
4-
import Zygote, ForwardDiff, ReverseDiff
5-
using Zygote: pullback
3+
import Zygote, ForwardDiff, ReverseDiff, FiniteDifferences
64
using Test, LinearAlgebra, Random
7-
using FiniteDifferences
85

96
include("utils_AD.jl")
10-
11-
dims = [3, 3]
12-
ν = 3.0
13-
14-
rng = MersenneTwister(42)
15-
16-
A = rand(rng, dims...)
17-
B = rand(rng, dims...)
18-
K = [zeros(dims[1], dims[1]), zeros(dims[2], dims[2])]
19-
20-
x = rand(rng, dims[1])
21-
y = rand(rng, dims[1])
22-
23-
l = rand(rng)
24-
vl = l * ones(dims[1])
25-
26-
kernels = [
27-
SqExponentialKernel(),
28-
ExponentialKernel(),
29-
MaternKernel= ν),
30-
transform(SqExponentialKernel(), l),
31-
transform(SqExponentialKernel(), vl),
32-
ExponentiatedKernel() + LinearKernel(),
33-
2.0 * PolynomialKernel() * Matern32Kernel(),
34-
]
35-
36-
ds = log.([eps(), rand(rng)])
37-
38-
testfunction(k, A, B, dim) = det(kernelmatrix(k, A, B, obsdim = dim))
39-
testfunction(k, A, dim) = det(kernelmatrix(k, A, obsdim = dim))
407
ADs = [:Zygote, :ForwardDiff, :ReverseDiff]
418

42-
43-
## Test kappa functions
44-
45-
@testset "Kappa functions" begin
46-
for k in kernels[isa.(kernels, KernelFunctions.SimpleKernel)]
47-
@testset "$k" begin
48-
@test_nowarn gradient(Val(:FiniteDiff), x -> kappa(k, exp(x[1])), ds[1]) # Check FiniteDiff does the right thing
49-
for AD in ADs
50-
@testset "$AD" begin
51-
for d in ds
52-
@test_nowarn gradient(Val(AD), x -> kappa(k, exp(x[1])), [d])
53-
@test gradient(Val(AD), x -> kappa(k, exp(x[1])), [d]) gradient(Val(:FiniteDiff), x -> kappa(k, exp(x[1])), [d]) atol=1e-8
54-
end
55-
end
56-
end
57-
end
58-
end
59-
end
60-
61-
@testset "Kernel evaluations" begin
62-
for k in kernels
63-
@testset "$k" begin
64-
for AD in ADs
65-
@test_nowarn gradient(Val(:FiniteDiff), x -> k(x, y), x)
66-
@testset "$AD" begin
67-
for d in ds
68-
@test_nowarn gradient(Val(AD), x -> k(x, y), x)
69-
@test gradient(Val(AD), x -> k(x, y), x) gradient(Val(:FiniteDiff), x -> k(x, y), x) atol=1e-8
70-
end
71-
end
72-
end
73-
end
74-
end
75-
end
76-
77-
@testset "Kernel Matrices" begin
78-
for k in kernels
79-
@testset "$k" begin
80-
for AD in ADs
81-
# @test_nowarn gradient(Val(:FiniteDiff), x -> k(x, y), )
82-
@testset "$AD" begin
83-
for dim in [1,2]
84-
@test_nowarn gradient(Val(AD), x -> testfunction(k, x, dim), A)
85-
@test_nowarn gradient(Val(AD), x -> testfunction(k, x, B, dim), A)
86-
@test gradient(Val(AD), x -> testfunction(k, x, B, dim), A) gradient(Val(:FiniteDiff), x -> testfunction(k, x, B, dim), A) atol=1e-8
87-
@test gradient(Val(AD), x -> testfunction(k, x, dim), A) gradient(Val(:FiniteDiff), x -> testfunction(k, x, dim), A) atol=1e-8
88-
end
89-
end
90-
end
91-
end
92-
end
93-
end
94-
95-
@testset "Params differentiation" begin
96-
for k in kernels
97-
@testset "$k" begin
98-
ps = params(k)
99-
@test_nowarn gradient(Val(:Zygote), () -> k(x, y), ps)
100-
end
9+
kname = "SEKernel_lengthscale"
10+
kfunction = () -> SEKernel()
11+
kfunction = (l -> transform(SEKernel(), first(l)))
12+
# args = nothing
13+
args = [2.0]
14+
v = test_FiniteDiff(kname, kfunction, args)
15+
if !v.anynonpass
16+
for AD in ADs
17+
test_AD(AD, kname, kfunction, args)
10118
end
10219
end

test/utils_AD.jl

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FDM = central_fdm(5, 1)
1+
FDM = FiniteDifferences.central_fdm(5, 1)
22

33
function gradient(::Val{:Zygote}, f::Function, args)
44
first(Zygote.gradient(f, args))
@@ -19,3 +19,80 @@ end
1919
function gradient(::Val{:FiniteDiff}, f::Function, args)
2020
first(FiniteDifferences.grad(FDM, f, args))
2121
end
22+
23+
24+
testfunction(k, A, B, dim) = sum(kernelmatrix(k, A, B, obsdim = dim))
25+
testfunction(k, A, dim) = sum(kernelmatrix(k, A, obsdim = dim))
26+
27+
function test_FiniteDiff(kernelname, kernelfunction, args = nothing)
28+
# Init arguments :
29+
k = if args === nothing
30+
kernelfunction()
31+
else
32+
kernelfunction(args)
33+
end
34+
dims = [3, 3]
35+
rng = MersenneTwister(42)
36+
@testset "FiniteDifferences with $(kernelname)" begin
37+
if k isa SimpleKernel
38+
for d in log.([eps(), rand(rng)])
39+
@test_nowarn gradient(Val(:FiniteDiff), x -> kappa(k, exp(first(x))), [d])
40+
end
41+
end
42+
## Testing Kernel Functions
43+
x = rand(rng, dims[1])
44+
y = rand(rng, dims[1])
45+
@test_nowarn gradient(Val(:FiniteDiff), x -> k(x, y), x)
46+
if !(args === nothing)
47+
@test_nowarn gradient(Val(:FiniteDiff), p -> kernelfunction(p)(x, y), args)
48+
end
49+
## Testing Kernel Matrices
50+
A = rand(rng, dims...)
51+
B = rand(rng, dims...)
52+
for dim in 1:2
53+
@test_nowarn gradient(Val(:FiniteDiff), a -> testfunction(k, a, dim), A)
54+
@test_nowarn gradient(Val(:FiniteDiff), a -> testfunction(k, a, B, dim), A)
55+
@test_nowarn gradient(Val(:FiniteDiff), b -> testfunction(k, A, b, dim), B)
56+
if !(args === nothing)
57+
@test_nowarn gradient(Val(:FiniteDiff), p -> testfunction(kernelfunction(p), A, B, dim), args)
58+
end
59+
end
60+
end
61+
end
62+
63+
function test_AD(AD, kernelname, kernelfunction, args = nothing)
64+
@testset "Testing $(kernelname) with AD : $(AD)" begin
65+
# Test kappa function
66+
dims = [3, 3]
67+
k = if args === nothing
68+
kernelfunction()
69+
else
70+
kernelfunction(args)
71+
end
72+
rng = MersenneTwister(42)
73+
if k isa SimpleKernel
74+
for d in log.([eps(), rand(rng)])
75+
@test gradient(Val(AD), x -> kappa(k, exp(x[1])), [d]) gradient(Val(:FiniteDiff), x -> kappa(k, exp(x[1])), [d]) atol=1e-8
76+
end
77+
end
78+
# Testing kernel evaluations
79+
x = rand(rng, dims[1])
80+
y = rand(rng, dims[1])
81+
@test gradient(Val(AD), x -> k(x, y), x) gradient(Val(:FiniteDiff), x -> k(x, y), x) atol=1e-8
82+
@test gradient(Val(AD), y -> k(x, y), y) gradient(Val(:FiniteDiff), y -> k(x, y), y) atol=1e-8
83+
if !(args === nothing)
84+
@test gradient(Val(AD), p -> kernelfunction(p)(x,y), args) gradient(Val(:FiniteDiff), p -> kernelfunction(p)(x, y), args) atol=1e-8
85+
end
86+
# Testing kernel matrices
87+
A = rand(rng, dims...)
88+
B = rand(rng, dims...)
89+
for dim in 1:2
90+
@test gradient(Val(AD), x -> testfunction(k, x, dim), A) gradient(Val(:FiniteDiff), x -> testfunction(k, x, dim), A) atol=1e-8
91+
@test gradient(Val(AD), a -> testfunction(k, a, B, dim), A) gradient(Val(:FiniteDiff), a -> testfunction(k, a, B, dim), A) atol=1e-8
92+
@test gradient(Val(AD), b -> testfunction(k, A, b, dim), B) gradient(Val(:FiniteDiff), b -> testfunction(k, A, b, dim), B) atol=1e-8
93+
if !(args === nothing)
94+
@test gradient(Val(AD), p -> testfunction(kernelfunction(p), A, dim), args) gradient(Val(:FiniteDiff), p -> testfunction(kernelfunction(p), A, dim), args) atol=1e-8
95+
end
96+
end
97+
end
98+
end

0 commit comments

Comments
 (0)