Skip to content

Commit 2b27baa

Browse files
committed
Renamed Matern3_2 and Matern5_2 and added k(X,Y)
1 parent c3554df commit 2b27baa

File tree

4 files changed

+68
-20
lines changed

4 files changed

+68
-20
lines changed

src/KernelFunctions.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module KernelFunctions
22

33
export kernelmatrix, kernelmatrix!, kappa
4-
export Kernel, SquaredExponentialKernel, MaternKernel, Matern3_2Kernel, Matern5_2Kernel
4+
export Kernel, SquaredExponentialKernel, MaternKernel, Matern32Kernel, Matern52Kernel
55

66
export Transform, ScaleTransform
77

@@ -15,7 +15,6 @@ abstract type Kernel{T,Tr} end
1515

1616
include("zygote_rules.jl")
1717
include("utils.jl")
18-
include("common.jl")
1918
include("transform/transform.jl")
2019
include("kernelmatrix.jl")
2120

@@ -24,4 +23,7 @@ for k in kernels
2423
include(joinpath("kernels",k*".jl"))
2524
end
2625

26+
include("generic.jl")
27+
28+
2729
end

src/generic.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11

22
@inline metric::Kernel) = κ.metric
3-
@inline::K)(d::Real) where {K<:Kernel} = kappa(κ,d)
4-
3+
kernels =
4+
for k in [:SquaredExponentialKernel,:MaternKernel,:Matern32Kernel,:Matern52Kernel]
5+
eval(quote
6+
@inline::$k)(d::Real) = kappa(κ,d)
7+
@inline::$k)(x::AbstractVector{T},y::AbstractVector{T}) where {T} = kernel(κ,evaluate(κ.(metric),x,y))
8+
@inline::$k)(x::AbstractMatrix{T},y::AbstractMatrix{T},obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ,x,y,obsdim=obsdim)
9+
end)
10+
end
511
### Transform generics
612

713
@inline transform::Kernel) = κ.transform
814
@inline transform::Kernel,x::AbstractVecOrMat) = transform.transform,x)
915
@inline transform::Kernel,x::AbstractVecOrMat,obsdim::Int) = transform.transform,x,obsdim)
10-
11-
@inline::Kernel)(x::AbstractVector{<:Real},y::AbstractVector{<:Real}) = kernel(κ,evaluate(κ.(metric),x,y))

src/kernels/matern.jl

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@ function MaternKernel(ρ::T₁=1.0,ν::T₂=1.5) where {T₁<:Real,T₂<:Real}
3737
if ν == 0.5
3838
ExponentialKernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
3939
elseif ν == 1.5
40-
Matern3_2Kernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
40+
Matern32Kernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
4141
elseif ν == 2.5
42-
Matern5_2Kernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
42+
Matern52Kernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
43+
elseif ν == Inf
44+
SquaredExponentialKernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
4345
else
4446
MaternKernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ),ν)
4547
end
@@ -50,9 +52,11 @@ function MaternKernel(ρ::A,ν::T=1.5) where {A<:AbstractVector{<:Real},T<:Real}
5052
if ν == 0.5
5153
ExponentialKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
5254
elseif ν == 1.5
53-
Matern3_2Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
55+
Matern32Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
5456
elseif ν == 2.5
55-
Matern5_2Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
57+
Matern52Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
58+
elseif ν == Inf
59+
SquaredExponentialKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
5660
else
5761
MaternKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ),ν)
5862
end
@@ -61,35 +65,61 @@ end
6165
function MaternKernel(t::T₁::T₂=1.5) where {T₁<:Transform,T₂<:Real}
6266
@check_args(MaternKernel, ν, ν > zero(T₂), "ν > 0")
6367
if ν == 0.5
64-
ExponentialKernel{eltype(t),T₁}(ScaleTransform(ρ))
68+
ExponentialKernel{eltype(t),T₁}(t)
6569
elseif ν == 1.5
66-
Matern3_2Kernel{eltype(t),T₁}(ScaleTransform(ρ))
70+
Matern32Kernel{eltype(t),T₁}(t)
6771
elseif ν == 2.5
68-
Matern5_2Kernel{eltype(t),T₁}(ScaleTransform(ρ))
72+
Matern52Kernel{eltype(t),T₁}(t)
73+
elseif ν == Inf
74+
SquaredExponentialKernel{eltype(t),T₁}(t)
6975
else
70-
MaternKernel{eltype(t),T₁}(ScaleTransform(ρ),ν)
76+
MaternKernel{eltype(t),T₁}(t,ν)
7177
end
7278
end
7379

