Skip to content

Commit b0acc6f

Browse files
committed
Corrected most tests
1 parent 90066dc commit b0acc6f

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

test/utils_AD.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
allapprox(x,y,tol=1e-8) = all(isapprox.(x,y,atol=tol))
2+
FDM = central_fdm(5,1)
3+
24

35
function kappa_AD(::Val{:Zygote},k::Kernel,d::Real)
46
first(Zygote.gradient(x->kappa(k,x),d))
@@ -9,29 +11,29 @@ function kappa_AD(::Val{:ForwardDiff},k::Kernel,d::Real)
911
end
1012

1113
function kappa_fdm(k::Kernel,d::Real)
12-
central_fdm(5,1)(x->kappa(k,x),d)
14+
first(FiniteDifferences.grad(FDM,x->kappa(k,x),d))
1315
end
1416

1517

1618
function transform_AD(::Val{:Zygote},t::Transform,A)
1719
ps = KernelFunctions.params(t)
18-
@test allisapprox(first(Zygote.gradient(p->transform_with_duplicate(p,t,A),ps)),
19-
central_fdm(5,1)(p->transform_with_duplicate(p,t,A),ps))
20-
@test allisapprox(first(Zygote.gradient(X->sum(transform(t,X,2)),A))
21-
.≈ central_fdm(5,1)(X->sum(transform(t,X,2)),A))
20+
@test allapprox(first(Zygote.gradient(p->transform_with_duplicate(p,t,A),ps)),
21+
first(FiniteDifferences.grad(FDM,p->transform_with_duplicate(p,t,A),ps)))
22+
@test allapprox(first(Zygote.gradient(X->sum(transform(t,X,2)),A)),
23+
first(FiniteDifferences.grad(FDM,X->sum(transform(t,X,2)),A)))
2224
end
2325

2426
function transform_AD(::Val{:ForwardDiff},t::Transform,A)
2527
ps = KernelFunctions.params(t)
2628
if t isa ScaleTransform
27-
@test allisapprox(first(ForwardDiff.gradient(p->transform_with_duplicate(first(p),t,A),[ps])),
28-
central_fdm(5,1)(p->transform_with_duplicate(p,t,A),ps))
29+
@test allapprox(first(ForwardDiff.gradient(p->transform_with_duplicate(first(p),t,A),[ps])),
30+
first(FiniteDifferences.grad(FDM,p->transform_with_duplicate(p,t,A),ps)))
2931
else
30-
@test allisapprox(ForwardDiff.gradient(p->transform_with_duplicate(p,t,A),ps),
31-
central_fdm(5,1)(p->transform_with_duplicate(p,t,A),ps))
32+
@test allapprox(ForwardDiff.gradient(p->transform_with_duplicate(p,t,A),ps),
33+
first(FiniteDifferences.grad(FDM,p->transform_with_duplicate(p,t,A),ps)))
3234
end
33-
@test allisapprox(ForwardDiff.gradient(X->sum(transform(t,X,2)),A),
34-
central_fdm(5,1)(X->sum(transform(t,X,2)),A))
35+
@test allapprox(ForwardDiff.gradient(X->sum(transform(t,X,2)),A),
36+
first(FiniteDifferences.grad(FDM,X->sum(transform(t,X,2)),A)))
3537
end
3638

3739
transform_with_duplicate(p,t,A) = sum(transform(KernelFunctions.duplicate(t,p),A,2))

0 commit comments

Comments
 (0)