Skip to content

Commit 909dfd8

Browse files
committed
Merge remote-tracking branch 'origin/master' into syntacticsugarallkernels
2 parents 58dc597 + c69bfe1 commit 909dfd8

File tree

6 files changed

+53
-1
lines changed

6 files changed

+53
-1
lines changed

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ makedocs(
1212
"Transform"=>"transform.md",
1313
"Metrics"=>"metrics.md",
1414
"Theory"=>"theory.md",
15+
"Custom Kernels"=>"create_kernel.md"
1516
"API"=>"api.md"]
1617
)
1718

docs/src/create_kernel.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
## Creating your own kernel
2+
3+
KernelFunctions.jl contains the most popular kernels already but you might want to make your own!
4+
5+
Here is for example how one can define the Squared Exponential Kernel again :
6+
7+
```julia
8+
struct MyKernel <: Kernel end
9+
10+
KernelFunctions.kappa(::MyKernel, d2::Real) = exp(-d2)
11+
KernelFunctions.metric(::MyKernel) = SqEuclidean()
12+
```
13+
14+
For a "Base" kernel, where the kernel function is simply a function applied on some metric between two vectors of real, you only need to:
15+
- Define your struct inheriting from `Kernel`.
16+
- Define a `kappa` function.
17+
- Define the metric used `SqEuclidean`, `DotProduct` etc. Note that the term "metric" is here overabused.
18+
- Optional : Define any parameter of your kernel as `trainable` by Flux.jl if you want to perform optimization on the parameters. We recommend wrapping all parameters in arrays to allow them to be mutable.
19+
20+
Once these functions are defined, you can use all the wrapping functions of KernelFuntions.jl

src/KernelFunctions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ export ExponentiatedKernel
1414
export MaternKernel, Matern32Kernel, Matern52Kernel
1515
export LinearKernel, PolynomialKernel
1616
export RationalQuadraticKernel, GammaRationalQuadraticKernel
17+
export MahalanobisKernel
1718
export KernelSum, KernelProduct
1819
export TransformedKernel, ScaledKernel
1920

@@ -44,7 +45,7 @@ include("distances/dotproduct.jl")
4445
include("distances/delta.jl")
4546
include("transform/transform.jl")
4647

47-
for k in ["exponential","matern","polynomial","constant","rationalquad","exponentiated","cosine"]
48+
for k in ["exponential","matern","polynomial","constant","rationalquad","exponentiated","cosine","maha"]
4849
include(joinpath("kernels",k*".jl"))
4950
end
5051
include("kernels/transformedkernel.jl")

src/kernels/maha.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""
2+
MahalanobisKernel(P::AbstractMatrix)
3+
4+
Mahalanobis distance-based kernel given by
5+
```math
6+
κ(x,y) = exp(-r^2), r^2 = maha(x,P,y) = (x-y)'*inv(P)*(x-y)
7+
```
8+
where the matrix P is the metric.
9+
10+
"""
11+
struct MahalanobisKernel{T<:Real, A<:AbstractMatrix{T}} <: BaseKernel
12+
P::A
13+
function MahalanobisKernel(P::AbstractMatrix{T}) where {T<:Real}
14+
LinearAlgebra.checksquare(P)
15+
new{T,typeof(P)}(P)
16+
end
17+
end
18+
19+
kappa::MahalanobisKernel, d::T) where {T<:Real} = exp(-d)
20+
21+
metric::MahalanobisKernel) = SqMahalanobis.P)

src/trainable.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ trainable(k::PolynomialKernel) = (k.d, k.c)
1616

1717
trainable(k::RationalQuadraticKernel) = (k.α,)
1818

19+
trainable(k::MahalanobisKernel) = (k.P,)
20+
1921
#### Composite kernels
2022

2123
trainable::KernelProduct) = κ.kernels

test/test_kernels.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,13 @@ x = rand()*2; v1 = rand(3); v2 = rand(3); id = IdentityTransform()
113113
@test kappa(PolynomialKernel(d=1.0,c=c),x) kappa(LinearKernel(c=c),x)
114114
end
115115
end
116+
@testset "Mahalanobis" begin
117+
P = rand(3,3)
118+
k = MahalanobisKernel(P)
119+
@test kappa(k,x) == exp(-x)
120+
@test k(v1,v2) exp(-sqmahalanobis(v1,v2, k.P))
121+
@test kappa(ExponentialKernel(),x) == kappa(k,x)
122+
end
116123
@testset "RationalQuadratic" begin
117124
@testset "RationalQuadraticKernel" begin
118125
α = 2.0

0 commit comments

Comments
 (0)