Skip to content

Commit 4aeb0e3

Browse files
committed
First draft of AD tests
1 parent a6159e1 commit 4aeb0e3

File tree

2 files changed

+128
-115
lines changed

2 files changed

+128
-115
lines changed

test/test_AD.jl

Lines changed: 75 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,119 +1,101 @@
11
using KernelFunctions
2-
using Zygote, ForwardDiff
3-
using Test, LinearAlgebra
2+
using KernelFunctions: kappa
3+
using Flux: params
4+
import Zygote, ForwardDiff, ReverseDiff
5+
using Test, LinearAlgebra, Random
46
using FiniteDifferences
57

6-
dims = [10,5]
8+
include("utils_AD.jl")
9+
10+
dims = [3, 3]
11+
ν = 3.0
12+
13+
rng = MersenneTwister(42)
14+
15+
A = rand(rng, dims...)
16+
B = rand(rng, dims...)
17+
K = [zeros(dims[1], dims[1]), zeros(dims[2], dims[2])]
18+
19+
x = rand(rng, dims[1])
20+
y = rand(rng, dims[1])
21+
22+
l = rand(rng)
23+
vl = l * ones(dims[1])
24+
25+
kernels = [
26+
SqExponentialKernel(),
27+
ExponentialKernel(),
28+
MaternKernel= ν),
29+
# transform(SqExponentialKernel(), l),
30+
# transform(SqExponentialKernel(), vl),
31+
# ExponentiatedKernel() + LinearKernel(),
32+
# 2.0 * PolynomialKernel() * Matern32Kernel(),
33+
]
34+
35+
ds = log.([eps(), rand(rng)])
36+
37+
testfunction(k, A, B, dim) = det(kernelmatrix(k, A, B, obsdim = dim))
38+
testfunction(k, A, dim) = det(kernelmatrix(k, A, obsdim = dim))
39+
ADs = [:Zygote, :ForwardDiff, :ReverseDiff]
740

8-
A = rand(dims...)
9-
B = rand(dims...)
10-
K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
11-
kernels_noparams = [:SqExponentialKernel,:ExponentialKernel,:GammaExponentialKernel,
12-
:MaternKernel,:Matern32Kernel,:Matern52Kernel,
13-
:LinearKernel,:PolynomialKernel,
14-
:RationalQuadraticKernel,:GammaRationalQuadraticKernel,
15-
:ExponentiatedKernel]
16-
l = 2.0
17-
ds = [0.0,3.0]
18-
vl = l*ones(dims[1])
19-
testfunction(k,A,B) = det(kernelmatrix(k,A,B))
20-
testfunction(k,A) = det(kernelmatrix(k,A))
21-
ADs = [:Zygote,:ForwardDiff]
2241

2342
## Test kappa functions
43+
2444
@testset "Kappa functions" begin
25-
for AD in ADs
26-
@testset "$AD" begin
27-
for k in kernels_noparams
28-
for d in ds
29-
@eval begin @test kappa_AD(Val(Symbol($AD)),$k(),$d) kappa_fdm($k(),$d) atol=1e-8 end
45+
for k in kernels[isa.(kernels, KernelFunctions.SimpleKernel)]
46+
@testset "$k" begin
47+
@test_nowarn gradient(Val(:FiniteDiff), x -> kappa(k, exp(x[1])), ds[1]) # Check FiniteDiff does the right thing
48+
for AD in ADs
49+
@testset "$AD" begin
50+
for d in ds
51+
@test_nowarn gradient(Val(AD), x -> kappa(k, exp(x[1])), [d])
52+
@test gradient(Val(AD), x -> kappa(k, exp(x[1])), [d]) gradient(Val(:FiniteDiff), x -> kappa(k, exp(x[1])), [d]) atol=1e-8
53+
end
3054
end
3155
end
32-
# Linear -> C
33-
# Polynomial -> C,D
34-
# Gamma (etc) -> gamma
35-
#
3656
end
3757
end
3858
end
3959

40-
@testset "Transform Operations" begin
41-
for AD in ADs
42-
@testset "$AD" begin
43-
@eval begin
44-
# Scale Transform
45-
transform_AD(Val(Symbol($AD)),ScaleTransform(l),A)
46-
# ARD Transform
47-
transform_AD(Val(Symbol($AD)),ARDTransform(vl),A)
48-
# Linear transform
49-
transform_AD(Val(Symbol($AD)), LinearTransform(rand(2,10)),A)
50-
# Chain Transform
51-
# transform_AD(Val(Symbol($AD)), LinearTransform, A)
60+
@testset "Kernel evaluations" begin
61+
for k in kernels
62+
@testset "$k" begin
63+
for AD in ADs
64+
@test_nowarn gradient(Val(:FiniteDiff), x -> k(x, y), x)
65+
@testset "$AD" begin
66+
for d in ds
67+
@test_nowarn gradient(Val(AD), x -> k(x, y), x)
68+
@test gradient(Val(AD), x -> k(x, y), x) gradient(Val(:FiniteDiff), x -> k(x, y), x) atol=1e-8
69+
end
70+
end
5271
end
5372
end
5473
end
5574
end
5675

