Skip to content

Commit 4be32b9

Browse files
committed
Added matern kernel
1 parent b7ff456 commit 4be32b9

File tree

5 files changed

+122
-17
lines changed

5 files changed

+122
-17
lines changed

dev/debugAD.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,33 @@ using Zygote, ForwardDiff, Tracker
33
using Test
44

55
dims = [10,5]
6-
76
A = rand(dims...)
87
B = rand(dims...)
98
K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
10-
kernels = [SquaredExponentialKernel]
11-
l = 2.0
9+
l = 0.1
1210
vl = l*ones(dims[1])
1311
testfunction(k,A,B) = sum(kernelmatrix(k,A,B))
1412
testfunction(k,A) = sum(kernelmatrix(k,A))
15-
13+
k = MaternKernel(vl)
14+
KernelFunctions.kappa(k,3)
1615
testfunction(SquaredExponentialKernel(vl),A)
16+
testfunction(MaternKernel(vl),A)
17+
@which kernelmatrix(MaternKernel(vl),A,B)
1718
#For debugging
1819
@info "Running Zygote gradients"
1920
Zygote.refresh()
2021
## Zygote
2122
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl)
23+
Zygote.gradient(x->testfunction(MaternKernel(x),A),vl)
2224
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)[1]
25+
Zygote.gradient(x->testfunction(MaternKernel(x),A,B),vl)[1]
2326
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),l)
27+
Zygote.gradient(x->testfunction(MaternKernel(x),A,B),l)
2428
Zygote.gradient(x->testfunction(SquaredExponentialKernel(x),A),l)
29+
Zygote.gradient(x->testfunction(MaternKernel(x),A),l)
2530
@info "Running Tracker gradients"
2631
## Tracker
27-
Tracker.gradient(x->testfunction(SquaredExponentialKernel(vl),x,B),A)
32+
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(vl),x,B),A)
2833
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(l),x[:,:]),A)
2934
# # Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl)
3035
# Tracker.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl)
@@ -34,6 +39,10 @@ Tracker.gradient(x->testfunction(SquaredExponentialKernel(vl),x,B),A)
3439
@info "Running ForwardDiff gradients"
3540
## ForwardDiff
3641
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x),A,B),vl) #
42+
ForwardDiff.gradient(x->testfunction(MaternKernel(x),A,B),vl) #
3743
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x),A),vl) #
44+
ForwardDiff.gradient(x->testfunction(MaternKernel(x),A),vl) #
3845
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x[1]),A,B),[l])
46+
ForwardDiff.gradient(x->testfunction(MaternKernel(x[1]),A,B),[l])
3947
ForwardDiff.gradient(x->testfunction(SquaredExponentialKernel(x[1]),A),[l])
48+
ForwardDiff.gradient(x->testfunction(MaternKernel(x[1]),A),[l])

src/KernelFunctions.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
module KernelFunctions
22

33
export kernelmatrix, kernelmatrix!, kappa
4-
export Kernel, SquaredExponentialKernel
4+
export Kernel, SquaredExponentialKernel, MaternKernel, Matern3_2Kernel, Matern5_2Kernel
55

66
export Transform, ScaleTransform
77

88
using Distances, LinearAlgebra
99
using Zygote: @adjoint
10+
using SpecialFunctions: lgamma, besselk
1011

1112
const defaultobs = 2
1213
abstract type Kernel{T,Tr} end
@@ -16,7 +17,7 @@ include("common.jl")
1617
include("transform/transform.jl")
1718
include("kernelmatrix.jl")
1819

19-
kernels = ["squaredexponential"]
20+
kernels = ["squaredexponential","matern"]
2021
for k in kernels
2122
include(joinpath("kernels",k*".jl"))
2223
end

