Skip to content

Commit fdd317f

Browse files
authored
Merge pull request #83 from theogf/general_kernelmatrix
[WIP] Rework on kernelmatrix to work with Vectors and more complex kernels
2 parents d52def6 + efa1479 commit fdd317f

26 files changed

+219
-273
lines changed

src/KernelFunctions.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ export duplicate, set! # Helpers
1010

1111
export Kernel
1212
export ConstantKernel, WhiteKernel, EyeKernel, ZeroKernel
13+
export CosineKernel
1314
export SqExponentialKernel, RBFKernel, GaussianKernel, SEKernel
1415
export LaplacianKernel, ExponentialKernel, GammaExponentialKernel
1516
export ExponentiatedKernel
@@ -43,6 +44,7 @@ Abstract type defining a slice-wise transformation on an input matrix
4344
abstract type Transform end
4445
abstract type Kernel end
4546
abstract type BaseKernel <: Kernel end
47+
abstract type SimpleKernel <: BaseKernel end
4648

4749
include("utils.jl")
4850
include("distances/dotproduct.jl")

src/basekernels/constant.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Create a kernel that always returning zero
77
```
88
The output type depends of `x` and `y`
99
"""
10-
struct ZeroKernel <: BaseKernel end
10+
struct ZeroKernel <: SimpleKernel end
1111

1212
kappa::ZeroKernel, d::T) where {T<:Real} = zero(T)
1313

@@ -24,7 +24,7 @@ Base.show(io::IO, ::ZeroKernel) = print(io, "Zero Kernel")
2424
```
2525
Kernel function working as an equivalent to add white noise. Can also be called via `EyeKernel()`
2626
"""
27-
struct WhiteKernel <: BaseKernel end
27+
struct WhiteKernel <: SimpleKernel end
2828

2929
"""
3030
EyeKernel()
@@ -48,7 +48,7 @@ Kernel function always returning a constant value `c`
4848
κ(x,y) = c
4949
```
5050
"""
51-
struct ConstantKernel{Tc<:Real} <: BaseKernel
51+
struct ConstantKernel{Tc<:Real} <: SimpleKernel
5252
c::Vector{Tc}
5353
function ConstantKernel(;c::T=1.0) where {T<:Real}
5454
new{T}([c])

src/basekernels/cosine.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ The cosine kernel is a stationary kernel for a sinusoidal given by
66
κ(x,y) = cos( π * (x-y) )
77
```
88
"""
9-
struct CosineKernel <: BaseKernel end
9+
struct CosineKernel <: SimpleKernel end
1010

1111
kappa::CosineKernel, d::Real) = cospi(d)
1212
metric(::CosineKernel) = Euclidean()

src/basekernels/exponential.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ Can also be called via `SEKernel`, `GaussianKernel` or `SEKernel`.
99
See also [`ExponentialKernel`](@ref) for a
1010
related form of the kernel or [`GammaExponentialKernel`](@ref) for a generalization.
1111
"""
12-
struct SqExponentialKernel <: BaseKernel end
12+
struct SqExponentialKernel <: SimpleKernel end
1313

