Skip to content

Commit d08f329

Browse files
committed
Vectorized parameters of each kernel and add trainable function
1 parent 310a885 commit d08f329

File tree

7 files changed

+36
-41
lines changed

7 files changed

+36
-41
lines changed

src/kernels/constant.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,14 @@ metric(::WhiteKernel) = Delta()
3535
Kernel function always returning a constant value `c`
3636
"""
3737
struct ConstantKernel{Tc<:Real} <: BaseKernel
38-
c::Tc
38+
c::Vector{Tc}
3939
function ConstantKernel(;c::T=1.0) where {T<:Real}
40-
new{T}(c)
40+
new{T}([c])
4141
end
4242
end
4343

44-
params(k::ConstantKernel) = (k.c,)
45-
opt_params(k::ConstantKernel) = (k.c,)
44+
trainable(k::ConstantKernel) = (k.c,)
4645

47-
kappa::ConstantKernel,x::Real) = κ.c*one(x)
46+
kappa::ConstantKernel,x::Real) = first(κ.c)*one(x)
4847

4948
metric(::ConstantKernel) = Delta()

src/kernels/exponential.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,15 @@ The γ-exponential kernel is an isotropic Mercer kernel given by the formula:
4747
```
4848
"""
4949
struct GammaExponentialKernel{Tγ<:Real} <: BaseKernel
50-
γ::Tγ
50+
γ::Vector{Tγ}
5151
function GammaExponentialKernel(;γ::T=2.0) where {T<:Real}
5252
@check_args(GammaExponentialKernel, γ, γ >= zero(T), "γ > 0")
53-
return new{T}(γ)
53+
return new{T}([γ])
5454
end
5555
end
5656

57-
params(k::GammaExponentialKernel) = (γ,)
58-
opt_params(k::GammaExponentialKernel) = (γ,)
57+
trainable(k::GammaExponentialKernel) = (γ,)
5958

60-
kappa::GammaExponentialKernel, d²::Real) = exp(-^κ.γ)
59+
kappa::GammaExponentialKernel, d²::Real) = exp(-^first(κ.γ))
6160
iskroncompatible(::GammaExponentialKernel) = true
6261
metric(::GammaExponentialKernel) = SqEuclidean()

src/kernels/matern.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,19 @@ The matern kernel is an isotropic Mercer kernel given by the formula:
77
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use [`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`, [`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`.
88
"""
99
struct MaternKernel{Tν<:Real} <: BaseKernel
10-
ν::Tν
10+
ν::Vector{Tν}
1111
function MaternKernel(;ν::T=1.5) where {T<:Real}
1212
@check_args(MaternKernel, ν, ν > zero(T), "ν > 0")
13-
return new{T}(ν)
13+
return new{T}([ν])
1414
end
1515
end
1616

17-
params(k::MaternKernel) = (k.ν,)
18-
opt_params(k::MaternKernel) = (k.ν,)
17+
trainable(k::MaternKernel) = (k.ν,)
1918

20-
@inline kappa::MaternKernel, d::Real) = iszero(d) ? one(d) : exp((one(d)-κ.ν)*logtwo-logabsgamma.ν)[1] + κ.ν*log(sqrt(2κ.ν)*d)+log(besselk.ν,sqrt(2κ.ν)*d)))
19+
@inline function kappa::MaternKernel, d::Real)
20+
ν = first.ν)
21+
iszero(d) ? one(d) : exp((one(d)-ν)*logtwo-logabsgamma(ν)[1] + ν*log(sqrt(2ν)*d)+log(besselk(ν,sqrt(2ν)*d)))
22+
end
2123

2224
metric(::MaternKernel) = Euclidean()
2325

src/kernels/polynomial.jl

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,15 @@ The linear kernel is a Mercer kernel given by
77
Where `c` is a real number
88
"""
99
struct LinearKernel{Tc<:Real} <: BaseKernel
10-
c::Tc
10+
c::Vector{Tc}
1111
function LinearKernel(;c::T=0.0) where {T}
12-
new{T}(c)
12+
new{T}([c])
1313
end
1414
end
1515

16-
params(k::LinearKernel) = (k.c,)
17-
opt_params(k::LinearKernel) = (k.c,)
16+
trainable(k::LinearKernel) = (k.c,)
1817

19-
kappa::LinearKernel, xᵀy::Real) = xᵀy + κ.c
18+
kappa::LinearKernel, xᵀy::Real) = xᵀy + first(κ.c)
2019

2120
metric(::LinearKernel) = DotProduct()
2221