57-
##TODO Eventually store real results in file
58-
@testset "Zygote Automatic Differentiation test" begin
59-
@testset "ARD" begin
60-
for k in kernels
61-
@testset "$k" begin
62-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A,B),vl)[1], ForwardDiff.gradient(x->testfunction(k(x),A,B),vl)))
63-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(vl),x,B),A)[1],ForwardDiff.gradient(x->testfunction(k(vl),x,B),A)))
64-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A),vl)[1],ForwardDiff.gradient(x->testfunction(k(x),A),vl)))
65-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(vl),x),A)[1],ForwardDiff.gradient(x->testfunction(k(vl),x),A)))
66-
end
67-
end
68-
end
69-
@testset "ISO" begin
70-
for k in kernels
71-
@testset "$k" begin
72-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A,B),l)[1],ForwardDiff.gradient(x->testfunction(k(x[1]),A,B),[l])[1]))
73-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(l),x,B),A)[1],ForwardDiff.gradient(x->testfunction(k(l),x,B),A)))
74-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A),l)[1],ForwardDiff.gradient(x->testfunction(k(x[1]),A),[l])))
75-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(l),x),A)[1],ForwardDiff.gradient(x->testfunction(k(l[1]),x),A)))
76+
@testset "Kernel Matrices" begin
77+
for k in kernels
78+
@testset "$k" begin
79+
for AD in ADs
80+
# @test_nowarn gradient(Val(:FiniteDiff), x -> k(x, y), )
81+
@testset "$AD" begin
82+
for dim in [1,2]
83+
@test_nowarn gradient(Val(AD), x -> testfunction(k, x, dim), A)
84+
@test_nowarn gradient(Val(AD), x -> testfunction(k, x, B, dim), A)
85+
@test gradient(Val(AD), x -> testfunction(k, x, B, dim), A) gradient(Val(:FiniteDiff), x -> testfunction(k, x, B, dim), A) atol=1e-8
86+
@test gradient(Val(AD), x -> testfunction(k, x, dim), A) gradient(Val(:FiniteDiff), x -> testfunction(k, x, dim), A) atol=1e-8
87+
end
88+
end
7689
end
7790
end
7891
end
7992
end
8093

81-
@testset "ForwardDiff AutomaticDifferentation test" begin
82-
@testset "ARD" begin
83-
for k in kernels
84-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(x),A,B),vl)
85-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(vl),x,B),A)
86-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(x),A),vl)
87-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(vl),x),A)
88-
end
89-
end
90-
@testset "ISO" begin
91-
for k in kernels
92-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(x[1]),A,B),[l])
93-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(l),x,B),A)
94-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(x[1]),A),[l])
95-
@test_nowarn ForwardDiff.gradient(x->testfunction(k(l[1]),x),A)
96-
end
97-
end
98-
end
99-
100-
101-
@testset "Tracker AutomaticDifferentation test" begin
102-
@testset "ARD" begin
103-
for k in kernels
104-
@test_broken all(Tracker.gradient(x->testfunction(k(x),A,B),vl)[1] .≈ ForwardDiff.gradient(x->testfunction(k(x),A,B),vl))
105-
@test_broken all(Tracker.gradient(x->testfunction(k(vl),x,B),A)[1] .≈ ForwardDiff.gradient(x->testfunction(k(vl),x,B),A))
106-
@test_broken all(Tracker.gradient(x->testfunction(k(x),A),vl)[1] .≈ ForwardDiff.gradient(x->testfunction(k(x),A),vl))
107-
@test_broken all.(Tracker.gradient(x->testfunction(k(vl),x),A) .≈ ForwardDiff.gradient(x->testfunction(k(vl),x),A))
108-
end
109-
end
110-
@testset "ISO" begin
111-
for k in kernels
112-
@test_broken Tracker.gradient(x->testfunction(k(x[1]),A,B),[l])
113-
@test_broken Tracker.gradient(x->testfunction(k(l),x,B),A)
114-
@test_broken Tracker.gradient(x->testfunction(k(x[1]),A),[l])
115-
@test_broken Tracker.gradient(x->testfunction(k(l),x),A)
116-
94+
@testset "Params differentiation" begin
95+
for k in kernels
96+
@testset "$k" begin
97+
ps = params(k)
98+
@test_nowarn gradient(Val(:Zygote), () -> k(x, y), ps)
11799
end
118100
end
119101
end