src/kernelmatrix.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -83,32 +83,29 @@ end
8383
kernelmatrix(κ::Kernel, X::Matrix ; obsdim::Int=2, symmetrize::Bool=true)
8484
```
8585
Calculate the kernel matrix of `X` with respect to kernel `κ`.
86+
# USED
8687
"""
8788
function kernelmatrix(
88-
κ::Kernel{T,<:Transform{A}},
89+
κ::Kernel{T,<:Transform},
8990
X::AbstractMatrix;
9091
obsdim::Int = defaultobs,
9192
symmetrize::Bool = true
9293
) where {T,A}
9394
# Tₖ = typeof(zero(eltype(X))*zero(T))
9495
# m = size(X,obsdim)
95-
K = map(x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),dims=obsdim))
96-
# K = Matrix{Tₖ}(undef,m,m)
97-
# for i in 1:m
98-
# tx = transform(κ,@view X[i,:])
99-
# for j in 1:i
100-
# K[i,j] = kappa(κ,kernel(κ,tx,transform(@view X[j,:])))
101-
# end
102-
# end
96+
#WARNING TEMP FIX
97+
= transform(κ,X,obsdim)
98+
K = map(x->kappa(κ,x),pairwise(metric(κ),X̂,X̂,dims=obsdim))
99+
# K = map(x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),dims=obsdim))
103100
return K
104-
# return kernelmatrix!(Matrix{Tₖ}(undef,m,m),κ,X,obsdim=obsdim,symmetrize=symmetrize)
105101
end
106102

107103
"""
108104
```
109105
kernelmatrix(κ::Kernel, X::Matrix, Y::Matrix; obsdim::Int=2)
110106
```
111107
Calculate the base matrix of `X` and `Y` with respect to kernel `κ`.
108+
# USED
112109
"""
113110
function kernelmatrix(
114111
κ::Kernel{T},

src/kernels/matern.jl

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
"""
2+
MaternKernel([[ρ=1],ν=3/2])
3+
4+
The matern kernel is an isotropic Mercer kernel given by the formula:
5+
6+
```
7+
κ(x,y) = 2^{1-ν}/Γ(ν)*(√(2ν)‖x-y‖/ρ)^ν K_ν(√(2ν)‖x-y‖/ρ)
8+
```
9+
10+
For `ν=n+1/2, n=0,1,2,...` it can be simplified (it will be converted automatically).
11+
`ρ` is a lengthscale parameter.
12+
13+
# Examples
14+
15+
```jldoctest; setup = :(using KernelFunctions)
16+
julia> MaternKernel()
17+
Matern3_2Kernel{Float64,Float64}(1.0)
18+
19+
julia> MaternKernel(2.0f0,3.0)
20+
MaternKernel{Float32,Float32}(2.0,3.0)
21+
22+
julia> MaternKernel([2.0,3.0],5/2)
23+
Matern5_2Kernel{Float64,Array{Float64}}([2.0,3.0])
24+
```
25+
"""
26+
struct MaternKernel{T,Tr<:Transform} <: Kernel{T,Tr}
27+
transform::Tr
28+
metric::SemiMetric
29+
ν::Real
30+
function MaternKernel{T,Tr}(transform::Tr::Real) where {T,Tr<:Transform}
31+
return new{T,Tr}(transform,SqEuclidean(),ν)
32+
end
33+
end
34+
35+
function MaternKernel::T₁=1.0::T₂=1.5) where {T₁<:Real,T₂<:Real}
36+
@check_args(MaternKernel, ν, ν > zero(T₂), "ν > 0")
37+
if ν == 0.5
38+
ExponentialKernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
39+
elseif ν == 1.5
40+
Matern3_2Kernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
41+
elseif ν == 2.5
42+
Matern5_2Kernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
43+
else
44+
MaternKernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ),ν)
45+
end
46+
end
47+
48+
function MaternKernel::A::T=1.5) where {A<:AbstractVector{<:Real},T<:Real}
49+
@check_args(MaternKernel, ν, ν > zero(T), "ν > 0")
50+
if ν == 0.5
51+
ExponentialKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
52+
elseif ν == 1.5
53+
Matern3_2Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
54+
elseif ν == 2.5
55+
Matern5_2Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
56+
else
57+
MaternKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ),ν)
58+
end
59+
end
60+
61+
function MaternKernel(t::T₁::T₂=1.5) where {T₁<:Transform,T₂<:Real}
62+
@check_args(MaternKernel, ν, ν > zero(T₂), "ν > 0")
63+
if ν == 0.5
64+
ExponentialKernel{eltype(t),T₁}(ScaleTransform(ρ))
65+
elseif ν == 1.5
66+
Matern3_2Kernel{eltype(t),T₁}(ScaleTransform(ρ))
67+
elseif ν == 2.5
68+
Matern5_2Kernel{eltype(t),T₁}(ScaleTransform(ρ))
69+
else
70+
MaternKernel{eltype(t),T₁}(ScaleTransform(ρ),ν)
71+
end
72+
end
73+
74+
@inline kappa::MaternKernel, d::Real) where {T} = exp((1.0-κ.ν)*log2 - lgamma.ν) - κ.ν*log(sqrt(2κ.ν*d²)))*besselk.ν,sqrt(2κ.ν*d²))
75+
76+
77+
struct Matern3_2Kernel{T,Tr<:Transform} <: Kernel{T,Tr}
78+
transform::Tr
79+
metric::SemiMetric
80+
function Matern3_2Kernel{T,Tr}(transform::Tr) where {T,Tr<:Transform}
81+
return new{T,Tr}(transform,SqEuclidean())
82+
end
83+
end
84+
85+
@inline kappa::Matern3_2Kernel, d²::T) where {T<:Real} = (1+sqrt(3*d²))*exp(-sqrt(3*d²))
86+
87+
struct Matern5_2Kernel{T,Tr<:Transform} <: Kernel{T,Tr}
88+
transform::Tr
89+
metric::SemiMetric
90+
function Matern5_2Kernel{T,Tr}(transform::Tr) where {T,Tr<:Transform}
91+
return new{T,Tr}(transform,SqEuclidean())
92+
end
93+
end
94+
95+
@inline kappa::Matern5_2Kernel, d²::Real) where {T} = (1+sqrt(5*d²)+5*/3)*exp(-sqrt(5*d²))

test/kernelmatrix.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
using Distances
2+
13
dims = [10,5]
24

35
A = rand(dims...)
46
B = rand(dims...)
57
K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
68
k = SquaredExponentialKernel()
9+
k = MaternKernel()
710

811
@testset "Inplace Kernel Matrix" begin
912
for obsdim in [1,2]

0 commit comments

Comments
 (0)