Skip to content

Commit 5b9c395

Browse files
committed
Merge master-dev
1 parent f6cb413 commit 5b9c395

17 files changed

+162
-117
lines changed

src/KernelFunctions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ export KernelSum, KernelProduct
1515
export SelectTransform, ChainTransform, ScaleTransform, LowRankTransform, IdentityTransform, FunctionTransform
1616

1717

18-
using Distances, LinearAlgebra
18+
using Distances, LinearAlgebra, StaticArrays
1919
using SpecialFunctions: lgamma, besselk
2020
using StatsFuns: logtwo
2121
using PDMats: PDMat
@@ -42,6 +42,6 @@ include("kernels/kernelsum.jl")
4242
include("kernels/kernelproduct.jl")
4343

4444
include("generic.jl")
45-
45+
include("squeeze.jl")
4646

4747
end

src/generic.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,16 @@ end
2323
## Constructors for kernels without parameters
2424
for kernel in [:ExponentialKernel,:SqExponentialKernel,:Matern32Kernel,:Matern52Kernel,:ExponentiatedKernel]
2525
@eval begin
26-
$kernel::T=1.0) where {T<:Real} = $kernel{T,ScaleTransform{Base.RefValue{T}}}(ScaleTransform(ρ))
27-
$kernel::A) where {A<:AbstractVector{<:Real}} = $kernel{eltype(A),ScaleTransform{A}}(ScaleTransform(ρ))
26+
$kernel::T=1.0) where {T<:Real} = $kernel{T,ScaleTransform{T}}(ScaleTransform(ρ))
27+
$kernel::AbstractVector{T}) where {T<:Real} = $kernel{T,ARDTransform{T,length(ρ)}}(ARDTransform(ρ))
2828
$kernel(t::Tr) where {Tr<:Transform} = $kernel{eltype(t),Tr}(t)
2929
end
3030
end
3131

3232
function set_params!(k::Kernel,x)
33-
@error "Setting parameters to this kernel is either not possible or has not been implemented"
33+
set!(k.transform,first(x))
3434
end
3535

36+
3637
params(k::Kernel) = (params(k.transform),)
38+
opt_params(k::Kernel) = (opt_params(k.transform),)

src/kernels/constant.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ struct ConstantKernel{T,Tr,Tc<:Real} <: Kernel{T,Tr}
5656
end
5757

5858
params(k::ConstantKernel) = (params(k.transform),k.c)
59+
opt_params(k::ConstantKernel) = (opt_params(k.transform),k.c)
5960

6061
function ConstantKernel(c::Tc=1.0) where {Tc<:Real}
6162
ConstantKernel{Float64,IdentityTransform,Tc}(IdentityTransform(),c)

src/kernels/exponential.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,16 @@ struct GammaExponentialKernel{T,Tr,Tᵧ<:Real} <: Kernel{T,Tr}
6161
end
6262

6363
params(k::GammaExponentialKernel) = (params(transform),γ)
64+
opt_params(k::GammaExponentialKernel) = (opt_params(transform),γ)
6465

6566
function GammaExponentialKernel::T₁=1.0,gamma::T₂=2.0) where {T₁<:Real,T₂<:Real}
6667
@check_args(GammaExponentialKernel, gamma, gamma >= zero(T₂), "gamma > 0")
67-
GammaExponentialKernel{T₁,ScaleTransform{Base.RefValue{T₁}},T₂}(ScaleTransform(ρ),gamma)
68+
GammaExponentialKernel{T₁,ScaleTransform{T₁},T₂}(ScaleTransform(ρ),gamma)
6869
end
6970

70-
function GammaExponentialKernel::A,gamma::T=2.0) where {A<:AbstractVector{<:Real},T₁<:Real}
71-
@check_args(GammaExponentialKernel, gamma, gamma >= zero(T), "gamma > 0")
72-
GammaExponentialKernel{eltype(A),ScaleTransform{A},T}(ScaleTransform(ρ),gamma)
71+
function GammaExponentialKernel::AbstractVector{T₁},gamma::T=2.0) where {T₁<:Real,T₂<:Real}
72+
@check_args(GammaExponentialKernel, gamma, gamma >= zero(T), "gamma > 0")
73+
GammaExponentialKernel{T₁,ARDTransform{T₁,length(ρ)},T}(ScaleTransform(ρ),gamma)
7374
end
7475