test/utils_AD.jl

Lines changed: 53 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,70 @@
1-
allapprox(x,y,tol=1e-8) = all(isapprox.(x,y,atol=tol))
2-
FDM = central_fdm(5,1)
1+
allapprox(x, y, tol = 1e-8) = all(isapprox.(x, y, atol = tol))
2+
FDM = central_fdm(5, 1)
33

4+
function gradient(::Val{:Zygote}, f::Function, args)
5+
first(Zygote.gradient(f, args))
6+
end
7+
8+
function gradient(::Val{:Zygote}, f::Function, args::Zygote.Params)
9+
Zygote.gradient(f, args)
10+
end
411

5-
function kappa_AD(::Val{:Zygote},k::Kernel,d::Real)
6-
first(Zygote.gradient(x->kappa(k,x),d))
12+
function gradient(::Val{:ForwardDiff}, f::Function, args)
13+
ForwardDiff.gradient(f, args)
714
end
815

9-
function kappa_AD(::Val{:ForwardDiff},k::Kernel,d::Real)
10-
first(ForwardDiff.gradient(x->kappa(k,first(x)),[d]))
16+
function gradient(::Val{:ReverseDiff}, f::Function, args)
17+
ReverseDiff.gradient(f, args)
1118
end
1219

13-
function kappa_fdm(k::Kernel,d::Real)
14-
first(FiniteDifferences.grad(FDM,x->kappa(k,x),d))
20+
function gradient(::Val{:FiniteDiff}, f::Function, args)
21+
first(FiniteDifferences.grad(FDM, f, args))
1522
end
1623

1724

18-
function transform_AD(::Val{:Zygote},t::Transform,A)
25+
26+
function transform_AD(::Val{:Zygote}, t::Transform, A)
1927
ps = KernelFunctions.params(t)
20-
@test allapprox(first(Zygote.gradient(p->transform_with_duplicate(p,t,A),ps)),
21-
first(FiniteDifferences.grad(FDM,p->transform_with_duplicate(p,t,A),ps)))
22-
@test allapprox(first(Zygote.gradient(X->sum(transform(t,X,2)),A)),
23-
first(FiniteDifferences.grad(FDM,X->sum(transform(t,X,2)),A)))
28+
@test allapprox(
29+
first(Zygote.gradient(p -> transform_with_duplicate(p, t, A), ps)),
30+
first(FiniteDifferences.grad(
31+
FDM,
32+
p -> transform_with_duplicate(p, t, A),
33+
ps,
34+
)),
35+
)
36+
@test allapprox(
37+
first(Zygote.gradient(X -> sum(transform(t, X, 2)), A)),
38+
first(FiniteDifferences.grad(FDM, X -> sum(transform(t, X, 2)), A)),
39+
)
2440
end
2541

26-
function transform_AD(::Val{:ForwardDiff},t::Transform,A)
42+
function transform_AD(::Val{:ForwardDiff}, t::Transform, A)
2743
ps = KernelFunctions.params(t)
2844
if t isa ScaleTransform
29-
@test allapprox(first(ForwardDiff.gradient(p->transform_with_duplicate(first(p),t,A),[ps])),
30-
first(FiniteDifferences.grad(FDM,p->transform_with_duplicate(p,t,A),ps)))
45+
@test allapprox(
46+
first(ForwardDiff.gradient(
47+
p -> transform_with_duplicate(first(p), t, A),
48+
[ps],
49+
)),
50+
first(FiniteDifferences.grad(
51+
FDM,
52+
p -> transform_with_duplicate(p, t, A),
53+
ps,
54+
)),
55+
)
3156
else
32-
@test allapprox(ForwardDiff.gradient(p->transform_with_duplicate(p,t,A),ps),
33-
first(FiniteDifferences.grad(FDM,p->transform_with_duplicate(p,t,A),ps)))
57+
@test allapprox(
58+
ForwardDiff.gradient(p -> transform_with_duplicate(p, t, A), ps),
59+
first(FiniteDifferences.grad(
60+
FDM,
61+
p -> transform_with_duplicate(p, t, A),
62+
ps,
63+
)),
64+
)
3465
end
35-
@test allapprox(ForwardDiff.gradient(X->sum(transform(t,X,2)),A),
36-
first(FiniteDifferences.grad(FDM,X->sum(transform(t,X,2)),A)))
66+
@test allapprox(
67+
ForwardDiff.gradient(X -> sum(transform(t, X, 2)), A),
68+
first(FiniteDifferences.grad(FDM, X -> sum(transform(t, X, 2)), A)),
69+
)
3770
end
38-
39-
transform_with_duplicate(p,t,A) = sum(transform(KernelFunctions.duplicate(t,p),A,2))

0 commit comments

Comments
 (0)