@@ -28,18 +27,17 @@ The polynomial kernel is a Mercer kernel given by
2827
```
2928
Where `c` is a real number, and `d` is a shape parameter bigger than 1
3029
"""
31-
struct PolynomialKernel{Td<:Real,Tc<:Real} <: BaseKernel
32-
d::Td
33-
c::Tc
30+
struct PolynomialKernel{Td<:Real, Tc<:Real} <: BaseKernel
31+
d::Vector{Td}
32+
c::Vector{Tc}
3433
function PolynomialKernel(; d::Td=2.0, c::Tc=0.0) where {Td<:Real, Tc<:Real}
3534
@check_args(PolynomialKernel, d, d >= one(Td), "d >= 1")
36-
return new{Td, Tc}(d, c)
35+
return new{Td, Tc}([d], [c])
3736
end
3837
end
3938

40-
params(k::PolynomialKernel) = (k.d,k.c)
41-
opt_params(k::PolynomialKernel) = (k.d,k.c)
39+
trainable(k::PolynomialKernel) = (k.d,k.c)
4240

43-
kappa::PolynomialKernel, xᵀy::T) where {T<:Real} = (xᵀy + κ.c)^(κ.d)
41+
kappa::PolynomialKernel, xᵀy::T) where {T<:Real} = (xᵀy + first(κ.c))^(first(κ.d))
4442

4543
metric(::PolynomialKernel) = DotProduct()

src/kernels/rationalquad.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,16 @@ The rational-quadratic kernel is an isotropic Mercer kernel given by the formula
77
where `α` is a shape parameter of the Euclidean distance. Check [`GammaRationalQuadraticKernel`](@ref) for a generalization.
88
"""
99
struct RationalQuadraticKernel{Tα<:Real} <: BaseKernel
10-
α::Tα
10+
α::Vector{Tα}
1111
function RationalQuadraticKernel(;α::T=2.0) where {T}
1212
@check_args(RationalQuadraticKernel, α, α > zero(T), "α > 1")
13-
return new{T}(α)
13+
return new{T}([α])
1414
end
1515
end
1616

17-
params(k::RationalQuadraticKernel) = (k.α,)
18-
opt_params(k::RationalQuadraticKernel) = (k.α,)
17+
trainable(k::RationalQuadraticKernel) = (k.α,)
1918

20-
kappa::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+/κ.α)^(-κ.α)
19+
kappa::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+/first(κ.α))^(-first(κ.α))
2120

2221
metric(::RationalQuadraticKernel) = SqEuclidean()
2322

@@ -30,18 +29,17 @@ The Gamma-rational-quadratic kernel is an isotropic Mercer kernel given by the f
3029
where `α` is a shape parameter of the Euclidean distance and `γ` is another shape parameter.
3130
"""
3231
struct GammaRationalQuadraticKernel{Tα<:Real, Tγ<:Real} <: BaseKernel
33-
α::Tα
34-
γ::Tγ
32+
α::Vector{Tα}
33+
γ::Vector{Tγ}
3534
function GammaRationalQuadraticKernel(;α::Tα=2.0, γ::Tγ=2.0) where {Tα<:Real, Tγ<:Real}
3635
@check_args(GammaRationalQuadraticKernel, α, α > one(Tα), "α > 1")
3736
@check_args(GammaRationalQuadraticKernel, γ, γ >= one(Tγ), "γ >= 1")
38-
return new{Tα, Tγ}(α, γ)
37+
return new{Tα, Tγ}([α], [γ])
3938
end
4039
end
4140

42-
params(k::GammaRationalQuadraticKernel) = (k.α,k.γ)
43-
opt_params(k::GammaRationalQuadraticKernel) = (k.α,k.γ)
41+
trainable(k::GammaRationalQuadraticKernel) = (k.α,k.γ)
4442

45-
kappa::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+^κ.γ/κ.α)^(-κ.α)
43+
kappa::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+^first(κ.γ)/first(κ.α))^(-first(κ.α))
4644

4745
metric(::GammaRationalQuadraticKernel) = SqEuclidean()

src/kernels/scaledkernel.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ kappa(k::ScaledKernel, x) = first(k.σ)*kappa(k.kernel, x)
1212

1313
metric(k::ScaledKernel) = metric(k.kernel)
1414

15-
params(k::ScaledKernel) = (k.σ,params(k.kernel))
16-
opt_params(k::ScaledKernel) = (k.σ,opt_params(k.kernel))
15+
trainable(k::ScaledKernel) = (k.σ,k.kernel)
1716

1817
Base.:*(w::Real,k::Kernel) = ScaledKernel(k,w)
1918

src/kernels/transformedkernel.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ kappa(κ::TransformedKernel, x) = kappa(κ.kernel, x)
2727

2828
metric::TransformedKernel) = metric.kernel)
2929

30-
params::TransformedKernel) = (params(κ.transform),params(κ.kernel))
30+
trainable::TransformedKernel) =.transform,κ.kernel)
3131

3232
Base.show(io::IO::TransformedKernel) = printshifted(io,κ,0)
3333

0 commit comments

Comments
 (0)