7576
function GammaExponentialKernel(t::Tr,gamma::T₁=2.0) where {Tr<:Transform,T₁<:Real}

src/kernels/kernelproduct.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ function KernelProduct(kernels::AbstractVector{<:Kernel})
1919
end
2020

2121
params(k::KernelProduct) = params.(k.kernels)
22+
opt_params(k::KernelProduct) = opt_params.(k.kernels)
2223

2324
Base.:*(k1::Kernel,k2::Kernel) = KernelProduct([k1,k2])
2425
Base.:*(k1::KernelProduct,k2::KernelProduct) = KernelProduct(vcat(k1.kernels,k2.kernels)) #TODO Add test

src/kernels/kernelsum.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ function KernelSum(kernels::AbstractVector{<:Kernel}; weights::AbstractVector{<:
2626
end
2727

2828
params(k::KernelSum) = (k.weights,params.(k.kernels))
29+
opt_params(k::KernelSum) = (k.weights,opt_params.(k.kernels))
2930

3031
Base.:+(k1::Kernel,k2::Kernel) = KernelSum([k1,k2],weights=[1.0,1.0])
3132
Base.:+(k1::KernelSum,k2::KernelSum) = KernelSum(vcat(k1.kernels,k2.kernels),weights=vcat(k1.weights,k2.weights))

src/kernels/matern.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ end
1717

1818
function MaternKernel::T₁=1.0::T₂=1.5) where {T₁<:Real,T₂<:Real}
1919
@check_args(MaternKernel, ν, ν > zero(T₂), "ν > 0")
20-
MaternKernel{T₁,ScaleTransform{Base.RefValue{T₁}},T₂}(ScaleTransform(ρ),ν)
20+
MaternKernel{T₁,ScaleTransform{T₁},T₂}(ScaleTransform(ρ),ν)
2121
end
2222

23-
function MaternKernel::A::T=1.5) where {A<:AbstractVector{<:Real},T<:Real}
24-
@check_args(MaternKernel, ν, ν > zero(T), "ν > 0")
25-
MaternKernel{eltype(A),ScaleTransform{A},T}(ScaleTransform(ρ),ν)
23+
function MaternKernel::AbstractVector{T₁}::T=1.5) where {T₁<:Real,T₂<:Real}
24+
@check_args(MaternKernel, ν, ν > zero(T), "ν > 0")
25+
MaternKernel{T₁,ARDTransform{T₁,length(ρ)},T₂}(ARDTransform(ρ),ν)
2626
end
2727

2828
function MaternKernel(t::Tr::T=1.5) where {Tr<:Transform,T<:Real}
@@ -31,6 +31,7 @@ function MaternKernel(t::Tr,ν::T=1.5) where {Tr<:Transform,T<:Real}
3131
end
3232

3333
params(k::MaternKernel) = (params(transform(k)),k.ν)
34+
opt_params(k::MaternKernel) = (opt_params(transform(k)),k.ν)
3435

3536
@inline kappa::MaternKernel, d::Real) = iszero(d) ? one(d) : exp((1.0-κ.ν)*logtwo-lgamma.ν) + κ.ν*log(sqrt(2κ.ν)*d)+log(besselk.ν,sqrt(2κ.ν)*d)))
3637

src/kernels/polynomial.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,19 @@ struct LinearKernel{T,Tr,Tc<:Real} <: Kernel{T,Tr}
1616
end
1717

1818
function LinearKernel::T₁=1.0,c::T₂=zero(T₁)) where {T₁<:Real,T₂<:Real}
19-
LinearKernel{T₁,ScaleTransform{Base.RefValue{T₁}},T₂}(ScaleTransform(ρ),c)
19+
LinearKernel{T₁,ScaleTransform{T₁},T₂}(ScaleTransform(ρ),c)
2020
end
2121

22-
function LinearKernel::A,c::T=zero(eltype(ρ))) where {A<:AbstractVector{<:Real},T<:Real}
23-
LinearKernel{eltype(A),ScaleTransform{A},T}(ScaleTransform(ρ),c)
22+
function LinearKernel::AbstractVector{T₁},c::T=zero(T₁)) where {T₁<:Real,T₂<:Real}
23+
LinearKernel{T₁,ARDTransform{T₁,length(ρ)},T₂}(ARDTransform(ρ),c)
2424
end
2525

