Skip to content

Commit e3dae23

Browse files
committed
Merge master-dev into master
1 parent 95d697c commit e3dae23

14 files changed

+45
-16
lines changed

src/KernelFunctions.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,15 @@ using PDMats: PDMat
2222

2323
const defaultobs = 2
2424

25-
include("utils.jl")
26-
include("distances/dotproduct.jl")
27-
include("distances/delta.jl")
28-
29-
3025
"""
3126
Abstract type defining a slice-wise transformation on an input matrix
3227
"""
3328
abstract type Transform end
3429
abstract type Kernel{T,Tr<:Transform} end
3530

31+
include("utils.jl")
32+
include("distances/dotproduct.jl")
33+
include("distances/delta.jl")
3634
include("transform/transform.jl")
3735
kernels = ["exponential","matern","polynomial","constant","rationalquad","exponentiated"]
3836
for k in kernels

src/generic.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,5 @@ end
3232
function set_params!(k::Kernel,x)
3333
@error "Setting parameters to this kernel is either not possible or has not been implemented"
3434
end
35+
36+
params(k::Kernel) = (params(k.transform),)

src/kernels/constant.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ struct ConstantKernel{T,Tr,Tc<:Real} <: Kernel{T,Tr}
5555
end
5656
end
5757

58+
params(k::ConstantKernel) = (params(k.transform),k.c)
59+
5860
function ConstantKernel(c::Tc=1.0) where {Tc<:Real}
5961
ConstantKernel{Float64,IdentityTransform,Tc}(IdentityTransform(),c)
6062
end

src/kernels/exponential.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ struct GammaExponentialKernel{T,Tr,Tᵧ<:Real} <: Kernel{T,Tr}
6060
end
6161
end
6262

63+
params(k::GammaExponentialKernel) = (params(transform),γ)
64+
6365
function GammaExponentialKernel::T₁=1.0,gamma::T₂=2.0) where {T₁<:Real,T₂<:Real}
6466
@check_args(GammaExponentialKernel, gamma, gamma >= zero(T₂), "gamma > 0")
6567
GammaExponentialKernel{T₁,ScaleTransform{Base.RefValue{T₁}},T₂}(ScaleTransform(ρ),gamma)

src/kernels/kernelproduct.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ function KernelProduct(kernels::AbstractVector{<:Kernel})
1818
KernelProduct{eltype(kernels),Transform}(kernels)
1919
end
2020

21+
params(k::KernelProduct) = params.(k.kernels)
22+
2123
Base.:*(k1::Kernel,k2::Kernel) = KernelProduct([k1,k2])
2224
Base.:*(k1::KernelProduct,k2::KernelProduct) = KernelProduct(vcat(k1.kernels,k2.kernels)) #TODO Add test
2325
Base.:*(k::Kernel,kp::KernelProduct) = KernelProduct(vcat(k,kp.kernels))

src/kernels/kernelsum.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@ struct KernelSum{T,Tr} <: Kernel{T,Tr}
1919
end
2020
end
2121

22-
2322
function KernelSum(kernels::AbstractVector{<:Kernel}; weights::AbstractVector{<:Real}=ones(Float64,length(kernels)))
2423
@assert length(kernels)==length(weights) "Weights and kernel vector should be of the same length"
2524
@assert all(weights.>=0) "All weights should be positive"
2625
KernelSum{eltype(kernels),Transform}(kernels,weights)
2726
end
2827

28+
params(k::KernelSum) = (k.weights,params.(k.kernels))
29+
2930
Base.:+(k1::Kernel,k2::Kernel) = KernelSum([k1,k2],weights=[1.0,1.0])
3031
Base.:+(k1::KernelSum,k2::KernelSum) = KernelSum(vcat(k1.kernels,k2.kernels),weights=vcat(k1.weights,k2.weights))
3132
Base.:+(k::Kernel,ks::KernelSum) = KernelSum(vcat(k,ks.kernels),weights=vcat(1.0,ks.weights))

src/kernels/matern.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ function MaternKernel(t::Tr,ν::T=1.5) where {Tr<:Transform,T<:Real}
3030
MaternKernel{eltype(t),Tr,T}(t,ν)
3131
end
3232

33+
params(k::MaternKernel) = (params(transform(k)),k.ν)
34+
3335
@inline kappa::MaternKernel, d::Real) = iszero(d) ? one(d) : exp((1.0-κ.ν)*logtwo-lgamma.ν) + κ.ν*log(sqrt(2κ.ν)*d)+log(besselk.ν,sqrt(2κ.ν)*d)))
3436

3537
"""

src/kernels/polynomial.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ function LinearKernel(t::Tr,c::T=zero(Float64)) where {Tr<:Transform,T<:Real}
2727
LinearKernel{eltype(t),Tr,T}(t,c)
2828
end
2929

30+
params(k::LinearKernel) = (params(transform(k)),k.c)
31+
3032
@inline kappa::LinearKernel, xᵀy::T) where {T<:Real} = xᵀy + κ.c
3133

3234
"""
@@ -62,4 +64,6 @@ function PolynomialKernel(t::Tr,d::T₁=2.0,c::T₂=zero(eltype(T₁))) where {T
6264
PolynomialKernel{eltype(Tr),Tr,T₁,T₂}(t,c,d)
6365
end
6466

67+
params(k::PolynomialKernel) = (params(transform(k)),k.d,k.c)
68+
6569
@inline kappa::PolynomialKernel, xᵀy::T) where {T<:Real} = (xᵀy + κ.c)^.d)

src/kernels/rationalquad.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ function RationalQuadraticKernel(t::Tr,α::T=2.0) where {Tr<:Transform,T<:Real}
3030
RationalQuadraticKernel{eltype(t),Tr,T}(t,α)
3131
end
3232

33+
params(k::RationalQuadraticKernel) = (params(transform(k)),k.α)
34+
3335
@inline kappa::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+/κ.α)^(-κ.α)
3436

3537

@@ -69,4 +71,6 @@ function GammaRationalQuadraticKernel(t::Tr,α::T₁=2.0,γ::T₂=2.0) where {Tr
6971
GammaRationalQuadraticKernel{eltype(t),Tr,T₁,T₂}(t,α,γ)
7072
end
7173

74+
params(k::GammaRationalQuadraticKernel) = (params(k.transform),k.α,k.γ)
75+
7276
@inline kappa::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+^κ.γ/κ.α)^(-κ.α)

src/transform/functiontransform.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@ struct FunctionTransform{F} <: Transform
1313
end
1414

1515
transform(t::FunctionTransform,X::T,obsdim::Int=defaultobs) where {T} = mapslices(t.f,X,dims=obsdim)
16+
17+
params(t::FunctionTransform) = t.f

0 commit comments

Comments
 (0)