Skip to content

Commit a0daa6c

Browse files
authored
Merge pull request #41 from theogf/params
params(), Flux/Zygote style
2 parents 716285c + 3e47144 commit a0daa6c

22 files changed

+133
-74
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,12 @@ julia = "1.0"
2525

2626
[extras]
2727
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
28+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
2829
Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e"
2930
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
3031
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
3132
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3233
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3334

3435
[targets]
35-
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker"]
36+
test = ["Random", "Test", "FiniteDifferences", "Zygote", "PDMats", "Kronecker", "Flux"]

src/KernelFunctions.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ include("zygote_adjoints.jl")
5858
function __init__()
5959
@require Kronecker="2c470bb0-bcc8-11e8-3dad-c9649493f05e" include("matrix/kernelkroneckermat.jl")
6060
@require PDMats="90014a1f-27ba-587c-ab20-58faa44d9150" include("matrix/kernelpdmat.jl")
61+
@require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" include("trainable.jl")
6162
end
6263

6364
end

src/kernels/constant.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,12 @@ metric(::WhiteKernel) = Delta()
3535
Kernel function always returning a constant value `c`
3636
"""
3737
struct ConstantKernel{Tc<:Real} <: BaseKernel
38-
c::Tc
38+
c::Vector{Tc}
3939
function ConstantKernel(;c::T=1.0) where {T<:Real}
40-
new{T}(c)
40+
new{T}([c])
4141
end
4242
end
4343

44-
params(k::ConstantKernel) = (k.c,)
45-
opt_params(k::ConstantKernel) = (k.c,)
46-
47-
kappa::ConstantKernel,x::Real) = κ.c*one(x)
44+
kappa::ConstantKernel,x::Real) = first.c)*one(x)
4845

4946
metric(::ConstantKernel) = Delta()

src/kernels/exponential.jl

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,13 @@ The γ-exponential kernel is an isotropic Mercer kernel given by the formula:
4747
```
4848
"""
4949
struct GammaExponentialKernel{Tγ<:Real} <: BaseKernel
50-
γ::Tγ
50+
γ::Vector{Tγ}
5151
function GammaExponentialKernel(;γ::T=2.0) where {T<:Real}
5252
@check_args(GammaExponentialKernel, γ, γ >= zero(T), "γ > 0")
53-
return new{T}(γ)
53+
return new{T}([γ])
5454
end
5555
end
5656

57-
params(k::GammaExponentialKernel) = (γ,)
58-
opt_params(k::GammaExponentialKernel) = (γ,)
59-
60-
kappa::GammaExponentialKernel, d²::Real) = exp(-^κ.γ)
57+
kappa::GammaExponentialKernel, d²::Real) = exp(-^first.γ))
6158
iskroncompatible(::GammaExponentialKernel) = true
6259
metric(::GammaExponentialKernel) = SqEuclidean()

src/kernels/kernelproduct.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@ struct KernelProduct <: Kernel
1414
kernels::Vector{Kernel}
1515
end
1616

17-
params(k::KernelProduct) = params.(k.kernels)
18-
opt_params(k::KernelProduct) = opt_params.(k.kernels)
19-
2017
Base.:*(k1::Kernel,k2::Kernel) = KernelProduct([k1,k2])
2118
Base.:*(k1::KernelProduct,k2::KernelProduct) = KernelProduct(vcat(k1.kernels,k2.kernels)) #TODO Add test
2219
Base.:*(k::Kernel,kp::KernelProduct) = KernelProduct(vcat(k,kp.kernels))

src/kernels/kernelsum.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@ function KernelSum(
2525
return KernelSum(kernels, weights)
2626
end
2727

28-
params(k::KernelSum) = (k.weights, params.(k.kernels))
29-
opt_params(k::KernelSum) = (k.weights, opt_params.(k.kernels))
30-
3128
Base.:+(k1::Kernel, k2::Kernel) = KernelSum([k1, k2], weights = [1.0, 1.0])
3229
Base.:+(k1::ScaledKernel, k2::ScaledKernel) = KernelSum([kernel(k1), kernel(k2)], weights = [first(k1.σ²), first(k2.σ²)])
3330
Base.:+(k1::KernelSum, k2::KernelSum) =

src/kernels/matern.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@ The matern kernel is an isotropic Mercer kernel given by the formula:
77
For `ν=n+1/2, n=0,1,2,...` it can be simplified and you should instead use [`ExponentialKernel`](@ref) for `n=0`, [`Matern32Kernel`](@ref), for `n=1`, [`Matern52Kernel`](@ref) for `n=2` and [`SqExponentialKernel`](@ref) for `n=∞`.
88
"""
99
struct MaternKernel{Tν<:Real} <: BaseKernel
10-
ν::Tν
10+
ν::Vector{Tν}
1111
function MaternKernel(;ν::T=1.5) where {T<:Real}
1212
@check_args(MaternKernel, ν, ν > zero(T), "ν > 0")
13-
return new{T}(ν)
13+
return new{T}([ν])
1414
end
1515
end
1616

17-
params(k::MaternKernel) = (k.ν,)
18-
opt_params(k::MaternKernel) = (k.ν,)
19-
20-
@inline kappa::MaternKernel, d::Real) = iszero(d) ? one(d) : exp((one(d)-κ.ν)*logtwo-logabsgamma.ν)[1] + κ.ν*log(sqrt(2κ.ν)*d)+log(besselk.ν,sqrt(2κ.ν)*d)))
17+
@inline function kappa::MaternKernel, d::Real)
18+
ν = first.ν)
19+
iszero(d) ? one(d) : exp((one(d)-ν)*logtwo-logabsgamma(ν)[1] + ν*log(sqrt(2ν)*d)+log(besselk(ν,sqrt(2ν)*d)))
20+
end
2121

2222
metric(::MaternKernel) = Euclidean()
2323

src/kernels/polynomial.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,13 @@ The linear kernel is a Mercer kernel given by
77
Where `c` is a real number
88
"""
99
struct LinearKernel{Tc<:Real} <: BaseKernel
10-
c::Tc
10+
c::Vector{Tc}
1111
function LinearKernel(;c::T=0.0) where {T}
12-
new{T}(c)
12+
new{T}([c])
1313
end
1414
end
1515

16-
params(k::LinearKernel) = (k.c,)
17-
opt_params(k::LinearKernel) = (k.c,)
18-
19-
kappa::LinearKernel, xᵀy::Real) = xᵀy + κ.c
16+
kappa::LinearKernel, xᵀy::Real) = xᵀy + first.c)
2017

