Skip to content

Commit b6a7901

Browse files
committed
Removing unnecessary functions and uncommented all cases
1 parent 4aeb0e3 commit b6a7901

File tree

2 files changed

+6
-54
lines changed

2 files changed

+6
-54
lines changed

test/test_AD.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using KernelFunctions
2-
using KernelFunctions: kappa
2+
using KernelFunctions: kappa, ColVecs, RowVecs
33
using Flux: params
44
import Zygote, ForwardDiff, ReverseDiff
5+
using Zygote: pullback
56
using Test, LinearAlgebra, Random
67
using FiniteDifferences
78

@@ -26,10 +27,10 @@ kernels = [
2627
SqExponentialKernel(),
2728
ExponentialKernel(),
2829
MaternKernel= ν),
29-
# transform(SqExponentialKernel(), l),
30-
# transform(SqExponentialKernel(), vl),
31-
# ExponentiatedKernel() + LinearKernel(),
32-
# 2.0 * PolynomialKernel() * Matern32Kernel(),
30+
transform(SqExponentialKernel(), l),
31+
transform(SqExponentialKernel(), vl),
32+
ExponentiatedKernel() + LinearKernel(),
33+
2.0 * PolynomialKernel() * Matern32Kernel(),
3334
]
3435

3536
ds = log.([eps(), rand(rng)])

test/utils_AD.jl

Lines changed: 0 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
allapprox(x, y, tol = 1e-8) = all(isapprox.(x, y, atol = tol))
21
FDM = central_fdm(5, 1)
32

43
function gradient(::Val{:Zygote}, f::Function, args)
@@ -20,51 +19,3 @@ end
2019
function gradient(::Val{:FiniteDiff}, f::Function, args)
2120
first(FiniteDifferences.grad(FDM, f, args))
2221
end
23-
24-
25-
26-
function transform_AD(::Val{:Zygote}, t::Transform, A)
27-
ps = KernelFunctions.params(t)
28-
@test allapprox(
29-
first(Zygote.gradient(p -> transform_with_duplicate(p, t, A), ps)),
30-
first(FiniteDifferences.grad(
31-
FDM,
32-
p -> transform_with_duplicate(p, t, A),
33-
ps,
34-
)),
35-
)
36-
@test allapprox(
37-
first(Zygote.gradient(X -> sum(transform(t, X, 2)), A)),
38-
first(FiniteDifferences.grad(FDM, X -> sum(transform(t, X, 2)), A)),
39-
)
40-
end
41-
42-
function transform_AD(::Val{:ForwardDiff}, t::Transform, A)
43-
ps = KernelFunctions.params(t)
44-
if t isa ScaleTransform
45-
@test allapprox(
46-
first(ForwardDiff.gradient(
47-
p -> transform_with_duplicate(first(p), t, A),
48-
[ps],
49-
)),
50-
first(FiniteDifferences.grad(
51-
FDM,
52-
p -> transform_with_duplicate(p, t, A),
53-
ps,
54-
)),
55-
)
56-
else
57-
@test allapprox(
58-
ForwardDiff.gradient(p -> transform_with_duplicate(p, t, A), ps),
59-
first(FiniteDifferences.grad(
60-
FDM,
61-
p -> transform_with_duplicate(p, t, A),
62-
ps,
63-
)),
64-
)
65-
end
66-
@test allapprox(
67-
ForwardDiff.gradient(X -> sum(transform(t, X, 2)), A),
68-
first(FiniteDifferences.grad(FDM, X -> sum(transform(t, X, 2)), A)),
69-
)
70-
end

0 commit comments

Comments
 (0)