2626
function LinearKernel(t::Tr,c::T=zero(Float64)) where {Tr<:Transform,T<:Real}
2727
LinearKernel{eltype(t),Tr,T}(t,c)
2828
end
2929

3030
params(k::LinearKernel) = (params(transform(k)),k.c)
31+
opt_params(k::LinearKernel) = (opt_params(transform(k)),k.c)
3132

3233
@inline kappa::LinearKernel, xᵀy::T) where {T<:Real} = xᵀy + κ.c
3334

@@ -51,12 +52,12 @@ end
5152

5253
function PolynomialKernel::T₁=1.0,d::T₂=2.0,c::T₃=zero(T₁)) where {T₁<:Real,T₂<:Real,T₃<:Real}
5354
@check_args(PolynomialKernel, d, d >= one(T₁), "d >= 1")
54-
PolynomialKernel{T₁,ScaleTransform{Base.RefValue{T₁}},T₂,T₃}(ScaleTransform(ρ),c,d)
55+
PolynomialKernel{T₁,ScaleTransform{T₁},T₂,T₃}(ScaleTransform(ρ),c,d)
5556
end
5657

57-
function PolynomialKernel::A,d::T=2.0,c::T=zero(eltype(ρ))) where {A<:AbstractVector{<:Real},T₁<:Real,T<:Real}
58-
@check_args(PolynomialKernel, d, d >= one(T), "d >= 1")
59-
PolynomialKernel{eltype(A),ScaleTransform{A},T₁,T₂}(ScaleTransform(ρ),c,d)
58+
function PolynomialKernel::AbstractVector{T₁},d::T=2.0,c::T=zero(T₁)) where {T₁<:Real,T₂<:Real,T<:Real}
59+
@check_args(PolynomialKernel, d, d >= one(T), "d >= 1")
60+
PolynomialKernel{T₁,ARDTransform{T₁,length(ρ)},T₂,T₃}(ARDTransform(ρ),c,d)
6061
end
6162

6263
function PolynomialKernel(t::Tr,d::T₁=2.0,c::T₂=zero(eltype(T₁))) where {Tr<:Transform,T₁<:Real,T₂<:Real}
@@ -65,5 +66,6 @@ function PolynomialKernel(t::Tr,d::T₁=2.0,c::T₂=zero(eltype(T₁))) where {T
6566
end
6667

6768
params(k::PolynomialKernel) = (params(transform(k)),k.d,k.c)
69+
opt_params(k::PolynomialKernel) = (opt_params(transform(k)),k.d,k.c)
6870

6971
@inline kappa::PolynomialKernel, xᵀy::T) where {T<:Real} = (xᵀy + κ.c)^.d)

src/kernels/rationalquad.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ end
1717

1818
function RationalQuadraticKernel::T₁=1.0::T₂=2.0) where {T₁<:Real,T₂<:Real}
1919
@check_args(RationalQuadraticKernel, α, α > zero(T₂), "α > 1")
20-
RationalQuadraticKernel{T₁,ScaleTransform{Base.RefValue{T₁}},T₂}(ScaleTransform(ρ),α)
20+
RationalQuadraticKernel{T₁,ScaleTransform{T₁},T₂}(ScaleTransform(ρ),α)
2121
end
2222

23-
function RationalQuadraticKernel::A::T=2.0) where {A<:AbstractVector{<:Real},T<:Real}
24-
@check_args(RationalQuadraticKernel, α, α > zero(T), "α > 1")
25-
RationalQuadraticKernel{eltype(A),ScaleTransform{A},T}(ScaleTransform(ρ),α)
23+
function RationalQuadraticKernel::AbstractVector{T₁}::T=2.0) where {T₁<:Real,T₂<:Real}
24+
@check_args(RationalQuadraticKernel, α, α > zero(T), "α > 1")
25+
RationalQuadraticKernel{T₁,ARDTransform{T₁,length(ρ)},T₂}(ARDTransform(ρ),α)
2626
end
2727

2828
function RationalQuadraticKernel(t::Tr::T=2.0) where {Tr<:Transform,T<:Real}
@@ -31,6 +31,7 @@ function RationalQuadraticKernel(t::Tr,α::T=2.0) where {Tr<:Transform,T<:Real}
3131
end
3232