2118
metric(::LinearKernel) = DotProduct()
2219

@@ -28,18 +25,15 @@ The polynomial kernel is a Mercer kernel given by
2825
```
2926
Where `c` is a real number, and `d` is a shape parameter bigger than 1
3027
"""
31-
struct PolynomialKernel{Td<:Real,Tc<:Real} <: BaseKernel
32-
d::Td
33-
c::Tc
28+
struct PolynomialKernel{Td<:Real, Tc<:Real} <: BaseKernel
29+
d::Vector{Td}
30+
c::Vector{Tc}
3431
function PolynomialKernel(; d::Td=2.0, c::Tc=0.0) where {Td<:Real, Tc<:Real}
3532
@check_args(PolynomialKernel, d, d >= one(Td), "d >= 1")
36-
return new{Td, Tc}(d, c)
33+
return new{Td, Tc}([d], [c])
3734
end
3835
end
3936

40-
params(k::PolynomialKernel) = (k.d,k.c)
41-
opt_params(k::PolynomialKernel) = (k.d,k.c)
42-
43-
kappa::PolynomialKernel, xᵀy::T) where {T<:Real} = (xᵀy + κ.c)^.d)
37+
kappa::PolynomialKernel, xᵀy::T) where {T<:Real} = (xᵀy + first.c))^(first.d))
4438

4539
metric(::PolynomialKernel) = DotProduct()

src/kernels/rationalquad.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,14 @@ The rational-quadratic kernel is an isotropic Mercer kernel given by the formula
77
where `α` is a shape parameter of the Euclidean distance. Check [`GammaRationalQuadraticKernel`](@ref) for a generalization.
88
"""
99
struct RationalQuadraticKernel{Tα<:Real} <: BaseKernel
10-
α::Tα
10+
α::Vector{Tα}
1111
function RationalQuadraticKernel(;α::T=2.0) where {T}
1212
@check_args(RationalQuadraticKernel, α, α > zero(T), "α > 1")
13-
return new{T}(α)
13+
return new{T}([α])
1414
end
1515
end
1616

17-
params(k::RationalQuadraticKernel) = (k.α,)
18-
opt_params(k::RationalQuadraticKernel) = (k.α,)
19-
20-
kappa::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+/κ.α)^(-κ.α)
17+
kappa::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+/first.α))^(-first.α))
2118

2219
metric(::RationalQuadraticKernel) = SqEuclidean()
2320

@@ -30,18 +27,15 @@ The Gamma-rational-quadratic kernel is an isotropic Mercer kernel given by the f
3027
where `α` is a shape parameter of the Euclidean distance and `γ` is another shape parameter.
3128
"""
3229
struct GammaRationalQuadraticKernel{Tα<:Real, Tγ<:Real} <: BaseKernel
33-
α::Tα
34-
γ::Tγ
30+
α::Vector{Tα}
31+
γ::Vector{Tγ}
3532
function GammaRationalQuadraticKernel(;α::Tα=2.0, γ::Tγ=2.0) where {Tα<:Real, Tγ<:Real}
3633
@check_args(GammaRationalQuadraticKernel, α, α > one(Tα), "α > 1")
3734
@check_args(GammaRationalQuadraticKernel, γ, γ >= one(Tγ), "γ >= 1")
38-
return new{Tα, Tγ}(α, γ)
35+
return new{Tα, Tγ}([α], [γ])
3936
end
4037
end
4138

42-
params(k::GammaRationalQuadraticKernel) = (k.α,k.γ)
43-
opt_params(k::GammaRationalQuadraticKernel) = (k.α,k.γ)
44-
45-
kappa::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+^κ.γ/κ.α)^(-κ.α)
39+
kappa::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+^first.γ)/first.α))^(-first.α))
4640

4741
metric(::GammaRationalQuadraticKernel) = SqEuclidean()

src/kernels/scaledkernel.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@ kappa(k::ScaledKernel, x) = first(k.σ²) * kappa(k.kernel, x)
1212

1313
metric(k::ScaledKernel) = metric(k.kernel)
1414

15-
params(k::ScaledKernel) = (k.σ², params(k.kernel))
16-
opt_params(k::ScaledKernel) = (k.σ², opt_params(k.kernel))
17-
1815
Base.:*(w::Real, k::Kernel) = ScaledKernel(k, w)
1916

2017
Base.show(io::IO, κ::ScaledKernel) = printshifted(io, κ, 0)

0 commit comments

Comments
 (0)