Skip to content

Commit 4a16377

Browse files
authored
Merge pull request #29 from devmotion/transform
Ease implementation of custom kernels
2 parents 58d42f5 + 2ffd7e4 commit 4a16377

File tree

5 files changed

+38
-7
lines changed

5 files changed

+38
-7
lines changed

src/generic.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,29 @@ Base.length(::Kernel) = 1
55
Base.iterate(k::Kernel) = (k,nothing)
66
Base.iterate(k::Kernel, ::Any) = nothing
77

8+
# default fallback for evaluating a kernel with two arguments (such as vectors etc)
9+
kappa::Kernel, x, y) = kappa(κ, evaluate(metric(κ), transform(κ, x), transform(κ, y)))
10+
811
### Syntactic sugar for creating matrices and using kernel functions
912
for k in [:ExponentialKernel,:SqExponentialKernel,:GammaExponentialKernel,:MaternKernel,:Matern32Kernel,:Matern52Kernel,:LinearKernel,:PolynomialKernel,:ExponentiatedKernel,:ZeroKernel,:WhiteKernel,:ConstantKernel,:RationalQuadraticKernel,:GammaRationalQuadraticKernel]
1013
@eval begin
1114
@inline::$k)(d::Real) = kappa(κ,d) #TODO Add test
12-
@inline::$k)(x::AbstractVector{<:Real},y::AbstractVector{<:Real}) = kappa(κ,evaluate(metric(κ),transform(κ,x),transform(κ,y)))
15+
@inline::$k)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) = kappa(κ, x, y)
1316
@inline::$k)(X::AbstractMatrix{T},Y::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,X,Y,obsdim=obsdim)
1417
@inline::$k)(X::AbstractMatrix{T};obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,X,obsdim=obsdim)
1518
end
1619
end
1720

1821
### Transform generics
1922
@inline transform::Kernel) = κ.transform
20-
@inline transform::Kernel,x::AbstractVecOrMat) = transform(κ.transform,x)
21-
@inline transform::Kernel,x::AbstractVecOrMat,obsdim::Int) = transform(κ.transform,x,obsdim)
23+
@inline transform::Kernel, x) = transform(transform(κ), x)
24+
@inline transform::Kernel, x, obsdim::Int) = transform(transform(κ), x, obsdim)
2225

2326
## Constructors for kernels without parameters
2427
for kernel in [:ExponentialKernel,:SqExponentialKernel,:Matern32Kernel,:Matern52Kernel,:ExponentiatedKernel]
2528
@eval begin
26-
$kernel::Real=1.0) = $kernel(ScaleTransform(ρ))
29+
$kernel() = $kernel(IdentityTransform())
30+
$kernel::Real) = $kernel(ScaleTransform(ρ))
2731
$kernel::AbstractVector{<:Real}) = $kernel(ARDTransform(ρ))
2832
end
2933
end

src/transform/transform.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct IdentityTransform <: Transform end
2626
params(t::IdentityTransform) = nothing
2727
duplicate(t::IdentityTransform,θ) = t
2828

29-
transform(t::IdentityTransform,x::AbstractArray,obsdim::Int=defaultobs) = x #TODO add test
29+
transform(t::IdentityTransform, x, obsdim::Int=defaultobs) = x #TODO add test
3030

3131
### TODO Maybe defining adjoints could help but so far it's not working
3232

src/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ feature_dim(obsdim::Int) = obsdim == 1 ? 2 : 1
2828

2929
base_kernel(k::Kernel) = eval(nameof(typeof(k)))
3030

31-
base_transform(k::Kernel) = base_transform(k.transform)
31+
base_transform(k::Kernel) = base_transform(transform(k))
3232
base_transform(t::Transform) = eval(nameof(typeof(t)))
3333
_tail(v::AbstractVector) = view(v,2:length(v))
3434

@@ -56,4 +56,4 @@ dim(k::Kernel) = length(params(k))
5656
For a kernel return a tuple with parameters of the transform followed by the specific parameters of the kernel
5757
For a transform return its parameters, for a `ChainTransform` return a vector of `params(t)`.
5858
"""
59-
params(k::Kernel) = (params(k.transform),)
59+
params(k::Kernel) = (params(transform(k)),)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ include("test_distances.jl")
1212
include("test_kernels.jl")
1313
include("test_generic.jl")
1414
include("test_adjoints.jl")
15+
include("test_custom.jl")
1516
#include("types.jl")
1617
end

test/test_custom.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using KernelFunctions
2+
using Test
3+
4+
# minimal definition of a custom kernel
5+
struct MyKernel <: Kernel{IdentityTransform} end
6+
7+
KernelFunctions.kappa(::MyKernel, d2::Real) = exp(-d2)
8+
KernelFunctions.metric(::MyKernel) = SqEuclidean()
9+
KernelFunctions.transform(::MyKernel) = IdentityTransform()
10+
11+
@test kappa(MyKernel(), 3) == kappa(SqExponentialKernel(), 3)
12+
@test kappa(MyKernel(), 1, 3) == kappa(SqExponentialKernel(), 1, 3)
13+
@test kappa(MyKernel(), [1, 2], [3, 4]) == kappa(SqExponentialKernel(), [1, 2], [3, 4])
14+
@test kernelmatrix(MyKernel(), [1 2; 3 4], [5 6; 7 8]) == kernelmatrix(SqExponentialKernel(), [1 2; 3 4], [5 6; 7 8])
15+
@test kernelmatrix(MyKernel(), [1 2; 3 4]) == kernelmatrix(SqExponentialKernel(), [1 2; 3 4])
16+
17+
# some syntactic sugar
18+
::MyKernel)(d::Real) = kappa(κ, d)
19+
::MyKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) = kappa(κ, x, y)
20+
::MyKernel)(X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}; obsdim = 2) = kernelmatrix(κ, X, Y; obsdim = obsdim)
21+
::MyKernel)(X::AbstractMatrix{<:Real}; obsdim = 2) = kernelmatrix(κ, X; obsdim = obsdim)
22+
23+
@test MyKernel()(3) == SqExponentialKernel()(3)
24+
@test MyKernel()([1, 2], [3, 4]) == SqExponentialKernel()([1, 2], [3, 4])
25+
@test MyKernel()([1 2; 3 4], [5 6; 7 8]) == SqExponentialKernel()([1 2; 3 4], [5 6; 7 8])
26+
@test MyKernel()([1 2; 3 4]) == SqExponentialKernel()([1 2; 3 4])

0 commit comments

Comments
 (0)