|
1 | 1 | using KernelFunctions
|
2 |
| -using Zygote, ForwardDiff |
3 |
| -using Test, LinearAlgebra |
| 2 | +using KernelFunctions: kappa |
| 3 | +using Flux: params |
| 4 | +import Zygote, ForwardDiff, ReverseDiff |
| 5 | +using Test, LinearAlgebra, Random |
4 | 6 | using FiniteDifferences
|
5 | 7 |
|
6 |
| -dims = [10,5] |
| 8 | +include("utils_AD.jl") |
| 9 | + |
| 10 | +dims = [3, 3] |
| 11 | +ν = 3.0 |
| 12 | + |
| 13 | +rng = MersenneTwister(42) |
| 14 | + |
| 15 | +A = rand(rng, dims...) |
| 16 | +B = rand(rng, dims...) |
| 17 | +K = [zeros(dims[1], dims[1]), zeros(dims[2], dims[2])] |
| 18 | + |
| 19 | +x = rand(rng, dims[1]) |
| 20 | +y = rand(rng, dims[1]) |
| 21 | + |
| 22 | +l = rand(rng) |
| 23 | +vl = l * ones(dims[1]) |
| 24 | + |
| 25 | +kernels = [ |
| 26 | + SqExponentialKernel(), |
| 27 | + ExponentialKernel(), |
| 28 | + MaternKernel(ν = ν), |
| 29 | + # transform(SqExponentialKernel(), l), |
| 30 | + # transform(SqExponentialKernel(), vl), |
| 31 | + # ExponentiatedKernel() + LinearKernel(), |
| 32 | + # 2.0 * PolynomialKernel() * Matern32Kernel(), |
| 33 | +] |
| 34 | + |
| 35 | +ds = log.([eps(), rand(rng)]) |
| 36 | + |
| 37 | +testfunction(k, A, B, dim) = det(kernelmatrix(k, A, B, obsdim = dim)) |
| 38 | +testfunction(k, A, dim) = det(kernelmatrix(k, A, obsdim = dim)) |
| 39 | +ADs = [:Zygote, :ForwardDiff, :ReverseDiff] |
7 | 40 |
|
8 |
| -A = rand(dims...) |
9 |
| -B = rand(dims...) |
10 |
| -K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])] |
11 |
| -kernels_noparams = [:SqExponentialKernel,:ExponentialKernel,:GammaExponentialKernel, |
12 |
| - :MaternKernel,:Matern32Kernel,:Matern52Kernel, |
13 |
| - :LinearKernel,:PolynomialKernel, |
14 |
| - :RationalQuadraticKernel,:GammaRationalQuadraticKernel, |
15 |
| - :ExponentiatedKernel] |
16 |
| -l = 2.0 |
17 |
| -ds = [0.0,3.0] |
18 |
| -vl = l*ones(dims[1]) |
19 |
| -testfunction(k,A,B) = det(kernelmatrix(k,A,B)) |
20 |
| -testfunction(k,A) = det(kernelmatrix(k,A)) |
21 |
| -ADs = [:Zygote,:ForwardDiff] |
22 | 41 |
|
23 | 42 | ## Test kappa functions
|
| 43 | + |
24 | 44 | @testset "Kappa functions" begin
|
25 |
| - for AD in ADs |
26 |
| - @testset "$AD" begin |
27 |
| - for k in kernels_noparams |
28 |
| - for d in ds |
29 |
| - @eval begin @test kappa_AD(Val(Symbol($AD)),$k(),$d) ≈ kappa_fdm($k(),$d) atol=1e-8 end |
| 45 | + for k in kernels[isa.(kernels, KernelFunctions.SimpleKernel)] |
| 46 | + @testset "$k" begin |
| 47 | + @test_nowarn gradient(Val(:FiniteDiff), x -> kappa(k, exp(x[1])), ds[1]) # Check FiniteDiff does the right thing |
| 48 | + for AD in ADs |
| 49 | + @testset "$AD" begin |
| 50 | + for d in ds |
| 51 | + @test_nowarn gradient(Val(AD), x -> kappa(k, exp(x[1])), [d]) |
| 52 | + @test gradient(Val(AD), x -> kappa(k, exp(x[1])), [d]) ≈ gradient(Val(:FiniteDiff), x -> kappa(k, exp(x[1])), [d]) atol=1e-8 |
| 53 | + end |
30 | 54 | end
|
31 | 55 | end
|
32 |
| - # Linear -> C |
33 |
| - # Polynomial -> C,D |
34 |
| - # Gamma (etc) -> gamma |
35 |
| - # |
36 | 56 | end
|
37 | 57 | end
|
38 | 58 | end
|
39 | 59 |
|
40 |
| -@testset "Transform Operations" begin |
41 |
| - for AD in ADs |
42 |
| - @testset "$AD" begin |
43 |
| - @eval begin |
44 |
| - # Scale Transform |
45 |
| - transform_AD(Val(Symbol($AD)),ScaleTransform(l),A) |
46 |
| - # ARD Transform |
47 |
| - transform_AD(Val(Symbol($AD)),ARDTransform(vl),A) |
48 |
| - # Linear transform |
49 |
| - transform_AD(Val(Symbol($AD)), LinearTransform(rand(2,10)),A) |
50 |
| - # Chain Transform |
51 |
| - # transform_AD(Val(Symbol($AD)), LinearTransform, A) |
| 60 | +@testset "Kernel evaluations" begin |
| 61 | + for k in kernels |
| 62 | + @testset "$k" begin |
| 63 | + for AD in ADs |
| 64 | + @test_nowarn gradient(Val(:FiniteDiff), x -> k(x, y), x) |
| 65 | + @testset "$AD" begin |
| 66 | + for d in ds |
| 67 | + @test_nowarn gradient(Val(AD), x -> k(x, y), x) |
| 68 | + @test gradient(Val(AD), x -> k(x, y), x) ≈ gradient(Val(:FiniteDiff), x -> k(x, y), x) atol=1e-8 |
| 69 | + end |
| 70 | + end |
52 | 71 | end
|
53 | 72 | end
|
54 | 73 | end
|
55 | 74 | end
|
56 | 75 |
|
57 |
| -##TODO Eventually store real results in file |
58 |
| -@testset "Zygote Automatic Differentiation test" begin |
59 |
| - @testset "ARD" begin |
60 |
| - for k in kernels |
61 |
| - @testset "$k" begin |
62 |
| - @test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A,B),vl)[1], ForwardDiff.gradient(x->testfunction(k(x),A,B),vl))) |
63 |
| - @test all(isapprox.(Zygote.gradient(x->testfunction(k(vl),x,B),A)[1],ForwardDiff.gradient(x->testfunction(k(vl),x,B),A))) |
64 |
| - @test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A),vl)[1],ForwardDiff.gradient(x->testfunction(k(x),A),vl))) |
65 |
| - @test all(isapprox.(Zygote.gradient(x->testfunction(k(vl),x),A)[1],ForwardDiff.gradient(x->testfunction(k(vl),x),A))) |
66 |
| - end |
67 |
| - end |
68 |
| - end |
69 |
| - @testset "ISO" begin |
70 |
| - for k in kernels |
71 |
| - @testset "$k" begin |
72 |
| - @test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A,B),l)[1],ForwardDiff.gradient(x->testfunction(k(x[1]),A,B),[l])[1])) |
73 |
| - @test all(isapprox.(Zygote.gradient(x->testfunction(k(l),x,B),A)[1],ForwardDiff.gradient(x->testfunction(k(l),x,B),A))) |
74 |
| - @test all(isapprox.(Zygote.gradient(x->testfunction(k(x),A),l)[1],ForwardDiff.gradient(x->testfunction(k(x[1]),A),[l]))) |
75 |
| - @test all(isapprox.(Zygote.gradient(x->testfunction(k(l),x),A)[1],ForwardDiff.gradient(x->testfunction(k(l[1]),x),A))) |
| 76 | +@testset "Kernel Matrices" begin |
| 77 | + for k in kernels |
| 78 | + @testset "$k" begin |
| 79 | + for AD in ADs |
| 80 | + # @test_nowarn gradient(Val(:FiniteDiff), x -> k(x, y), ) |
| 81 | + @testset "$AD" begin |
| 82 | + for dim in [1,2] |
| 83 | + @test_nowarn gradient(Val(AD), x -> testfunction(k, x, dim), A) |
| 84 | + @test_nowarn gradient(Val(AD), x -> testfunction(k, x, B, dim), A) |
| 85 | + @test gradient(Val(AD), x -> testfunction(k, x, B, dim), A) ≈ gradient(Val(:FiniteDiff), x -> testfunction(k, x, B, dim), A) atol=1e-8 |
| 86 | + @test gradient(Val(AD), x -> testfunction(k, x, dim), A) ≈ gradient(Val(:FiniteDiff), x -> testfunction(k, x, dim), A) atol=1e-8 |
| 87 | + end |
| 88 | + end |
76 | 89 | end
|
77 | 90 | end
|
78 | 91 | end
|
79 | 92 | end
|
80 | 93 |
|
81 |
| -@testset "ForwardDiff AutomaticDifferentation test" begin |
82 |
| - @testset "ARD" begin |
83 |
| - for k in kernels |
84 |
| - @test_nowarn ForwardDiff.gradient(x->testfunction(k(x),A,B),vl) |
85 |
| - @test_nowarn ForwardDiff.gradient(x->testfunction(k(vl),x,B),A) |
86 |
| - @test_nowarn ForwardDiff.gradient(x->testfunction(k(x),A),vl) |
87 |
| - @test_nowarn ForwardDiff.gradient(x->testfunction(k(vl),x),A) |
88 |
| - end |
89 |
| - end |
90 |
| - @testset "ISO" begin |
91 |
| - for k in kernels |
92 |
| - @test_nowarn ForwardDiff.gradient(x->testfunction(k(x[1]),A,B),[l]) |
93 |
| - @test_nowarn ForwardDiff.gradient(x->testfunction(k(l),x,B),A) |
94 |
| - @test_nowarn ForwardDiff.gradient(x->testfunction(k(x[1]),A),[l]) |
95 |
| - @test_nowarn ForwardDiff.gradient(x->testfunction(k(l[1]),x),A) |
96 |
| - end |
97 |
| - end |
98 |
| -end |
99 |
| - |
100 |
| - |
101 |
| -@testset "Tracker AutomaticDifferentation test" begin |
102 |
| - @testset "ARD" begin |
103 |
| - for k in kernels |
104 |
| - @test_broken all(Tracker.gradient(x->testfunction(k(x),A,B),vl)[1] .≈ ForwardDiff.gradient(x->testfunction(k(x),A,B),vl)) |
105 |
| - @test_broken all(Tracker.gradient(x->testfunction(k(vl),x,B),A)[1] .≈ ForwardDiff.gradient(x->testfunction(k(vl),x,B),A)) |
106 |
| - @test_broken all(Tracker.gradient(x->testfunction(k(x),A),vl)[1] .≈ ForwardDiff.gradient(x->testfunction(k(x),A),vl)) |
107 |
| - @test_broken all.(Tracker.gradient(x->testfunction(k(vl),x),A) .≈ ForwardDiff.gradient(x->testfunction(k(vl),x),A)) |
108 |
| - end |
109 |
| - end |
110 |
| - @testset "ISO" begin |
111 |
| - for k in kernels |
112 |
| - @test_broken Tracker.gradient(x->testfunction(k(x[1]),A,B),[l]) |
113 |
| - @test_broken Tracker.gradient(x->testfunction(k(l),x,B),A) |
114 |
| - @test_broken Tracker.gradient(x->testfunction(k(x[1]),A),[l]) |
115 |
| - @test_broken Tracker.gradient(x->testfunction(k(l),x),A) |
116 |
| - |
| 94 | +@testset "Params differentiation" begin |
| 95 | + for k in kernels |
| 96 | + @testset "$k" begin |
| 97 | + ps = params(k) |
| 98 | + @test_nowarn gradient(Val(:Zygote), () -> k(x, y), ps) |
117 | 99 | end
|
118 | 100 | end
|
119 | 101 | end
|
0 commit comments