1414
kappa::SqExponentialKernel, d²::Real) = exp(-d²)
1515
iskroncompatible(::SqExponentialKernel) = true
@@ -30,7 +30,7 @@ The exponential kernel is a Mercer kernel given by the formula:
3030
κ(x,y) = exp(-‖x-y‖)
3131
```
3232
"""
33-
struct ExponentialKernel <: BaseKernel end
33+
struct ExponentialKernel <: SimpleKernel end
3434

3535
kappa::ExponentialKernel, d::Real) = exp(-d)
3636
iskroncompatible(::ExponentialKernel) = true
@@ -51,7 +51,7 @@ The γ-exponential kernel is an isotropic Mercer kernel given by the formula:
5151
Where `γ > 0`, (the keyword `γ` can be replaced by `gamma`)
5252
For `γ = 1`, see `SqExponentialKernel` and `γ = 0.5`, see `ExponentialKernel`
5353
"""
54-
struct GammaExponentialKernel{Tγ<:Real} <: BaseKernel
54+
struct GammaExponentialKernel{Tγ<:Real} <: SimpleKernel
5555
γ::Vector{Tγ}
5656
function GammaExponentialKernel(; gamma::T=2.0, γ::T=gamma) where {T<:Real}
5757
@check_args(GammaExponentialKernel, γ, γ >= zero(T), "γ > 0")

src/basekernels/exponentiated.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ The exponentiated kernel is a Mercer kernel given by:
66
κ(x,y) = exp(xᵀy)
77
```
88
"""
9-
struct ExponentiatedKernel <: BaseKernel end
9+
struct ExponentiatedKernel <: SimpleKernel end
1010

1111
kappa::ExponentiatedKernel, xᵀy::Real) = exp(xᵀy)
1212
metric(::ExponentiatedKernel) = DotProduct()

src/basekernels/fbm.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,6 @@ function kernelmatrix!(
6666
return K
6767
end
6868

69-
## Apply kernel on two vectors ##
70-
function _kernel(
71-
κ::FBMKernel,
72-
x::AbstractVector,
73-
y::AbstractVector;
74-
obsdim::Int = defaultobs
75-
)
76-
@assert length(x) == length(y) "x and y don't have the same dimension!"
77-
return kappa(κ, x, y)
78-
end
79-
8069
function kappa::FBMKernel, x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
8170
modX = sum(abs2, x)
8271
modY = sum(abs2, y)

src/basekernels/maha.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Mahalanobis distance-based kernel given by
88
where the matrix P is the metric.
99
1010
"""
11-
struct MahalanobisKernel{T<:Real, A<:AbstractMatrix{T}} <: BaseKernel
11+
struct MahalanobisKernel{T<:Real, A<:AbstractMatrix{T}} <: SimpleKernel
1212
P::A
1313
function MahalanobisKernel(P::AbstractMatrix{T}) where {T<:Real}
1414
LinearAlgebra.checksquare(P)

src/basekernels/matern.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ The matern kernel is a Mercer kernel given by the formula:
77
```
88
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use [`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`, [`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`.
99
"""
10-
struct MaternKernel{Tν<:Real} <: BaseKernel
10+
struct MaternKernel{Tν<:Real} <: SimpleKernel
1111
ν::Vector{Tν}
1212
function MaternKernel(;nu::T=1.5, ν::T=nu) where {T<:Real}
1313
@check_args(MaternKernel, ν, ν > zero(T), "ν > 0")
@@ -37,7 +37,7 @@ The matern 3/2 kernel is a Mercer kernel given by the formula:
3737
κ(x,y) = (1+√(3)‖x-y‖)exp(-√(3)‖x-y‖)
3838
```
3939
"""
40-
struct Matern32Kernel <: BaseKernel end
40+
struct Matern32Kernel <: SimpleKernel end
4141

4242
kappa::Matern32Kernel, d::Real) = (1 + sqrt(3) * d) * exp(-sqrt(3) * d)
4343
metric(::Matern32Kernel) = Euclidean()
@@ -52,7 +52,7 @@ The matern 5/2 kernel is a Mercer kernel given by the formula:
5252
κ(x,y) = (1+√(5)‖x-y‖ + 5/3‖x-y‖^2)exp(-√(5)‖x-y‖)
5353
```
5454
"""
55-
struct Matern52Kernel <: BaseKernel end
55+
struct Matern52Kernel <: SimpleKernel end
5656

5757
kappa::Matern52Kernel, d::Real) = (1 + sqrt(5) * d + 5 * d^2 / 3) * exp(-sqrt(5) * d)
5858
metric(::Matern52Kernel) = Euclidean()

