Skip to content

Commit 577518f

Browse files
committed
Rewrote testing code
1 parent d586967 commit 577518f

File tree

1 file changed

+55
-20
lines changed

1 file changed

+55
-20
lines changed

test/utils_AD.jl

Lines changed: 55 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11

2-
FDM = FiniteDifferences.central_fdm(5, 1)
2+
const FDM = FiniteDifferences.central_fdm(5, 1)
33

4-
function gradient(::Val{:Zygote}, f::Function, args)
4+
gradient(f, s::Symbol, args) = gradient(f, Val(s), args)
5+
6+
function gradient(f, ::Val{:Zygote}, args)
57
g = first(Zygote.gradient(f, args))
68
if isnothing(g)
79
if args isa AbstractArray{<:Real}
@@ -14,18 +16,21 @@ function gradient(::Val{:Zygote}, f::Function, args)
1416
end
1517
end
1618

17-
function gradient(::Val{:ForwardDiff}, f::Function, args)
19+
function gradient(f, ::Val{:ForwardDiff}, args)
1820
ForwardDiff.gradient(f, args)
1921
end
2022

21-
function gradient(::Val{:ReverseDiff}, f::Function, args)
23+
function gradient(f, ::Val{:ReverseDiff}, args)
2224
ReverseDiff.gradient(f, args)
2325
end
2426

25-
function gradient(::Val{:FiniteDiff}, f::Function, args)
27+
function gradient(f, ::Val{:FiniteDiff}, args)
2628
first(FiniteDifferences.grad(FDM, f, args))
2729
end
2830

31+
function compare_gradient(f, AD::Symbol, args)
32+
isapprox(gradient(f, AD, args), gradient(f, :FiniteDiff, args), atol=1e-8, rtol=1e-5)
33+
end
2934

3035
testfunction(k, A, B, dim) = sum(kernelmatrix(k, A, B, obsdim = dim))
3136
testfunction(k, A, dim) = sum(kernelmatrix(k, A, obsdim = dim))
@@ -50,25 +55,39 @@ function test_FiniteDiff(kernelfunction, args = nothing, dims = [3, 3])
5055
@testset "FiniteDifferences" begin
5156
if k isa SimpleKernel
5257
for d in log.([eps(), rand(rng)])
53-
@test_nowarn gradient(Val(:FiniteDiff), x -> kappa(k, exp(first(x))), [d])
58+
@test_nowarn gradient(:FiniteDiff, [d]) do x
59+
kappa(k, exp(first(x)))
60+
end
5461
end
5562
end
5663
## Testing Kernel Functions
5764
x = rand(rng, dims[1])
5865
y = rand(rng, dims[1])
59-
@test_nowarn gradient(Val(:FiniteDiff), x -> k(x, y), x)
66+
@test_nowarn gradient(:FiniteDiff, x) do x
67+
k(x, y)
68+
end
6069
if !(args === nothing)
61-
@test_nowarn gradient(Val(:FiniteDiff), p -> kernelfunction(p)(x, y), args)
70+
@test_nowarn gradient(:FiniteDiff, args) do p
71+
kernelfunction(p)(x, y)
72+
end
6273
end
6374
## Testing Kernel Matrices
6475
A = rand(rng, dims...)
6576
B = rand(rng, dims...)
6677
for dim in 1:2
67-
@test_nowarn gradient(Val(:FiniteDiff), a -> testfunction(k, a, dim), A)
68-
@test_nowarn gradient(Val(:FiniteDiff), a -> testfunction(k, a, B, dim), A)
69-
@test_nowarn gradient(Val(:FiniteDiff), b -> testfunction(k, A, b, dim), B)
78+
@test_nowarn gradient(:FiniteDiff, A) do a
79+
testfunction(k, a, dim)
80+
end
81+
@test_nowarn gradient(:FiniteDiff , A) do a
82+
testfunction(k, a, B, dim)
83+
end
84+
@test_nowarn gradient(:FiniteDiff, B) do b
85+
testfunction(k, A, b, dim)
86+
end
7087
if !(args === nothing)
71-
@test_nowarn gradient(Val(:FiniteDiff), p -> testfunction(kernelfunction(p), A, B, dim), args)
88+
@test_nowarn gradient(:FiniteDiff, args) do p
89+
testfunction(kernelfunction(p), A, B, dim)
90+
end
7291
end
7392
end
7493
end
@@ -85,26 +104,42 @@ function test_AD(AD::Symbol, kernelfunction, args = nothing, dims = [3, 3])
85104
rng = MersenneTwister(42)
86105
if k isa SimpleKernel
87106
for d in log.([eps(), rand(rng)])
88-
@test gradient(Val(AD), x -> kappa(k, exp(x[1])), [d]) gradient(Val(:FiniteDiff), x -> kappa(k, exp(x[1])), [d]) atol=1e-8 rtol=1e-5
107+
@test compare_gradient(AD, [d]) do x
108+
kappa(k, exp(x[1])
109+
end
89110
end
90111
end
91112
# Testing kernel evaluations
92113
x = rand(rng, dims[1])
93114
y = rand(rng, dims[1])
94-
@test gradient(Val(AD), x -> k(x, y), x) gradient(Val(:FiniteDiff), x -> k(x, y), x) atol=1e-8 rtol=1e-5
95-
@test gradient(Val(AD), y -> k(x, y), y) gradient(Val(:FiniteDiff), y -> k(x, y), y) atol=1e-8 rtol=1e-5
115+
@test compare_gradient(AD, x) do x
116+
k(x, y)
117+
end
118+
@test compare_gradient(AD, y) do y
119+
k(x, y)
120+
end
96121
if !(args === nothing)
97-
@test gradient(Val(AD), p -> kernelfunction(p)(x,y), args) gradient(Val(:FiniteDiff), p -> kernelfunction(p)(x, y), args) atol=1e-8 rtol=1e-5
122+
@test compare_gradient(AD, args) do p
123+
kernelfunction(p)(x,y)
124+
end
98125
end
99126
# Testing kernel matrices
100127
A = rand(rng, dims...)
101128
B = rand(rng, dims...)
102129
for dim in 1:2
103-
@test gradient(Val(AD), x -> testfunction(k, x, dim), A) gradient(Val(:FiniteDiff), x -> testfunction(k, x, dim), A) atol=1e-8 rtol=1e-5
104-
@test gradient(Val(AD), a -> testfunction(k, a, B, dim), A) gradient(Val(:FiniteDiff), a -> testfunction(k, a, B, dim), A) atol=1e-8 rtol=1e-5
105-
@test gradient(Val(AD), b -> testfunction(k, A, b, dim), B) gradient(Val(:FiniteDiff), b -> testfunction(k, A, b, dim), B) atol=1e-8 rtol=1e-5
130+
@test compare_gradient(AD, A) do a
131+
testfunction(k, a, dim)
132+
end
133+
@test conpare_gradient(AD, A) do a
134+
testfunction(k, a, B, dim)
135+
end
136+
@test compare_gradient(AD, B) do b
137+
testfunction(k, A, b, dim)
138+
end
106139
if !(args === nothing)
107-
@test gradient(Val(AD), p -> testfunction(kernelfunction(p), A, dim), args) gradient(Val(:FiniteDiff), p -> testfunction(kernelfunction(p), A, dim), args) atol=1e-8 rtol=1e-5
140+
@test compare_gradient(AD, args) do p
141+
testfunction(kernelfunction(p), AD, A, dim)
142+
end
108143
end
109144
end
110145
end

0 commit comments

Comments
 (0)