|
2 | 2 |
|
3 | 3 | rng = MersenneTwister(123456)
|
4 | 4 | dims = [10,5]
|
5 |
| - |
6 |
| - A = rand(rng, dims...) |
7 |
| - B = rand(rng, dims...) |
| 5 | + vA = [rand(rng, dims[1]) for _ in 1:dims[2]] |
| 6 | + A = hcat(vA...) |
| 7 | + vB = [rand(rng, dims[1]) for _ in 1:dims[2]] |
| 8 | + B = hcat(vB...) |
| 9 | + x = rand(rng, dims[1]) |
| 10 | + X = collect(reshape(x, 1, :)) |
| 11 | + y = rand(rng, dims[2]) |
| 12 | + Y = collect(reshape(y, 1 , :)) |
| 13 | + KX = zeros(dims[1], dims[1]) |
| 14 | + KXY = zeros(dims[1], dims[2]) |
8 | 15 | C = rand(rng, 8, 9)
|
9 | 16 | K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
|
10 | 17 | Kdiag = [zeros(dims[1]),zeros(dims[2])]
|
11 | 18 | s = rand(rng)
|
12 | 19 | k = SqExponentialKernel()
|
| 20 | + struct baseSE <: KernelFunctions.BaseKernel end |
| 21 | + (k::baseSE)(x, y) = exp(-evaluate(SqEuclidean(), x, y)) |
| 22 | + newk = baseSE() |
13 | 23 | kt = transform(SqExponentialKernel(),s)
|
14 | 24 |
|
15 | 25 | @testset "Kernel Matrix Operations" begin
|
16 | 26 | @testset "Inplace Kernel Matrix" begin
|
| 27 | + @test kernelmatrix!(KX, k, x) ≈ kernelmatrix!(KX, k, X) |
| 28 | + @test kernelmatrix!(KXY, k, x, y) ≈ kernelmatrix!(KXY, k, X, Y) |
| 29 | + @test kernelmatrix!(K[2], k, vA) ≈ kernelmatrix(k, A) atol = 1e-5 |
| 30 | + @test kernelmatrix!(K[2], k, vA, vB) ≈ kernelmatrix(k, A, B) atol = 1e-5 |
17 | 31 | for obsdim in [1,2]
|
| 32 | + @show obsdim |
18 | 33 | @test kernelmatrix!(K[obsdim], k, A, B, obsdim = obsdim) == kernelmatrix(k, A, B, obsdim = obsdim)
|
19 | 34 | @test kernelmatrix!(K[obsdim], k, A, obsdim = obsdim) == kernelmatrix(k, A, obsdim = obsdim)
|
20 | 35 | @test kerneldiagmatrix!(Kdiag[obsdim], k, A, obsdim = obsdim) == kerneldiagmatrix(k, A, obsdim = obsdim)
|
21 | 36 | @test_throws DimensionMismatch kernelmatrix!(K[obsdim], k, A, C, obsdim=obsdim)
|
22 | 37 | @test_throws DimensionMismatch kernelmatrix!(K[obsdim], k, C, obsdim=obsdim)
|
23 | 38 | @test_throws DimensionMismatch kerneldiagmatrix!(Kdiag[obsdim], k, C, obsdim=obsdim)
|
| 39 | + @test kernelmatrix!(K[obsdim], newk, A, B, obsdim = obsdim) ≈ kernelmatrix(k, A, B, obsdim = obsdim) |
| 40 | + @test kernelmatrix!(K[obsdim], newk, A, obsdim = obsdim) ≈ kernelmatrix(k, A, obsdim = obsdim) |
| 41 | + @test kerneldiagmatrix!(Kdiag[obsdim], newk, A, obsdim = obsdim) ≈ kerneldiagmatrix(k, A, obsdim = obsdim) |
24 | 42 | end
|
25 | 43 | end
|
26 | 44 | @testset "Kernel matrix" begin
|
| 45 | + @test kernelmatrix(k, x) ≈ kernelmatrix(k, X) |
| 46 | + @test kernelmatrix(k, x, y) ≈ kernelmatrix(k, X, Y) |
| 47 | + @test kernelmatrix(k, vA) ≈ kernelmatrix(k, A) atol = 1e-5 |
| 48 | + @test kernelmatrix(k, vA, vB) ≈ kernelmatrix(k, A, B) atol = 1e-5 |
27 | 49 | for obsdim in [1,2]
|
28 | 50 | @test kernelmatrix(k,A,B,obsdim=obsdim) == kappa.(k,pairwise(KernelFunctions.metric(k),A,B,dims=obsdim))
|
29 | 51 | @test kernelmatrix(k,A,obsdim=obsdim) == kappa.(k,pairwise(KernelFunctions.metric(k),A,dims=obsdim))
|
|
32 | 54 | @test k(A,obsdim=obsdim) == kernelmatrix(k,A,obsdim=obsdim)
|
33 | 55 | # @test KernelFunctions._kernel(k,1.0,2.0) == KernelFunctions._kernel(k,[1.0],[2.0])
|
34 | 56 | @test_throws DimensionMismatch kernelmatrix(k,A,C,obsdim=obsdim)
|
| 57 | + @test kernelmatrix!(K[obsdim], newk, A, B, obsdim = obsdim) ≈ kernelmatrix(k, A, B, obsdim = obsdim) |
| 58 | + @test kernelmatrix!(K[obsdim], newk, A, obsdim = obsdim) ≈ kernelmatrix(k, A, obsdim = obsdim) |
| 59 | + @test kerneldiagmatrix!(Kdiag[obsdim], newk, A, obsdim = obsdim) ≈ kerneldiagmatrix(k, A, obsdim = obsdim) |
35 | 60 | end
|
36 | 61 | end
|
37 | 62 | @testset "Transformed Kernel Matrix Operations" begin
|
|
0 commit comments