src/basekernels/periodic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ Periodic Kernel as described in http://www.inference.org.uk/mackay/gpB.pdf eq. 4
88
κ(x,y) = exp( - 0.5 sum_i(sin (π(x_i - y_i))/r_i))
99
```
1010
"""
11-
struct PeriodicKernel{T} <: BaseKernel
11+
struct PeriodicKernel{T} <: SimpleKernel
1212
r::Vector{T}
1313
function PeriodicKernel(; r::AbstractVector{T} = ones(Float64, 1)) where {T<:Real}
1414
@assert all(r .> 0)

src/basekernels/piecewisepolynomial.jl

Lines changed: 5 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@ processes are hence v times mean-square differentiable. The kernel function is:
1010
where `r` is the Mahalanobis distance mahalanobis(x,y) with `maha` as the metric.
1111
1212
"""
13-
struct PiecewisePolynomialKernel{V, A<:AbstractMatrix{<:Real}} <: BaseKernel
13+
struct PiecewisePolynomialKernel{V, A<:AbstractMatrix{<:Real}} <: SimpleKernel
1414
maha::A
15+
j::Int
1516
function PiecewisePolynomialKernel{V}(maha::AbstractMatrix{<:Real}) where V
1617
V in (0, 1, 2, 3) || error("Invalid paramter v=$(V). Should be 0, 1, 2 or 3.")
1718
LinearAlgebra.checksquare(maha)
18-
return new{V,typeof(maha)}(maha)
19+
j = div(size(maha, 1), 2) + V + 1
20+
return new{V,typeof(maha)}(maha, j)
1921
end
2022
end
2123

@@ -29,78 +31,7 @@ _f(κ::PiecewisePolynomialKernel{2}, r, j) = 1 + (j + 2) * r + (j^2 + 4 * j + 3)
2931
_f::PiecewisePolynomialKernel{3}, r, j) = 1 + (j + 3) * r +
3032
(6 * j^2 + 36j + 45) / 15 * r.^2 + (j^3 + 9 * j^2 + 23j + 15) / 15 * r.^3
3133

32-
function _piecewisepolynomial::PiecewisePolynomialKernel{V}, r, j) where V
33-
return max(1 - r, 0)^(j + V) * _f(κ, r, j)
34-
end
35-
36-
function kappa(
37-
κ::PiecewisePolynomialKernel{V},
38-
x::AbstractVector{<:Real},
39-
y::AbstractVector{<:Real},
40-
) where {V}
41-
r = evaluate(metric(κ), x, y)
42-
j = div(size(x, 2), 1) + V + 1
43-
return _piecewisepolynomial(κ, r, j)
44-
end
45-
46-
function _kernel(
47-
κ::PiecewisePolynomialKernel,
48-
x::AbstractVector,
49-
y::AbstractVector;
50-
obsdim::Int = defaultobs,
51-
)
52-
@assert length(x) == length(y) "x and y don't have the same dimension!"
53-
return kappa(κ,x,y)
54-
end
55-
56-
function kernelmatrix(
57-
κ::PiecewisePolynomialKernel{V},
58-
X::AbstractMatrix;
59-
obsdim::Int = defaultobs
60-
) where {V}
61-
j = div(size(X, feature_dim(obsdim)), 2) + V + 1
62-
return map(r->_piecewisepolynomial(κ, r, j), pairwise(metric(κ), X; dims=obsdim))
63-
end
64-
65-
function _kernelmatrix::PiecewisePolynomialKernel{V}, X, Y, obsdim) where {V}
66-
j = div(size(X, feature_dim(obsdim)), 2) + V + 1
67-
return map(r->_piecewisepolynomial(κ, r, j), pairwise(metric(κ), X, Y; dims=obsdim))
68-
end
69-
70-
function kernelmatrix!(
71-
K::AbstractMatrix,
72-
κ::PiecewisePolynomialKernel{V},
73-
X::AbstractMatrix;
74-
obsdim::Int = defaultobs
75-
) where {V}
76-
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
77-
if !check_dims(K, X, X, feature_dim(obsdim), obsdim)
78-
throw(DimensionMismatch(
79-
"Dimensions of the target array K $(size(K)) are not consistent with X " *
80-
"$(size(X))",
81-
))
82-
end
83-
j = div(size(X, feature_dim(obsdim)), 2) + V + 1
84-
return map!(r->_piecewisepolynomial(κ,r,j), K, pairwise(metric(κ), X; dims=obsdim))
85-
end
86-
87-
function kernelmatrix!(
88-
K::AbstractMatrix,
89-
κ::PiecewisePolynomialKernel{V},
90-
X::AbstractMatrix,
91-
Y::AbstractMatrix;
92-
obsdim::Int = defaultobs,
93-
) where {V}
94-
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
95-
if !check_dims(K, X, Y, feature_dim(obsdim), obsdim)
96-
throw(DimensionMismatch(
97-
"Dimensions $(size(K)) of the target array K are not consistent with X " *
98-
"($(size(X))) and Y ($(size(Y)))",
99-
))
100-
end
101-
j = div(size(X, feature_dim(obsdim)), 2) + V + 1
102-
return map!(r->_piecewisepolynomial(κ,r,j), K, pairwise(metric(κ), X, Y; dims=obsdim))
103-
end
34+
kappa::PiecewisePolynomialKernel{V}, r) where V = max(1 - r, 0)^.j + V) * _f(κ, r, κ.j)
10435

10536
metric::PiecewisePolynomialKernel) = Mahalanobis.maha)
10637

0 commit comments

Comments
 (0)