Skip to content

Commit 65505a1

Browse files
committed
Tests for Zygote and ForwardDiff
1 parent 9c83324 commit 65505a1

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

test/testAD.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
using Zygote, ForwardDiff
2+
3+
dims = [10,5]
4+
5+
A = rand(dims...)
6+
B = rand(dims...)
7+
K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
8+
kernels = [SquaredExponentialKernel]
9+
l = 2.0
10+
vl = l*ones(dims[1])
11+
testfunction(k,A,B) = sum(kernelmatrix(k,A,B))
12+
testfunction(k,A) = sum(kernelmatrix(k,A))
13+
14+
15+
#For debugging
16+
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)
17+
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)
18+
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),l)
19+
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x[1]),A),[l])
20+
21+
##Eventually store real results in file
22+
23+
@testset "Zygote Automatic Differentiation test" begin
24+
@testset "ARD" begin
25+
for k in kernels
26+
@test Zygote.gradient(x->testfunction(k(x),A,B),vl)
27+
@test Zygote.gradient(x->testfunction(k(vl),x,B),A)
28+
@test Zygote.gradient(x->testfunction(k(x),A),vl)
29+
@test Zygote.gradient(x->testfunction(k(vl),x),A)
30+
end
31+
end
32+
@testset "ISO" begin
33+
for k in kernels
34+
@test Zygote.gradient(x->testfunction(k(x),A,B),l)
35+
@test Zygote.gradient(x->testfunction(k(l),x,B),A)
36+
@test Zygote.gradient(x->testfunction(k(x),A),l)
37+
@test Zygote.gradient(x->testfunction(k(l),x),A)
38+
39+
end
40+
end
41+
end
42+
43+
@testset "ForwardDiff AutomaticDifferentation test" begin
44+
@testset "ARD" begin
45+
for k in kernels
46+
@test ForwardDiff.gradient(x->testfunction(k(x),A,B),vl)
47+
@test ForwardDiff.gradient(x->testfunction(k(vl),x,B),A)
48+
@test ForwardDiff.gradient(x->testfunction(k(x),A),vl)
49+
@test ForwardDiff.gradient(x->testfunction(k(vl),x),A)
50+
end
51+
end
52+
@testset "ISO" begin
53+
for k in kernels
54+
@test ForwardDiff.gradient(x->testfunction(k(x[1]),A,B),[l])
55+
@test ForwardDiff.gradient(x->testfunction(k(l),x,B),A)
56+
@test ForwardDiff.gradient(x->testfunction(k(x),A),[l])
57+
@test ForwardDiff.gradient(x->testfunction(k(l[1]),x),A)
58+
59+
end
60+
end
61+
end

0 commit comments

Comments
 (0)