Skip to content

Commit 2ffd7e4

Browse files
committed
Add test of custom kernel
1 parent a31b097 commit 2ffd7e4

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ include("test_distances.jl")
1212
include("test_kernels.jl")
1313
include("test_generic.jl")
1414
include("test_adjoints.jl")
15+
include("test_custom.jl")
1516
#include("types.jl")
1617
end

test/test_custom.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using KernelFunctions
2+
using Test
3+
4+
# minimal definition of a custom kernel
5+
struct MyKernel <: Kernel{IdentityTransform} end
6+
7+
KernelFunctions.kappa(::MyKernel, d2::Real) = exp(-d2)
8+
KernelFunctions.metric(::MyKernel) = SqEuclidean()
9+
KernelFunctions.transform(::MyKernel) = IdentityTransform()
10+
11+
@test kappa(MyKernel(), 3) == kappa(SqExponentialKernel(), 3)
12+
@test kappa(MyKernel(), 1, 3) == kappa(SqExponentialKernel(), 1, 3)
13+
@test kappa(MyKernel(), [1, 2], [3, 4]) == kappa(SqExponentialKernel(), [1, 2], [3, 4])
14+
@test kernelmatrix(MyKernel(), [1 2; 3 4], [5 6; 7 8]) == kernelmatrix(SqExponentialKernel(), [1 2; 3 4], [5 6; 7 8])
15+
@test kernelmatrix(MyKernel(), [1 2; 3 4]) == kernelmatrix(SqExponentialKernel(), [1 2; 3 4])
16+
17+
# some syntactic sugar
18+
::MyKernel)(d::Real) = kappa(κ, d)
19+
::MyKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) = kappa(κ, x, y)
20+
::MyKernel)(X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}; obsdim = 2) = kernelmatrix(κ, X, Y; obsdim = obsdim)
21+
::MyKernel)(X::AbstractMatrix{<:Real}; obsdim = 2) = kernelmatrix(κ, X; obsdim = obsdim)
22+
23+
@test MyKernel()(3) == SqExponentialKernel()(3)
24+
@test MyKernel()([1, 2], [3, 4]) == SqExponentialKernel()([1, 2], [3, 4])
25+
@test MyKernel()([1 2; 3 4], [5 6; 7 8]) == SqExponentialKernel()([1 2; 3 4], [5 6; 7 8])
26+
@test MyKernel()([1 2; 3 4]) == SqExponentialKernel()([1 2; 3 4])

0 commit comments

Comments
 (0)