Skip to content

Commit b6ddf52

Browse files
committed
Clearer failing messages
1 parent 6b5ba4d commit b6ddf52

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

test/utils_AD.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ function gradient(f, ::Val{:FiniteDiff}, args)
2929
end
3030

3131
function compare_gradient(f, AD::Symbol, args)
32-
isapprox(gradient(f, AD, args), gradient(f, :FiniteDiff, args), atol=1e-8, rtol=1e-5)
32+
grad_AD = gradient(f, AD, args)
33+
grad_FD = gradient(f, :FiniteDiff, args)
34+
@test grad_AD grad_FD atol=1e-8 rtol=1e-5
3335
end
3436

3537
testfunction(k, A, B, dim) = sum(kernelmatrix(k, A, B, obsdim = dim))
@@ -104,40 +106,40 @@ function test_AD(AD::Symbol, kernelfunction, args = nothing, dims = [3, 3])
104106
rng = MersenneTwister(42)
105107
if k isa SimpleKernel
106108
for d in log.([eps(), rand(rng)])
107-
@test compare_gradient(AD, [d]) do x
109+
compare_gradient(AD, [d]) do x
108110
kappa(k, exp(x[1]))
109111
end
110112
end
111113
end
112114
# Testing kernel evaluations
113115
x = rand(rng, dims[1])
114116
y = rand(rng, dims[1])
115-
@test compare_gradient(AD, x) do x
117+
compare_gradient(AD, x) do x
116118
k(x, y)
117119
end
118-
@test compare_gradient(AD, y) do y
120+
compare_gradient(AD, y) do y
119121
k(x, y)
120122
end
121123
if !(args === nothing)
122-
@test compare_gradient(AD, args) do p
124+
compare_gradient(AD, args) do p
123125
kernelfunction(p)(x,y)
124126
end
125127
end
126128
# Testing kernel matrices
127129
A = rand(rng, dims...)
128130
B = rand(rng, dims...)
129131
for dim in 1:2
130-
@test compare_gradient(AD, A) do a
132+
compare_gradient(AD, A) do a
131133
testfunction(k, a, dim)
132134
end
133-
@test compare_gradient(AD, A) do a
135+
compare_gradient(AD, A) do a
134136
testfunction(k, a, B, dim)
135137
end
136-
@test compare_gradient(AD, B) do b
138+
compare_gradient(AD, B) do b
137139
testfunction(k, A, b, dim)
138140
end
139141
if !(args === nothing)
140-
@test compare_gradient(AD, args) do p
142+
compare_gradient(AD, args) do p
141143
testfunction(kernelfunction(p), A, dim)
142144
end
143145
end

0 commit comments

Comments
 (0)