7480
@inline kappa::MaternKernel, d::Real) where {T} = exp((1.0-κ.ν)*logtwo - lgamma.ν) - κ.ν*log(sqrt(2κ.ν)*d))*besselk.ν,sqrt(2κ.ν)*d)
7581

7682

77-
struct Matern3_2Kernel{T,Tr<:Transform} <: Kernel{T,Tr}
83+
struct Matern32Kernel{T,Tr<:Transform} <: Kernel{T,Tr}
7884
transform::Tr
7985
metric::SemiMetric
80-
function Matern3_2Kernel{T,Tr}(transform::Tr) where {T,Tr<:Transform}
86+
function Matern32Kernel{T,Tr}(transform::Tr) where {T,Tr<:Transform}
8187
return new{T,Tr}(transform,Euclidean())
8288
end
8389
end
8490

85-
@inline kappa::Matern3_2Kernel, d::T) where {T<:Real} = (1+sqrt(3)*d)*exp(-sqrt(3)*d)
91+
function Matern32Kernel::T) where {T<:Real}
92+
Matern32Kernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
93+
end
94+
95+
function Matern32Kernel::A) where {A<:AbstractVector{<:Real}}
96+
Matern32Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
97+
end
98+
99+
function Matern32Kernel(t::Transform)
100+
Matern52Kernel{eltype(A),ScaleTransform{A}}(t)
101+
end
102+
103+
@inline kappa::Matern32Kernel, d::T) where {T<:Real} = (1+sqrt(3)*d)*exp(-sqrt(3)*d)
86104

87-
struct Matern5_2Kernel{T,Tr<:Transform} <: Kernel{T,Tr}
105+
struct Matern52Kernel{T,Tr<:Transform} <: Kernel{T,Tr}
88106
transform::Tr
89107
metric::SemiMetric
90-
function Matern5_2Kernel{T,Tr}(transform::Tr) where {T,Tr<:Transform}
108+
function Matern52Kernel{T,Tr}(transform::Tr) where {T,Tr<:Transform}
91109
return new{T,Tr}(transform,Euclidean())
92110
end
93111
end
94112

95-
@inline kappa::Matern5_2Kernel, d::Real) where {T} = (1+sqrt(5)*d+5*d^2/3)*exp(-sqrt(5)*d)
113+
function Matern52Kernel::T) where {T<:Real}
114+
Matern52Kernel{T₁,ScaleTransform{T₁}}(ScaleTransform(ρ))
115+
end
116+
117+
function Matern52Kernel::A) where {A<:AbstractVector{<:Real}}
118+
Matern52Kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
119+
end
120+
121+
function Matern52Kernel(t::Transform)
122+
Matern52Kernel{eltype(A),ScaleTransform{A}}(t)
123+
end
124+
125+
@inline kappa::Matern52Kernel, d::Real) where {T} = (1+sqrt(5)*d+5*d^2/3)*exp(-sqrt(5)*d)

test/constructors.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,15 @@ vl = [l,l]
1010
@test KernelFunctions.transform(SquaredExponentialKernel(l)) == ScaleTransform(l)
1111
@test KernelFunctions.transform(SquaredExponentialKernel(vl)) == ScaleTransform(vl)
1212
end
13+
14+
@testset "MaternKernel" begin
15+
@test KernelFunctions.metric(MaternKernel(l)) == Euclidean()
16+
@test KernelFunctions.metric(MaternKernel(l,1.5)) == Euclidean()
17+
@test KernelFunctions.metric(MaternKernel(l,2.5)) == Euclidean()
18+
@test KernelFunctions.transform(MaternKernel(l)) == ScaleTransform(l)
19+
@test KernelFunctions.transform(MaternKernel(vl)) == ScaleTransform(vl)
20+
@test isa(MaternKernel(),Matern32Kernel)
21+
@test isa(MaternKernel(1.0,1.0),MaternKernel)
22+
@test isa(MaternKernel(1.0,1.5),Matern32Kernel)
23+
@test isa(MaternKernel(1.0,2.5),Matern52Kernel)
24+
end

0 commit comments

Comments
 (0)