3333
params(k::RationalQuadraticKernel) = (params(transform(k)),k.α)
34+
opt_params(k::RationalQuadraticKernel) = (opt_params(transform(k)),k.α)
3435

3536
@inline kappa::RationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+/κ.α)^(-κ.α)
3637

@@ -56,13 +57,13 @@ end
5657
function GammaRationalQuadraticKernel::T₁=1.0::T₂=2.0::T₃=2.0) where {T₁<:Real,T₂<:Real,T₃<:Real}
5758
@check_args(GammaRationalQuadraticKernel, α, α > one(T₂), "α > 1")
5859
@check_args(GammaRationalQuadraticKernel, γ, γ >= one(T₂), "γ >= 1")
59-
GammaRationalQuadraticKernel{T₁,ScaleTransform{Base.RefValue{T₁}},T₂,T₃}(ScaleTransform(ρ),α,γ)
60+
GammaRationalQuadraticKernel{T₁,ScaleTransform{T₁},T₂,T₃}(ScaleTransform(ρ),α,γ)
6061
end
6162

62-
function GammaRationalQuadraticKernel::A::T=2.0::T=2.0) where {A<:AbstractVector{<:Real},T₁<:Real,T<:Real}
63-
@check_args(GammaRationalQuadraticKernel, α, α > one(T), "α > 1")
64-
@check_args(GammaRationalQuadraticKernel, γ, γ >= one(T), "γ >= 1")
65-
GammaRationalQuadraticKernel{eltype(A),ScaleTransform{A},T₁,T₂}(ScaleTransform(ρ),α,γ)
63+
function GammaRationalQuadraticKernel::AbstractVector{T₁}::T=2.0::T=2.0) where {T₁<:Real,T₂<:Real,T<:Real}
64+
@check_args(GammaRationalQuadraticKernel, α, α > one(T), "α > 1")
65+
@check_args(GammaRationalQuadraticKernel, γ, γ >= one(T), "γ >= 1")
66+
GammaRationalQuadraticKernel{T₁,ARDTransform{T₁,length(ρ)},T₂,T₃}(ARDTransform(ρ),α,γ)
6667
end
6768

6869
function GammaRationalQuadraticKernel(t::Tr::T₁=2.0::T₂=2.0) where {Tr<:Transform,T₁<:Real,T₂<:Real}
@@ -72,5 +73,6 @@ function GammaRationalQuadraticKernel(t::Tr,α::T₁=2.0,γ::T₂=2.0) where {Tr
7273
end
7374

7475
params(k::GammaRationalQuadraticKernel) = (params(k.transform),k.α,k.γ)
76+
opt_params(k::GammaRationalQuadraticKernel) = (opt_params(k.transform),k.α,k.γ)
7577

7678
@inline kappa::GammaRationalQuadraticKernel, d²::T) where {T<:Real} = (one(T)+^κ.γ/κ.α)^(-κ.α)

src/squeeze.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
base_kernel(k::Kernel) = eval(nameof(typeof(k)))
2+
3+
base_transform(k::Kernel) = base_transform(k.transform)
4+
base_transform(t::Transform) = eval(nameof(typeof(t)))
5+
tail(v::AbstractVector) = view(v,2:length(v))
6+
duplicate(k::Kernel::AbstractVector) = base_kernel(k)(duplicate(transform(k),first(θ)),tail(θ)...)
7+
duplicate(k::Kernel::Tuple) = base_kernel(k)(duplicate(transform(k),first(θ)),Base.tail(θ)...)
8+
Base.one(x::V) where {V<:AbstractArray{T}} where T = V(fill(one(T),size(x)))
9+
duplicate(t::Transform,θ) = base_transform(t)(θ)
10+
duplicate(t::ChainTransform,θ) = ChainTransform(duplicate.(t.transforms,θ))
11+
duplicate(t::FunctionTransform,θ) = t
12+
duplicate(t::IdentityTransform,θ) = t
13+
duplicate(t::SelectTransform,θ) = t
14+
15+
16+
function duplicate(k::KernelSum,θ)
17+
KernelSum(duplicate.(k.kernels,θ[2]),weights=first(θ))
18+
end
19+
20+
function duplicate(k::KernelProduct,θ)
21+
KernelProduct(duplicate.(k.kernels,θ))
22+
end
23+
24+
dim(k::Kernel) = length(params(k))

0 commit comments

Comments
 (0)