Skip to content

Commit b7ff456

Browse files
committed
Tests cosmetics
1 parent a5efbb9 commit b7ff456

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

test/testAD.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,29 +7,32 @@ dims = [10,5]
77
A = rand(dims...)
88
B = rand(dims...)
99
K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
10-
kernels = [SquaredExponentialKernel]
10+
kernels = [SquaredExponentialKernel,MaternKernel]
1111
l = 2.0
1212
vl = l*ones(dims[1])
1313
testfunction(k,A,B) = sum(kernelmatrix(k,A,B))
1414
testfunction(k,A) = sum(kernelmatrix(k,A))
1515

16-
testfunction(SquaredExponentialKernel(vl),A)
1716
##Eventually store real results in file
1817
@testset "Zygote Automatic Differentiation test" begin
1918
@testset "ARD" begin
2019
for k in kernels
21-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A,B),vl)[1], ForwardDiff.gradient(x->testfunction(k(x),A,B),vl)))
22-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(vl),x,B),A)[1],ForwardDiff.gradient(x->testfunction(k(vl),x,B),A)))
23-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A),vl)[1],ForwardDiff.gradient(x->testfunction(k(x),A),vl)))
24-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(vl),x),A)[1],ForwardDiff.gradient(x->testfunction(k(vl),x),A)))
20+
@testset "$k" begin
21+
@test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A,B),vl)[1], ForwardDiff.gradient(x->testfunction(k(x),A,B),vl)))
22+
@test all(isapprox.(Zygote.gradient(x->testfunction(k(vl),x,B),A)[1],ForwardDiff.gradient(x->testfunction(k(vl),x,B),A)))
23+
@test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A),vl)[1],ForwardDiff.gradient(x->testfunction(k(x),A),vl)))
24+
@test all(isapprox.(Zygote.gradient(x->testfunction(k(vl),x),A)[1],ForwardDiff.gradient(x->testfunction(k(vl),x),A)))
25+
end
2526
end
2627
end
2728
@testset "ISO" begin
2829
for k in kernels
29-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A,B),l)[1],ForwardDiff.gradient(x->testfunction(k(x[1]),A,B),[l])[1]))
30-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(l),x,B),A)[1],ForwardDiff.gradient(x->testfunction(k(l),x,B),A)))
31-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A),l)[1],ForwardDiff.gradient(x->testfunction(k(x[1]),A),[l])))
32-
@test all(isapprox.(Zygote.gradient(x->testfunction(k(l),x),A)[1],ForwardDiff.gradient(x->testfunction(k(l[1]),x),A)))
30+
@testset "$k" begin
31+
@test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A,B),l)[1],ForwardDiff.gradient(x->testfunction(k(x[1]),A,B),[l])[1]))
32+
@test all(isapprox.(Zygote.gradient(x->testfunction(k(l),x,B),A)[1],ForwardDiff.gradient(x->testfunction(k(l),x,B),A)))
33+
@test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A),l)[1],ForwardDiff.gradient(x->testfunction(k(x[1]),A),[l])))
34+
@test all(isapprox.(Zygote.gradient(x->testfunction(k(l),x),A)[1],ForwardDiff.gradient(x->testfunction(k(l[1]),x),A)))
35+
end
3336
end
3437
end
3538
end

0 commit comments

Comments
 (0)