diff --git a/Project.toml b/Project.toml index 19247efdf..27432b0c2 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -30,6 +31,7 @@ FillArrays = "0.10, 0.11, 0.12, 0.13, 1" Functors = "0.1, 0.2, 0.3, 0.4" IrrationalConstants = "0.1, 0.2" LogExpFunctions = "0.2.1, 0.3" +ParameterHandling = "0.4" Requires = "1.0.1" SpecialFunctions = "0.8, 0.9, 0.10, 1, 2" Statistics = "1" diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 63205b5bf..421127577 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -43,6 +43,8 @@ export MOInput, prepare_isotopic_multi_output_data, prepare_heterotopic_multi_ou export IndependentMOKernel, LatentFactorMOKernel, IntrinsicCoregionMOKernel, LinearMixingModelKernel +export ParameterKernel + # Reexports export tensor, ⊗, compose @@ -53,11 +55,12 @@ using CompositionsBase using Distances using FillArrays using Functors +using ParameterHandling using LinearAlgebra using Requires using SpecialFunctions: loggamma, besselk, polygamma using IrrationalConstants: logtwo, twoπ, invsqrt2 -using LogExpFunctions: softplus +using LogExpFunctions: logit, logistic, softplus using StatsBase using TensorCore using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield @@ -111,6 +114,7 @@ include("kernels/kernelproduct.jl") include("kernels/kerneltensorproduct.jl") include("kernels/overloads.jl") include("kernels/neuralkernelnetwork.jl") +include("kernels/parameterkernel.jl") include("approximations/nystrom.jl") include("generic.jl") diff --git a/src/TestUtils.jl b/src/TestUtils.jl index cd14ec718..567906d38 100644 --- a/src/TestUtils.jl +++ b/src/TestUtils.jl @@ -3,6 +3,7 @@ module TestUtils using Distances using LinearAlgebra using KernelFunctions +using ParameterHandling using Random using Test @@ -84,6 +85,11 @@ function test_interface( tmp_diag = Vector{Float64}(undef, length(x0)) @test kernelmatrix_diag!(tmp_diag, k, x0) ≈ kernelmatrix_diag(k, x0) @test kernelmatrix_diag!(tmp_diag, k, x0, x1) ≈ kernelmatrix_diag(k, x0, x1) + + # Check flatten/unflatten + ParameterHandling.TestUtils.test_flatten_interface(k) + + return nothing end """ diff --git a/src/basekernels/constant.jl b/src/basekernels/constant.jl index e3f5245fb..d9fe9d070 100644 --- a/src/basekernels/constant.jl +++ b/src/basekernels/constant.jl @@ -15,8 +15,11 @@ See also: [`ConstantKernel`](@ref) """ struct ZeroKernel <: SimpleKernel end +@noparams ZeroKernel + # SimpleKernel interface kappa(::ZeroKernel, ::Real) = false + metric(::ZeroKernel) = Delta() # Optimizations @@ -68,6 +71,8 @@ k(x, x') = \\delta(x, x'). """ struct WhiteKernel <: SimpleKernel end +@noparams WhiteKernel + """ EyeKernel() @@ -95,52 +100,59 @@ k(x, x') = c. See also: [`ZeroKernel`](@ref) """ -struct ConstantKernel{Tc<:Real} <: SimpleKernel - c::Vector{Tc} +struct ConstantKernel{T<:Real} <: SimpleKernel + c::T - function ConstantKernel(; c::Real=1.0) + function ConstantKernel(c::Real) @check_args(ConstantKernel, c, c >= zero(c), "c ≥ 0") - return new{typeof(c)}([c]) + return new{typeof(c)}(c) end end -@functor ConstantKernel +ConstantKernel(; c::Real=1.0) = ConstantKernel(c) + +function ParameterHandling.flatten(::Type{T}, k::ConstantKernel{S}) where {T<:Real,S} + function unflatten_to_constantkernel(v::Vector{T}) + return ConstantKernel(; c=S(exp(only(v)))) + end + return T[log(k.c)], unflatten_to_constantkernel +end # SimpleKernel interface -kappa(κ::ConstantKernel, ::Real) = only(κ.c) +kappa(κ::ConstantKernel, ::Real) = κ.c metric(::ConstantKernel) = Delta() # Optimizations -(k::ConstantKernel)(x, y) = only(k.c) -kernelmatrix(k::ConstantKernel, x::AbstractVector) = Fill(only(k.c), length(x), length(x)) +(k::ConstantKernel)(x, y) = k.c +kernelmatrix(k::ConstantKernel, x::AbstractVector) = Fill(k.c, length(x), length(x)) function kernelmatrix(k::ConstantKernel, x::AbstractVector, y::AbstractVector) validate_inputs(x, y) - return Fill(only(k.c), length(x), length(y)) + return Fill(k.c, length(x), length(y)) end function kernelmatrix!(K::AbstractMatrix, k::ConstantKernel, x::AbstractVector) validate_inplace_dims(K, x) - return fill!(K, only(k.c)) + return fill!(K, k.c) end function kernelmatrix!( K::AbstractMatrix, k::ConstantKernel, x::AbstractVector, y::AbstractVector ) validate_inplace_dims(K, x, y) - return fill!(K, only(k.c)) + return fill!(K, k.c) end -kernelmatrix_diag(k::ConstantKernel, x::AbstractVector) = Fill(only(k.c), length(x)) +kernelmatrix_diag(k::ConstantKernel, x::AbstractVector) = Fill(k.c, length(x)) function kernelmatrix_diag(k::ConstantKernel, x::AbstractVector, y::AbstractVector) validate_inputs(x, y) - return Fill(only(k.c), length(x)) + return Fill(k.c, length(x)) end function kernelmatrix_diag!(K::AbstractVector, k::ConstantKernel, x::AbstractVector) validate_inplace_dims(K, x) - return fill!(K, only(k.c)) + return fill!(K, k.c) end function kernelmatrix_diag!( K::AbstractVector, k::ConstantKernel, x::AbstractVector, y::AbstractVector ) validate_inplace_dims(K, x, y) - return fill!(K, only(k.c)) + return fill!(K, k.c) end -Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", only(κ.c), ")") +Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", κ.c, ")") diff --git a/src/basekernels/cosine.jl b/src/basekernels/cosine.jl index 50cc6fdf3..4c1040822 100644 --- a/src/basekernels/cosine.jl +++ b/src/basekernels/cosine.jl @@ -17,6 +17,8 @@ end CosineKernel(; metric=Euclidean()) = CosineKernel(metric) +@noparams CosineKernel + kappa(::CosineKernel, d::Real) = cospi(d) metric(k::CosineKernel) = k.metric diff --git a/src/basekernels/exponential.jl b/src/basekernels/exponential.jl index 2061d40f9..ab6232201 100644 --- a/src/basekernels/exponential.jl +++ b/src/basekernels/exponential.jl @@ -20,6 +20,8 @@ end SqExponentialKernel(; metric=Euclidean()) = SqExponentialKernel(metric) +@noparams SqExponentialKernel + kappa(::SqExponentialKernel, d::Real) = exp(-d^2 / 2) kappa(::SqExponentialKernel{<:Euclidean}, d²::Real) = exp(-d² / 2) @@ -76,6 +78,8 @@ end ExponentialKernel(; metric=Euclidean()) = ExponentialKernel(metric) +@noparams ExponentialKernel + kappa(::ExponentialKernel, d::Real) = exp(-d) metric(k::ExponentialKernel) = k.metric @@ -121,13 +125,13 @@ See also: [`ExponentialKernel`](@ref), [`SqExponentialKernel`](@ref) [^RW]: C. E. Rasmussen & C. K. I. Williams (2006). Gaussian Processes for Machine Learning. """ -struct GammaExponentialKernel{Tγ<:Real,M} <: SimpleKernel - γ::Vector{Tγ} +struct GammaExponentialKernel{T<:Real,M} <: SimpleKernel + γ::T metric::M function GammaExponentialKernel(γ::Real, metric) @check_args(GammaExponentialKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]") - return new{typeof(γ),typeof(metric)}([γ], metric) + return new{typeof(γ),typeof(metric)}(γ, metric) end end @@ -135,16 +139,23 @@ function GammaExponentialKernel(; gamma::Real=1.0, γ::Real=gamma, metric=Euclid return GammaExponentialKernel(γ, metric) end -@functor GammaExponentialKernel +function ParameterHandling.flatten( + ::Type{T}, k::GammaExponentialKernel{S} +) where {T<:Real,S<:Real} + metric = k.metric + function unflatten_to_gammaexponentialkernel(v::Vector{T}) + γ = S(2 * logistic(only(v))) + return GammaExponentialKernel(; γ=γ, metric=metric) + end + return T[logit(k.γ / 2)], unflatten_to_gammaexponentialkernel +end -kappa(κ::GammaExponentialKernel, d::Real) = exp(-d^only(κ.γ)) +kappa(κ::GammaExponentialKernel, d::Real) = exp(-d^κ.γ) metric(k::GammaExponentialKernel) = k.metric iskroncompatible(::GammaExponentialKernel) = true function Base.show(io::IO, κ::GammaExponentialKernel) - return print( - io, "Gamma Exponential Kernel (γ = ", only(κ.γ), ", metric = ", κ.metric, ")" - ) + return print(io, "Gamma Exponential Kernel (γ = ", κ.γ, ", metric = ", κ.metric, ")") end diff --git a/src/basekernels/exponentiated.jl b/src/basekernels/exponentiated.jl index 0b360ceb6..66888f750 100644 --- a/src/basekernels/exponentiated.jl +++ b/src/basekernels/exponentiated.jl @@ -12,6 +12,8 @@ k(x, x') = \\exp(x^\\top x'). """ struct ExponentiatedKernel <: SimpleKernel end +@noparams ExponentiatedKernel + kappa(::ExponentiatedKernel, xᵀy::Real) = exp(xᵀy) metric(::ExponentiatedKernel) = DotProduct() diff --git a/src/basekernels/fbm.jl b/src/basekernels/fbm.jl index 18dbdea14..c947e3551 100644 --- a/src/basekernels/fbm.jl +++ b/src/basekernels/fbm.jl @@ -13,16 +13,23 @@ k(x, x'; h) = \\frac{\\|x\\|_2^{2h} + \\|x'\\|_2^{2h} - \\|x - x'\\|^{2h}}{2}. ``` """ struct FBMKernel{T<:Real} <: Kernel - h::Vector{T} + h::T + function FBMKernel(h::Real) @check_args(FBMKernel, h, zero(h) ≤ h ≤ one(h), "h ∈ [0, 1]") - return new{typeof(h)}([h]) + return new{typeof(h)}(h) end end FBMKernel(; h::Real=0.5) = FBMKernel(h) -@functor FBMKernel +function ParameterHandling.flatten(::Type{T}, k::FBMKernel{S}) where {T<:Real,S<:Real} + function unflatten_to_fbmkernel(v::Vector{T}) + h = S(logistic(only(v))) + return FBMKernel(h) + end + return T[logit(k.h)], unflatten_to_fbmkernel +end function (κ::FBMKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) modX = sum(abs2, x) diff --git a/src/basekernels/matern.jl b/src/basekernels/matern.jl index c009bf596..e4d0bdb46 100644 --- a/src/basekernels/matern.jl +++ b/src/basekernels/matern.jl @@ -23,19 +23,23 @@ differentiable in the mean-square sense. See also: [`Matern12Kernel`](@ref), [`Matern32Kernel`](@ref), [`Matern52Kernel`](@ref) """ -struct MaternKernel{Tν<:Real,M} <: SimpleKernel - ν::Vector{Tν} +struct MaternKernel{T<:Real,M} <: SimpleKernel + ν::T metric::M function MaternKernel(ν::Real, metric) @check_args(MaternKernel, ν, ν > zero(ν), "ν > 0") - return new{typeof(ν),typeof(metric)}([ν], metric) + return new{typeof(ν),typeof(metric)}(ν, metric) end end MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, metric) -@functor MaternKernel +function ParameterHandling.flatten(::Type{T}, k::MaternKernel{S}) where {T<:Real,S<:Real} + metric = k.metric + unflatten_to_maternkernel(v::Vector{T}) = MaternKernel(S(exp(first(v))), metric) + return T[log(k.ν)], unflatten_to_maternkernel +end @inline kappa(k::MaternKernel, d::Real) = _matern(only(k.ν), d) @@ -80,6 +84,8 @@ end Matern32Kernel(; metric=Euclidean()) = Matern32Kernel(metric) +@noparams Matern32Kernel + kappa(::Matern32Kernel, d::Real) = (1 + sqrt(3) * d) * exp(-sqrt(3) * d) metric(k::Matern32Kernel) = k.metric @@ -111,6 +117,8 @@ end Matern52Kernel(; metric=Euclidean()) = Matern52Kernel(metric) +@noparams Matern52Kernel + kappa(::Matern52Kernel, d::Real) = (1 + sqrt(5) * d + 5 * d^2 / 3) * exp(-sqrt(5) * d) metric(k::Matern52Kernel) = k.metric diff --git a/src/basekernels/nn.jl b/src/basekernels/nn.jl index 08bf233ba..86946f66b 100644 --- a/src/basekernels/nn.jl +++ b/src/basekernels/nn.jl @@ -33,6 +33,8 @@ for inputs ``x, x' \\in \\mathbb{R}^d``.[^CW] """ struct NeuralNetworkKernel <: Kernel end +@noparams NeuralNetworkKernel + function (κ::NeuralNetworkKernel)(x, y) return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y)))) end diff --git a/src/basekernels/periodic.jl b/src/basekernels/periodic.jl index 2758d7f94..a0464f530 100644 --- a/src/basekernels/periodic.jl +++ b/src/basekernels/periodic.jl @@ -21,16 +21,25 @@ struct PeriodicKernel{T} <: SimpleKernel end end +""" + PeriodicKernel(dims::Int) + +Create a [`PeriodicKernel`](@ref) with parameter `r=ones(Float64, dims)`. +""" PeriodicKernel(dims::Int) = PeriodicKernel(Float64, dims) """ - PeriodicKernel([T=Float64, dims::Int=1]) + PeriodicKernel(T, dims::Int=1) Create a [`PeriodicKernel`](@ref) with parameter `r=ones(T, dims)`. """ -PeriodicKernel(T::DataType, dims::Int=1) = PeriodicKernel(; r=ones(T, dims)) +PeriodicKernel(::Type{T}, dims::Int=1) where {T} = PeriodicKernel(; r=ones(T, dims)) -@functor PeriodicKernel +function ParameterHandling.flatten(::Type{T}, k::PeriodicKernel{S}) where {T<:Real,S} + vec = T[log(ri) for ri in k.r] + unflatten_to_periodickernel(v::Vector{T}) = PeriodicKernel(; r=S[exp(vi) for vi in v]) + return vec, unflatten_to_periodickernel +end metric(κ::PeriodicKernel) = Sinus(κ.r) diff --git a/src/basekernels/piecewisepolynomial.jl b/src/basekernels/piecewisepolynomial.jl index 07b3638dd..39d8f7cf3 100644 --- a/src/basekernels/piecewisepolynomial.jl +++ b/src/basekernels/piecewisepolynomial.jl @@ -46,6 +46,8 @@ function PiecewisePolynomialKernel(; degree::Int=0, kwargs...) return PiecewisePolynomialKernel{degree}(; kwargs...) end +@noparams PiecewisePolynomialKernel + piecewise_polynomial_coefficients(::Val{0}, ::Int) = (1,) piecewise_polynomial_coefficients(::Val{1}, j::Int) = (1, j + 1) piecewise_polynomial_coefficients(::Val{2}, j::Int) = (1, j + 2, (j^2 + 4 * j)//3 + 1) diff --git a/src/basekernels/polynomial.jl b/src/basekernels/polynomial.jl index 0dbd2d52b..e97679384 100644 --- a/src/basekernels/polynomial.jl +++ b/src/basekernels/polynomial.jl @@ -13,42 +13,47 @@ k(x, x'; c) = x^\\top x' + c. See also: [`PolynomialKernel`](@ref) """ -struct LinearKernel{Tc<:Real} <: SimpleKernel - c::Vector{Tc} +struct LinearKernel{T<:Real} <: SimpleKernel + c::T function LinearKernel(c::Real) @check_args(LinearKernel, c, c >= zero(c), "c ≥ 0") - return new{typeof(c)}([c]) + return new{typeof(c)}(c) end end LinearKernel(; c::Real=0.0) = LinearKernel(c) -@functor LinearKernel +function ParameterHandling.flatten(::Type{T}, k::LinearKernel{S}) where {T<:Real,S<:Real} + function unflatten_to_linearkernel(v::Vector{T}) + return LinearKernel(S(exp(only(v)))) + end + return T[log(k.c)], unflatten_to_linearkernel +end __linear_kappa(c::Real, xᵀy::Real) = xᵀy + c -kappa(κ::LinearKernel, xᵀy::Real) = __linear_kappa(only(κ.c), xᵀy) +kappa(κ::LinearKernel, xᵀy::Real) = __linear_kappa(κ.c, xᵀy) metric(::LinearKernel) = DotProduct() function kernelmatrix(k::LinearKernel, x::AbstractVector, y::AbstractVector) - return __linear_kappa.(only(k.c), pairwise(metric(k), x, y)) + return __linear_kappa.(k.c, pairwise(metric(k), x, y)) end function kernelmatrix(k::LinearKernel, x::AbstractVector) - return __linear_kappa.(only(k.c), pairwise(metric(k), x)) + return __linear_kappa.(k.c, pairwise(metric(k), x)) end function kernelmatrix_diag(k::LinearKernel, x::AbstractVector, y::AbstractVector) - return __linear_kappa.(only(k.c), colwise(metric(k), x, y)) + return __linear_kappa.(k.c, colwise(metric(k), x, y)) end function kernelmatrix_diag(k::LinearKernel, x::AbstractVector) - return __linear_kappa.(only(k.c), colwise(metric(k), x)) + return __linear_kappa.(k.c, colwise(metric(k), x)) end -Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", only(κ.c), ")") +Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", κ.c, ")") """ PolynomialKernel(; degree::Int=2, c::Real=0.0) @@ -65,19 +70,27 @@ k(x, x'; c, \\nu) = (x^\\top x' + c)^\\nu. See also: [`LinearKernel`](@ref) """ -struct PolynomialKernel{Tc<:Real} <: SimpleKernel +struct PolynomialKernel{T<:Real} <: SimpleKernel degree::Int - c::Vector{Tc} + c::T - function PolynomialKernel{Tc}(degree::Int, c::Vector{Tc}) where {Tc} + function PolynomialKernel(degree::Int, c::Real) @check_args(PolynomialKernel, degree, degree >= one(degree), "degree ≥ 1") - @check_args(PolynomialKernel, c, only(c) >= zero(Tc), "c ≥ 0") - return new{Tc}(degree, c) + @check_args(PolynomialKernel, c, c >= zero(c), "c ≥ 0") + return new{typeof(c)}(degree, c) end end -function PolynomialKernel(; degree::Int=2, c::Real=0.0) - return PolynomialKernel{typeof(c)}(degree, [c]) +PolynomialKernel(; degree::Int=2, c::Real=0.0) = PolynomialKernel(degree, c) + +function ParameterHandling.flatten( + ::Type{T}, k::PolynomialKernel{S} +) where {T<:Real,S<:Real} + degree = k.degree + function unflatten_to_polynomialkernel(v::Vector{T}) + return PolynomialKernel(degree, S(exp(only(v)))) + end + return T[log(k.c)], unflatten_to_polynomialkernel end # The degree of the polynomial kernel is a fixed discrete parameter @@ -92,26 +105,26 @@ end (κ::_PolynomialKappa)(c::Real, xᵀy::Real) = (xᵀy + c)^κ.degree -kappa(κ::PolynomialKernel, xᵀy::Real) = _PolynomialKappa(κ.degree)(only(κ.c), xᵀy) +kappa(κ::PolynomialKernel, xᵀy::Real) = _PolynomialKappa(κ.degree)(κ.c, xᵀy) metric(::PolynomialKernel) = DotProduct() function kernelmatrix(k::PolynomialKernel, x::AbstractVector, y::AbstractVector) - return _PolynomialKappa(k.degree).(only(k.c), pairwise(metric(k), x, y)) + return _PolynomialKappa(k.degree).(k.c, pairwise(metric(k), x, y)) end function kernelmatrix(k::PolynomialKernel, x::AbstractVector) - return _PolynomialKappa(k.degree).(only(k.c), pairwise(metric(k), x)) + return _PolynomialKappa(k.degree).(k.c, pairwise(metric(k), x)) end function kernelmatrix_diag(k::PolynomialKernel, x::AbstractVector, y::AbstractVector) - return _PolynomialKappa(k.degree).(only(k.c), colwise(metric(k), x, y)) + return _PolynomialKappa(k.degree).(k.c, colwise(metric(k), x, y)) end function kernelmatrix_diag(k::PolynomialKernel, x::AbstractVector) - return _PolynomialKappa(k.degree).(only(k.c), colwise(metric(k), x)) + return _PolynomialKappa(k.degree).(k.c, colwise(metric(k), x)) end function Base.show(io::IO, κ::PolynomialKernel) - return print(io, "Polynomial Kernel (c = ", only(κ.c), ", degree = ", κ.degree, ")") + return print(io, "Polynomial Kernel (c = ", κ.c, ", degree = ", κ.degree, ")") end diff --git a/src/basekernels/rational.jl b/src/basekernels/rational.jl index 835fe92ff..d84ac9941 100644 --- a/src/basekernels/rational.jl +++ b/src/basekernels/rational.jl @@ -15,13 +15,13 @@ The [`ExponentialKernel`](@ref) is recovered in the limit as ``\\alpha \\to \\in See also: [`GammaRationalKernel`](@ref) """ -struct RationalKernel{Tα<:Real,M} <: SimpleKernel - α::Vector{Tα} +struct RationalKernel{T<:Real,M} <: SimpleKernel + α::T metric::M function RationalKernel(α::Real, metric) @check_args(RationalKernel, α, α > zero(α), "α > 0") - return new{typeof(α),typeof(metric)}([α], metric) + return new{typeof(α),typeof(metric)}(α, metric) end end @@ -29,36 +29,42 @@ function RationalKernel(; alpha::Real=2.0, α::Real=alpha, metric=Euclidean()) return RationalKernel(α, metric) end -@functor RationalKernel +function ParameterHandling.flatten(::Type{T}, k::RationalKernel{S}) where {T<:Real,S} + metric = k.metric + function unflatten_to_rationalkernel(v::Vector{T}) + return RationalKernel(S(exp(only(v))), metric) + end + return T[log(k.α)], unflatten_to_rationalkernel +end __rational_kappa(α::Real, d::Real) = (one(d) + d / α)^(-α) -kappa(κ::RationalKernel, d::Real) = __rational_kappa(only(κ.α), d) +kappa(κ::RationalKernel, d::Real) = __rational_kappa(κ.α, d) metric(k::RationalKernel) = k.metric # AD-performance optimisation. Is unit tested. function kernelmatrix(k::RationalKernel, x::AbstractVector, y::AbstractVector) - return __rational_kappa.(only(k.α), pairwise(metric(k), x, y)) + return __rational_kappa.(k.α, pairwise(metric(k), x, y)) end # AD-performance optimisation. Is unit tested. function kernelmatrix(k::RationalKernel, x::AbstractVector) - return __rational_kappa.(only(k.α), pairwise(metric(k), x)) + return __rational_kappa.(k.α, pairwise(metric(k), x)) end # AD-performance optimisation. Is unit tested. function kernelmatrix_diag(k::RationalKernel, x::AbstractVector, y::AbstractVector) - return __rational_kappa.(only(k.α), colwise(metric(k), x, y)) + return __rational_kappa.(k.α, colwise(metric(k), x, y)) end # AD-performance optimisation. Is unit tested. function kernelmatrix_diag(k::RationalKernel, x::AbstractVector) - return __rational_kappa.(only(k.α), colwise(metric(k), x)) + return __rational_kappa.(k.α, colwise(metric(k), x)) end function Base.show(io::IO, κ::RationalKernel) - return print(io, "Rational Kernel (α = ", only(κ.α), ", metric = ", κ.metric, ")") + return print(io, "Rational Kernel (α = ", κ.α, ", metric = ", κ.metric, ")") end """ @@ -79,14 +85,24 @@ The [`SqExponentialKernel`](@ref) is recovered in the limit as ``\\alpha \\to \\ See also: [`GammaRationalKernel`](@ref) """ -struct RationalQuadraticKernel{Tα<:Real,M} <: SimpleKernel - α::Vector{Tα} +struct RationalQuadraticKernel{T<:Real,M} <: SimpleKernel + α::T metric::M function RationalQuadraticKernel(; alpha::Real=2.0, α::Real=alpha, metric=Euclidean()) @check_args(RationalQuadraticKernel, α, α > zero(α), "α > 0") - return new{typeof(α),typeof(metric)}([α], metric) + return new{typeof(α),typeof(metric)}(α, metric) + end +end + +function ParameterHandling.flatten( + ::Type{T}, k::RationalQuadraticKernel{S} +) where {T<:Real,S} + metric = k.metric + function unflatten_to_rationalquadratickernel(v::Vector{T}) + return RationalQuadraticKernel(; α=S(exp(only(v))), metric=metric) end + return T[log(k.α)], unflatten_to_rationalquadratickernel end const _RQ_Euclidean = RationalQuadraticKernel{<:Real,<:Euclidean} @@ -96,56 +112,54 @@ const _RQ_Euclidean = RationalQuadraticKernel{<:Real,<:Euclidean} __rq_kappa(α::Real, d::Real) = (one(d) + d^2 / (2 * α))^(-α) __rq_kappa_euclidean(α::Real, d²::Real) = (one(d²) + d² / (2 * α))^(-α) -kappa(κ::RationalQuadraticKernel, d::Real) = __rq_kappa(only(κ.α), d) -kappa(κ::_RQ_Euclidean, d²::Real) = __rq_kappa_euclidean(only(κ.α), d²) +kappa(κ::RationalQuadraticKernel, d::Real) = __rq_kappa(κ.α, d) +kappa(κ::_RQ_Euclidean, d²::Real) = __rq_kappa_euclidean(κ.α, d²) metric(k::RationalQuadraticKernel) = k.metric metric(::RationalQuadraticKernel{<:Real,<:Euclidean}) = SqEuclidean() # AD-performance optimisation. Is unit tested. function kernelmatrix(k::RationalQuadraticKernel, x::AbstractVector, y::AbstractVector) - return __rq_kappa.(only(k.α), pairwise(metric(k), x, y)) + return __rq_kappa.(k.α, pairwise(metric(k), x, y)) end # AD-performance optimisation. Is unit tested. function kernelmatrix(k::RationalQuadraticKernel, x::AbstractVector) - return __rq_kappa.(only(k.α), pairwise(metric(k), x)) + return __rq_kappa.(k.α, pairwise(metric(k), x)) end # AD-performance optimisation. Is unit tested. function kernelmatrix_diag(k::RationalQuadraticKernel, x::AbstractVector, y::AbstractVector) - return __rq_kappa.(only(k.α), colwise(metric(k), x, y)) + return __rq_kappa.(k.α, colwise(metric(k), x, y)) end # AD-performance optimisation. Is unit tested. function kernelmatrix_diag(k::RationalQuadraticKernel, x::AbstractVector) - return __rq_kappa.(only(k.α), colwise(metric(k), x)) + return __rq_kappa.(k.α, colwise(metric(k), x)) end # AD-performance optimisation. Is unit tested. function kernelmatrix(k::_RQ_Euclidean, x::AbstractVector, y::AbstractVector) - return __rq_kappa_euclidean.(only(k.α), pairwise(SqEuclidean(), x, y)) + return __rq_kappa_euclidean.(k.α, pairwise(SqEuclidean(), x, y)) end # AD-performance optimisation. Is unit tested. function kernelmatrix(k::_RQ_Euclidean, x::AbstractVector) - return __rq_kappa_euclidean.(only(k.α), pairwise(SqEuclidean(), x)) + return __rq_kappa_euclidean.(k.α, pairwise(SqEuclidean(), x)) end # AD-performance optimisation. Is unit tested. function kernelmatrix_diag(k::_RQ_Euclidean, x::AbstractVector, y::AbstractVector) - return __rq_kappa_euclidean.(only(k.α), colwise(SqEuclidean(), x, y)) + return __rq_kappa_euclidean.(k.α, colwise(SqEuclidean(), x, y)) end # AD-performance optimisation. Is unit tested. function kernelmatrix_diag(k::_RQ_Euclidean, x::AbstractVector) - return __rq_kappa_euclidean.(only(k.α), colwise(SqEuclidean(), x)) + return __rq_kappa_euclidean.(k.α, colwise(SqEuclidean(), x)) end function Base.show(io::IO, κ::RationalQuadraticKernel) - return print( - io, "Rational Quadratic Kernel (α = ", only(κ.α), ", metric = ", κ.metric, ")" - ) + return print(io, "Rational Quadratic Kernel (α = ", κ.α, ", metric = ", κ.metric, ")") end """ @@ -167,8 +181,8 @@ The [`GammaExponentialKernel`](@ref) is recovered in the limit as ``\\alpha \\to See also: [`RationalKernel`](@ref), [`RationalQuadraticKernel`](@ref) """ struct GammaRationalKernel{Tα<:Real,Tγ<:Real,M} <: SimpleKernel - α::Vector{Tα} - γ::Vector{Tγ} + α::Tα + γ::Tγ metric::M function GammaRationalKernel(; @@ -176,47 +190,53 @@ struct GammaRationalKernel{Tα<:Real,Tγ<:Real,M} <: SimpleKernel ) @check_args(GammaRationalKernel, α, α > zero(α), "α > 0") @check_args(GammaRationalKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]") - return new{typeof(α),typeof(γ),typeof(metric)}([α], [γ], metric) + return new{typeof(α),typeof(γ),typeof(metric)}(α, γ, metric) end end -@functor GammaRationalKernel +function ParameterHandling.flatten( + ::Type{T}, k::GammaRationalKernel{Tα,Tγ} +) where {T<:Real,Tα,Tγ} + vec = T[log(k.α), logit(k.γ / 2)] + metric = k.metric + function unflatten_to_gammarationalkernel(v::Vector{T}) + length(v) == 2 || error("incorrect number of parameters") + logα, logitγ = v + α = Tα(exp(logα)) + γ = Tγ(2 * logistic(logitγ)) + return GammaRationalKernel(; α=α, γ=γ, metric=metric) + end + return vec, unflatten_to_gammarationalkernel +end __grk_kappa(α::Real, γ::Real, d::Real) = (one(d) + d^γ / α)^(-α) -kappa(κ::GammaRationalKernel, d::Real) = __grk_kappa(only(κ.α), only(κ.γ), d) +kappa(κ::GammaRationalKernel, d::Real) = __grk_kappa(κ.α, κ.γ, d) metric(k::GammaRationalKernel) = k.metric # AD-performance optimisation. Is unit tested. function kernelmatrix(k::GammaRationalKernel, x::AbstractVector, y::AbstractVector) - return __grk_kappa.(only(k.α), only(k.γ), pairwise(metric(k), x, y)) + return __grk_kappa.(k.α, k.γ, pairwise(metric(k), x, y)) end # AD-performance optimisation. Is unit tested. function kernelmatrix(k::GammaRationalKernel, x::AbstractVector) - return __grk_kappa.(only(k.α), only(k.γ), pairwise(metric(k), x)) + return __grk_kappa.(k.α, k.γ, pairwise(metric(k), x)) end # AD-performance optimisation. Is unit tested. function kernelmatrix_diag(k::GammaRationalKernel, x::AbstractVector, y::AbstractVector) - return __grk_kappa.(only(k.α), only(k.γ), colwise(metric(k), x, y)) + return __grk_kappa.(k.α, k.γ, colwise(metric(k), x, y)) end # AD-performance optimisation. Is unit tested. function kernelmatrix_diag(k::GammaRationalKernel, x::AbstractVector) - return __grk_kappa.(only(k.α), only(k.γ), colwise(metric(k), x)) + return __grk_kappa.(k.α, k.γ, colwise(metric(k), x)) end function Base.show(io::IO, κ::GammaRationalKernel) return print( - io, - "Gamma Rational Kernel (α = ", - only(κ.α), - ", γ = ", - only(κ.γ), - ", metric = ", - κ.metric, - ")", + io, "Gamma Rational Kernel (α = ", κ.α, ", γ = ", κ.γ, ", metric = ", κ.metric, ")" ) end diff --git a/src/basekernels/wiener.jl b/src/basekernels/wiener.jl index 14d330850..741a5960f 100644 --- a/src/basekernels/wiener.jl +++ b/src/basekernels/wiener.jl @@ -52,6 +52,8 @@ function WienerKernel(; i::Integer=0) return WienerKernel{i}() end +@noparams WienerKernel + function (::WienerKernel{0})(x, y) X = sqrt(sum(abs2, x)) Y = sqrt(sum(abs2, y)) diff --git a/src/kernels/gibbskernel.jl b/src/kernels/gibbskernel.jl index 46e14995d..d1bd9f4ec 100644 --- a/src/kernels/gibbskernel.jl +++ b/src/kernels/gibbskernel.jl @@ -36,6 +36,16 @@ end GibbsKernel(; lengthscale) = GibbsKernel(lengthscale) +@functor GibbsKernel + +# or just `@noparams GibbsKernel` - it would be safer since there is no +# default fallback for `flatten` +function ParameterHandling.flatten(::Type{T}, k::GibbsKernel) where {T<:Real} + vec, unflatten_to_lengthscale = flatten(T, k.lengthscale) + unflatten_to_gibbskernel(v::Vector{T}) = GibbsKernel(unflatten_to_lengthscale(v)) + return vec, unflatten_to_gibbskernel +end + function (k::GibbsKernel)(x, y) lengthscale = k.lengthscale lx = lengthscale(x) diff --git a/src/kernels/kernelproduct.jl b/src/kernels/kernelproduct.jl index 4a0bf932f..f4447fad9 100644 --- a/src/kernels/kernelproduct.jl +++ b/src/kernels/kernelproduct.jl @@ -41,6 +41,25 @@ end @functor KernelProduct +function ParameterHandling.flatten(::Type{T}, k::KernelProduct) where {T<:Real} + vecs_and_backs = map(Base.Fix1(flatten, T), k.kernels) + vecs = map(first, vecs_and_backs) + length_vecs = map(length, vecs) + backs = map(last, vecs_and_backs) + flat_vecs = reduce(vcat, vecs) + function unflatten_to_kernelproduct(v::Vector{T}) + length(v) == length(flat_vecs) || error("incorrect number of parameters") + offset = Ref(0) + kernels = map(backs, length_vecs) do back, length_vec + oldoffset = offset[] + newoffset = offset[] = oldoffset + length_vec + return back(v[(oldoffset + 1):newoffset]) + end + return KernelProduct(kernels) + end + return flat_vecs, unflatten_to_kernelproduct +end + Base.length(k::KernelProduct) = length(k.kernels) function _hadamard(f, ks::Tuple, args...) diff --git a/src/kernels/kernelsum.jl b/src/kernels/kernelsum.jl index 77709502c..d127610d5 100644 --- a/src/kernels/kernelsum.jl +++ b/src/kernels/kernelsum.jl @@ -41,6 +41,26 @@ end @functor KernelSum +function ParameterHandling.flatten(::Type{T}, k::KernelSum) where {T<:Real} + vecs_and_backs = map(Base.Fix1(flatten, T), k.kernels) + vecs = map(first, vecs_and_backs) + length_vecs = map(length, vecs) + backs = map(last, vecs_and_backs) + flat_vecs = reduce(vcat, vecs) + n = length(flat_vecs) + function unflatten_to_kernelsum(v::Vector{T}) + length(v) == n || error("incorrect number of parameters") + offset = Ref(0) + kernels = map(backs, length_vecs) do back, length_vec + oldoffset = offset[] + newoffset = offset[] = oldoffset + length_vec + return back(v[(oldoffset + 1):newoffset]) + end + return KernelSum(kernels) + end + return flat_vecs, unflatten_to_kernelsum +end + Base.length(k::KernelSum) = length(k.kernels) _sum(f, ks::Tuple, args...) = f(first(ks), args...) + _sum(f, Base.tail(ks), args...) diff --git a/src/kernels/kerneltensorproduct.jl b/src/kernels/kerneltensorproduct.jl index ea0044409..ab3f3486a 100644 --- a/src/kernels/kerneltensorproduct.jl +++ b/src/kernels/kerneltensorproduct.jl @@ -47,6 +47,25 @@ end @functor KernelTensorProduct +function ParameterHandling.flatten(::Type{T}, k::KernelTensorProduct) where {T<:Real} + vecs_and_backs = map(Base.Fix1(flatten, T), k.kernels) + vecs = map(first, vecs_and_backs) + length_vecs = map(length, vecs) + backs = map(last, vecs_and_backs) + flat_vecs = reduce(vcat, vecs) + function unflatten_to_kerneltensorproduct(v::Vector{T}) + length(v) == length(flat_vecs) || error("incorrect number of parameters") + offset = Ref(0) + kernels = map(backs, length_vecs) do back, length_vec + oldoffset = offset[] + newoffset = offset[] = oldoffset + length_vec + return back(v[(oldoffset + 1):newoffset]) + end + return KernelTensorProduct(kernels) + end + return flat_vecs, unflatten_to_kerneltensorproduct +end + Base.length(kernel::KernelTensorProduct) = length(kernel.kernels) function (kernel::KernelTensorProduct)(x, y) diff --git a/src/kernels/normalizedkernel.jl b/src/kernels/normalizedkernel.jl index e9fc570f2..19407a3cf 100644 --- a/src/kernels/normalizedkernel.jl +++ b/src/kernels/normalizedkernel.jl @@ -17,6 +17,12 @@ end @functor NormalizedKernel +function ParameterHandling.flatten(::Type{T}, k::NormalizedKernel) where {T<:Real} + vec, back = flatten(T, k.kernel) + unflatten_to_normalizedkernel(v::Vector{T}) = NormalizedKernel(back(v)) + return vec, unflatten_to_normalizedkernel +end + (κ::NormalizedKernel)(x, y) = κ.kernel(x, y) / sqrt(κ.kernel(x, x) * κ.kernel(y, y)) function kernelmatrix(κ::NormalizedKernel, x::AbstractVector, y::AbstractVector) diff --git a/src/kernels/parameterkernel.jl b/src/kernels/parameterkernel.jl new file mode 100644 index 000000000..5d14d16c8 --- /dev/null +++ b/src/kernels/parameterkernel.jl @@ -0,0 +1,152 @@ +""" + ParameterKernel(params, kernel) + +Kernel with parameters `params` that can be instantiated by calling `kernel(params)`. + +This kernel is particularly useful if you want to optimize a vector of, +usually unconstrained, kernel parameters `params` with e.g. +[Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl) or +[Flux.jl](https://github.com/FluxML/Flux.jl). + +# Examples + +There are two different approaches for obtaining the parameters `params` and the function +`kernel` from which a `ParameterKernel` can be constructed. + +## Extracting parameters from an existing kernel + +You can extract the parameters `params` and the function `kernel` from an existing kernel +`k` with `ParameterHandling.flatten`: +```jldoctest parameterkernel1 +julia> k = 2.0 * (RationalQuadraticKernel(; α=1.0) + ConstantKernel(; c=2.5)); + +julia> params, kernel = ParameterHandling.flatten(k); +``` + +Here `params` is a vector of the three parameters of kernel `k`. In this example, all these +parameters must be positive (otherwise `k` would not be a positive-definite kernel). To +simplify unconstrained optimization with e.g. Optim.jl or Flux.jl, +`ParameterHandling.flatten` automatically transforms the parameters to unconstrained values: +```jldoctest parameterkernel1 +julia> params ≈ map(log, [1.0, 2.5, 2.0]) +true +``` + +Kernel `k` can be reconstructed with the `kernel` function: +```jldoctest parameterkernel1 +julia> kernel(params) +Sum of 2 kernels: + Rational Quadratic Kernel (α = 1.0, metric = Distances.Euclidean(0.0)) + Constant Kernel (c = 2.5) + - σ² = 2.0 +``` + +As expected, different parameter values yield a kernel of the same structure with different +parameters: +```jldoctest parameterkernel1 +julia> kernel([log(0.25), log(0.5), log(2.0)]) +Sum of 2 kernels: + Rational Quadratic Kernel (α = 0.25, metric = Distances.Euclidean(0.0)) + Constant Kernel (c = 0.5) + - σ² = 2.0 +``` + +## Defining a function that constructs the kernel + +Instead of extracting parameters and a reconstruction function from an existing kernel you +can explicitly define a function that constructs the kernel of interest and a set of +parameters. + +```jldoctest parameterkernel2 +julia> using LogExpFunctions + +julia> function kernel(params) + length(params) == 1 || throw(ArgumentError("incorrect number of parameters")) + p = first(params) + return 2 * (RationalQuadraticKernel(; α=log1pexp(p)) + ConstantKernel(; c=exp(p))) + end; +``` + +With the function `kernel` kernels of the same structure as in the example above can be +constructed: +```jldoctest parameterkernel2 +julia> kernel([log(0.5)]) +Sum of 2 kernels: + Rational Quadratic Kernel (α = 0.4054651081081644, metric = Distances.Euclidean(0.0)) + Constant Kernel (c = 0.5) + - σ² = 2 +``` + +This example shows that defining `kernel` manually has some advantages over using +`ParameterHandling.flatten`: +- Kernel parameters can be fixed (scale parameter is always set to `2` in this example) +- Kernel parameters can be transformed from unconstrained to constrained space with + non-default mappings (shape parameter `α` is transformed with `log1pexp`) +- Kernel parameters can be linked (parameters `α` and `c` are computed from a single + parameter `p`) + +See also: [ParameterHandling.jl](https://github.com/invenia/ParameterHandling.jl) +""" +struct ParameterKernel{P,K} <: Kernel + params::P + kernel::K +end + +# convenience function +""" + ParameterKernel(kernel::Kernel) + +Construct a `ParameterKernel` from an existing `kernel`. + +The constructor is a short-hand for `ParameterKernel(ParameterHandling.flatten(kernel)...)`. +""" +ParameterKernel(kernel::Kernel) = ParameterKernel(flatten(kernel)...) + +Functors.@functor ParameterKernel (params,) + +function ParameterHandling.flatten(::Type{T}, kernel::ParameterKernel) where {T<:Real} + params_vec, unflatten_to_params = flatten(T, kernel.params) + k = kernel.kernel + function unflatten_to_parameterkernel(v::Vector{T}) + return ParameterKernel(unflatten_to_params(v), k) + end + return params_vec, unflatten_to_parameterkernel +end + +(k::ParameterKernel)(x, y) = k.kernel(k.params)(x, y) + +function kernelmatrix(k::ParameterKernel, x::AbstractVector) + return kernelmatrix(k.kernel(k.params), x) +end + +function kernelmatrix(k::ParameterKernel, x::AbstractVector, y::AbstractVector) + return kernelmatrix(k.kernel(k.params), x, y) +end + +function kernelmatrix!(K::AbstractMatrix, k::ParameterKernel, x::AbstractVector) + return kernelmatrix!(K, k.kernel(k.params), x) +end + +function kernelmatrix!( + K::AbstractMatrix, k::ParameterKernel, x::AbstractVector, y::AbstractVector +) + return kernelmatrix!(K, k.kernel(k.params), x, y) +end + +function kernelmatrix_diag(k::ParameterKernel, x::AbstractVector) + return kernelmatrix_diag(k.kernel(k.params), x) +end + +function kernelmatrix_diag(k::ParameterKernel, x::AbstractVector, y::AbstractVector) + return kernelmatrix_diag(k.kernel(k.params), x, y) +end + +function kernelmatrix_diag!(K::AbstractVector, k::ParameterKernel, x::AbstractVector) + return kernelmatrix_diag!(K, k.kernel(k.params), x) +end + +function kernelmatrix_diag!( + K::AbstractVector, k::ParameterKernel, x::AbstractVector, y::AbstractVector +) + return kernelmatrix_diag!(K, k.kernel(k.params), x, y) +end diff --git a/src/kernels/scaledkernel.jl b/src/kernels/scaledkernel.jl index 0a17943ac..79fb612b9 100644 --- a/src/kernels/scaledkernel.jl +++ b/src/kernels/scaledkernel.jl @@ -13,32 +13,47 @@ multiplication with variance ``\\sigma^2 > 0`` is defined as """ struct ScaledKernel{Tk<:Kernel,Tσ²<:Real} <: Kernel kernel::Tk - σ²::Vector{Tσ²} -end + σ²::Tσ² -function ScaledKernel(kernel::Tk, σ²::Tσ²=1.0) where {Tk<:Kernel,Tσ²<:Real} - @check_args(ScaledKernel, σ², σ² > zero(Tσ²), "σ² > 0") - return ScaledKernel{Tk,Tσ²}(kernel, [σ²]) + function ScaledKernel(kernel::Kernel, σ²::Real) + @check_args(ScaledKernel, σ², σ² > zero(σ²), "σ² > 0") + return new{typeof(kernel),typeof(σ²)}(kernel, σ²) + end end -@functor ScaledKernel +ScaledKernel(kernel::Kernel) = ScaledKernel(kernel, 1.0) + +# σ² is a positive parameter (and a scalar!) but Functors does not handle +# parameter constraints +@functor ScaledKernel (kernel,) + +function ParameterHandling.flatten( + ::Type{T}, k::ScaledKernel{<:Kernel,S} +) where {T<:Real,S<:Real} + kernel_vec, kernel_back = flatten(T, k.kernel) + function unflatten_to_scaledkernel(v::Vector{T}) + kernel = kernel_back(v[1:(end - 1)]) + return ScaledKernel(kernel, S(exp(last(v)))) + end + return vcat(kernel_vec, T(log(k.σ²))), unflatten_to_scaledkernel +end -(k::ScaledKernel)(x, y) = only(k.σ²) * k.kernel(x, y) +(k::ScaledKernel)(x, y) = k.σ² * k.kernel(x, y) function kernelmatrix(κ::ScaledKernel, x::AbstractVector, y::AbstractVector) - return only(κ.σ²) * kernelmatrix(κ.kernel, x, y) + return κ.σ² * kernelmatrix(κ.kernel, x, y) end function kernelmatrix(κ::ScaledKernel, x::AbstractVector) - return only(κ.σ²) * kernelmatrix(κ.kernel, x) + return κ.σ² * kernelmatrix(κ.kernel, x) end function kernelmatrix_diag(κ::ScaledKernel, x::AbstractVector) - return only(κ.σ²) * kernelmatrix_diag(κ.kernel, x) + return κ.σ² * kernelmatrix_diag(κ.kernel, x) end function kernelmatrix_diag(κ::ScaledKernel, x::AbstractVector, y::AbstractVector) - return only(κ.σ²) * kernelmatrix_diag(κ.kernel, x, y) + return κ.σ² * kernelmatrix_diag(κ.kernel, x, y) end function kernelmatrix!( @@ -75,5 +90,9 @@ Base.show(io::IO, κ::ScaledKernel) = printshifted(io, κ, 0) function printshifted(io::IO, κ::ScaledKernel, shift::Int) printshifted(io, κ.kernel, shift) - return print(io, "\n" * ("\t"^(shift + 1)) * "- σ² = $(only(κ.σ²))") + print(io, "\n") + for _ in 1:(shift + 1) + print(io, "\t") + end + return print(io, "- σ² = ", κ.σ²) end diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index 88e719ef1..0de2b4f2b 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -16,6 +16,21 @@ end @functor TransformedKernel +function ParameterHandling.flatten(::Type{T}, k::TransformedKernel) where {T<:Real} + kernel_vec, kernel_back = flatten(T, k.kernel) + transform_vec, transform_back = flatten(T, k.transform) + v = vcat(kernel_vec, transform_vec) + n = length(v) + nkernel = length(kernel_vec) + function unflatten_to_transformedkernel(v::Vector{T}) + length(v) == n || error("incorrect number of parameters") + kernel = kernel_back(v[1:nkernel]) + transform = transform_back(v[(nkernel + 1):end]) + return TransformedKernel(kernel, transform) + end + return v, unflatten_to_transformedkernel +end + (k::TransformedKernel)(x, y) = k.kernel(k.transform(x), k.transform(y)) # Optimizations for scale transforms of simple kernels to save allocations: diff --git a/src/mokernels/independent.jl b/src/mokernels/independent.jl index 1f7811b14..10e8197fd 100644 --- a/src/mokernels/independent.jl +++ b/src/mokernels/independent.jl @@ -23,6 +23,16 @@ struct IndependentMOKernel{Tkernel<:Kernel} <: MOKernel kernel::Tkernel end +@functor IndependentMOKernel + +function ParameterHandling.flatten(::Type{T}, k::IndependentMOKernel) where {T<:Real} + vec, unflatten_to_kernel = flatten(T, k.kernel) + function unflatten_to_independentmokernel(v::Vector{T}) + return IndependentMOKernel(unflatten_to_kernel(v)) + end + return vec, unflatten_to_independentmokernel +end + function (κ::IndependentMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,Int}) return κ.kernel(x, y) * (px == py) end diff --git a/src/mokernels/intrinsiccoregion.jl b/src/mokernels/intrinsiccoregion.jl index 0a940796b..223485422 100644 --- a/src/mokernels/intrinsiccoregion.jl +++ b/src/mokernels/intrinsiccoregion.jl @@ -38,6 +38,22 @@ function IntrinsicCoregionMOKernel(kernel::Kernel, B::AbstractMatrix) return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B) end +@functor IntrinsicCoregionMOKernel (kernel,) + +function ParameterHandling.flatten(::Type{T}, k::IntrinsicCoregionMOKernel) where {T<:Real} + kernel_vec, unflatten_to_kernel = flatten(T, k.kernel) + B_vec, unflatten_to_B = value_flatten(T, positive_definite(k.B)) + nkernel = length(kernel_vec) + ntotal = nkernel + length(B_vec) + function unflatten_to_intrinsiccoregionkernel(v::Vector{T}) + length(v) == ntotal || error("incorrect number of parameters") + kernel = unflatten_to_kernel(v[1:nkernel]) + B = unflatten_to_B(v[(nkernel + 1):end]) + return IntrinsicCoregionMOKernel(kernel, B) + end + return vcat(kernel_vec, B_vec), unflatten_to_intrinsiccoregionkernel +end + function (k::IntrinsicCoregionMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,Int}) return k.B[px, py] * k.kernel(x, y) end diff --git a/src/mokernels/lmm.jl b/src/mokernels/lmm.jl index 0045d520e..8e37114e7 100644 --- a/src/mokernels/lmm.jl +++ b/src/mokernels/lmm.jl @@ -31,6 +31,32 @@ function LinearMixingModelKernel(k::Kernel, H::AbstractMatrix) return LinearMixingModelKernel(Fill(k, size(H, 1)), H) end +@functor LinearMixingModelKernel + +function ParameterHandling.flatten(::Type{T}, k::LinearMixingModelKernel) where {T<:Real} + kernel_vecs_and_backs = map(Base.Fix1(flatten, T), k.kernels) + kernel_vecs = map(first, kernel_vecs_and_backs) + length_kernel_vecs = map(length, kernel_vecs) + kernel_backs = map(last, kernel_vecs_and_backs) + H_vec, H_back = flatten(T, k.B) + flat_kernel_vecs = reduce(vcat, vecs) + nkernel = length(flat_kernel_vecs) + flat_vecs = vcat(flat_kernel_vecs, H_vec) + n = length(flat_vecs) + function unflatten_to_linearmixingmodelkernel(v::Vector{T}) + length(v) == n || error("incorrect number of parameters") + offset = Ref(0) + kernels = map(kernel_backs, length_kernel_vecs) do back, length_vec + oldoffset = offset[] + newoffset = offset[] = oldoffset + length_vec + return back(v[(oldoffset + 1):newoffset]) + end + H = H_back(v[(nkernel + 1):end]) + return LinearMixingModelKernel(kernels, H) + end + return flat_vecs, unflatten_to_linearmixingmodelkernel +end + function (κ::LinearMixingModelKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,Int}) (px > size(κ.H, 2) || py > size(κ.H, 2) || px < 1 || py < 1) && error("`px` and `py` must be within the range of the number of outputs") diff --git a/src/test_utils.jl b/src/test_utils.jl new file mode 100644 index 000000000..344994588 --- /dev/null +++ b/src/test_utils.jl @@ -0,0 +1,178 @@ +module TestUtils + +const __ATOL = 1e-9 +const __RTOL = 1e-9 + +using Distances +using LinearAlgebra +using KernelFunctions +using ParameterHandling +using Random +using Test + +""" + test_interface( + k::Kernel, + x0::AbstractVector, + x1::AbstractVector, + x2::AbstractVector; + atol=__ATOL, + ) + +Run various consistency checks on `k` at the inputs `x0`, `x1`, and `x2`. +`x0` and `x1` should be of the same length with different values, while `x0` and `x2` should +be of different lengths. + + test_interface([rng::AbstractRNG], k::Kernel, T::Type{<:AbstractVector}; atol=__ATOL) + +`test_interface` offers certain types of test data generation to make running these tests +require less code for common input types. For example, `Vector{<:Real}`, `ColVecs{<:Real}`, +and `RowVecs{<:Real}` are supported. For other input vector types, please provide the data +manually. +""" +function test_interface( + k::Kernel, + x0::AbstractVector, + x1::AbstractVector, + x2::AbstractVector; + atol=__ATOL, + rtol=__RTOL, +) + # Ensure that we have the required inputs. + @assert length(x0) == length(x1) + @assert length(x0) ≠ length(x2) + + # Check that kernelmatrix_diag basically works. + @test kernelmatrix_diag(k, x0, x1) isa AbstractVector + @test length(kernelmatrix_diag(k, x0, x1)) == length(x0) + + # Check that pairwise basically works. + @test kernelmatrix(k, x0, x2) isa AbstractMatrix + @test size(kernelmatrix(k, x0, x2)) == (length(x0), length(x2)) + + # Check that elementwise is consistent with pairwise. + @test kernelmatrix_diag(k, x0, x1) ≈ diag(kernelmatrix(k, x0, x1)) atol = atol + + # Check additional binary elementwise properties for kernels. + @test kernelmatrix_diag(k, x0, x1) ≈ kernelmatrix_diag(k, x1, x0) + @test kernelmatrix(k, x0, x2) ≈ kernelmatrix(k, x2, x0)' atol = atol + + # Check that unary elementwise basically works. + @test kernelmatrix_diag(k, x0) isa AbstractVector + @test length(kernelmatrix_diag(k, x0)) == length(x0) + + # Check that unary pairwise basically works. + @test kernelmatrix(k, x0) isa AbstractMatrix + @test size(kernelmatrix(k, x0)) == (length(x0), length(x0)) + @test kernelmatrix(k, x0) ≈ kernelmatrix(k, x0)' atol = atol + + # Check that unary elementwise is consistent with unary pairwise. + @test kernelmatrix_diag(k, x0) ≈ diag(kernelmatrix(k, x0)) atol = atol + + # Check that unary pairwise produces a positive definite matrix (approximately). + @test eigmin(Matrix(kernelmatrix(k, x0))) > -atol + + # Check that unary elementwise / pairwise are consistent with the binary versions. + @test kernelmatrix_diag(k, x0) ≈ kernelmatrix_diag(k, x0, x0) atol = atol rtol = rtol + @test kernelmatrix(k, x0) ≈ kernelmatrix(k, x0, x0) atol = atol rtol = rtol + + # Check that basic kernel evaluation succeeds and is consistent with `kernelmatrix`. + @test k(first(x0), first(x1)) isa Real + @test kernelmatrix(k, x0, x2) ≈ [k(xl, xr) for xl in x0, xr in x2] + + tmp = Matrix{Float64}(undef, length(x0), length(x2)) + @test kernelmatrix!(tmp, k, x0, x2) ≈ kernelmatrix(k, x0, x2) + + tmp_square = Matrix{Float64}(undef, length(x0), length(x0)) + @test kernelmatrix!(tmp_square, k, x0) ≈ kernelmatrix(k, x0) + + tmp_diag = Vector{Float64}(undef, length(x0)) + @test kernelmatrix_diag!(tmp_diag, k, x0) ≈ kernelmatrix_diag(k, x0) + @test kernelmatrix_diag!(tmp_diag, k, x0, x1) ≈ kernelmatrix_diag(k, x0, x1) + + # Check flatten/unflatten + ParameterHandling.TestUtils.test_flatten_interface(k) + + return nothing +end + +function test_interface( + rng::AbstractRNG, k::Kernel, ::Type{Vector{T}}; kwargs... +) where {T<:Real} + return test_interface( + k, randn(rng, T, 1001), randn(rng, T, 1001), randn(rng, T, 1000); kwargs... + ) +end + +function test_interface( + rng::AbstractRNG, k::MOKernel, ::Type{Vector{Tuple{T,Int}}}; dim_out=1, kwargs... +) where {T<:Real} + return test_interface( + k, + [(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:51], + [(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:51], + [(randn(rng, T), rand(rng, 1:dim_out)) for i in 1:50]; + kwargs..., + ) +end + +function test_interface( + rng::AbstractRNG, k::Kernel, ::Type{<:ColVecs{T}}; dim_in=2, kwargs... +) where {T<:Real} + return test_interface( + k, + ColVecs(randn(rng, T, dim_in, 1001)), + ColVecs(randn(rng, T, dim_in, 1001)), + ColVecs(randn(rng, T, dim_in, 1000)); + kwargs..., + ) +end + +function test_interface( + rng::AbstractRNG, k::Kernel, ::Type{<:RowVecs{T}}; dim_in=2, kwargs... +) where {T<:Real} + return test_interface( + k, + RowVecs(randn(rng, T, 1001, dim_in)), + RowVecs(randn(rng, T, 1001, dim_in)), + RowVecs(randn(rng, T, 1000, dim_in)); + kwargs..., + ) +end + +function test_interface( + rng::AbstractRNG, k::Kernel, ::Type{<:Vector{Vector{T}}}; dim_in=2, kwargs... +) where {T<:Real} + return test_interface( + k, + [randn(rng, T, dim_in) for _ in 1:1001], + [randn(rng, T, dim_in) for _ in 1:1001], + [randn(rng, T, dim_in) for _ in 1:1000]; + kwargs..., + ) +end + +function test_interface(k::Kernel, T::Type{<:AbstractVector}; kwargs...) + return test_interface(Random.GLOBAL_RNG, k, T; kwargs...) +end + +function test_interface(rng::AbstractRNG, k::Kernel, T::Type{<:Real}; kwargs...) + @testset "Vector{$T}" begin + test_interface(rng, k, Vector{T}; kwargs...) + end + @testset "ColVecs{$T}" begin + test_interface(rng, k, ColVecs{T}; kwargs...) + end + @testset "RowVecs{$T}" begin + test_interface(rng, k, RowVecs{T}; kwargs...) + end + @testset "Vector{Vector{T}}" begin + test_interface(rng, k, Vector{Vector{T}}; kwargs...) + end +end + +function test_interface(k::Kernel, T::Type{<:Real}=Float64; kwargs...) + return test_interface(Random.GLOBAL_RNG, k, T; kwargs...) +end + +end # module diff --git a/src/transform/ardtransform.jl b/src/transform/ardtransform.jl index 4cb40fe78..3e8aa2e12 100644 --- a/src/transform/ardtransform.jl +++ b/src/transform/ardtransform.jl @@ -23,7 +23,10 @@ Create an [`ARDTransform`](@ref) with vector `fill(s, dims)`. """ ARDTransform(s::Real, dims::Integer) = ARDTransform(fill(s, dims)) -@functor ARDTransform +function ParameterHandling.flatten(::Type{T}, t::ARDTransform{S}) where {T<:Real,S} + unflatten_to_ardtransform(v::Vector{T}) = ARDTransform(convert(S, map(exp, v))) + return convert(Vector, map(T ∘ log, t.v)), unflatten_to_ardtransform +end function set!(t::ARDTransform{<:AbstractVector{T}}, ρ::AbstractVector{T}) where {T<:Real} @assert length(ρ) == dim(t) "Trying to set a vector of size $(length(ρ)) to ARDTransform of dimension $(dim(t))" diff --git a/src/transform/chaintransform.jl b/src/transform/chaintransform.jl index 208b1f689..7060653e3 100644 --- a/src/transform/chaintransform.jl +++ b/src/transform/chaintransform.jl @@ -25,6 +25,25 @@ end @functor ChainTransform +function ParameterHandling.flatten(::Type{T}, t::ChainTransform) where {T<:Real} + vecs_and_backs = map(Base.Fix1(flatten, T), t.transforms) + vecs = map(first, vecs_and_backs) + length_vecs = map(length, vecs) + backs = map(last, vecs_and_backs) + flat_vecs = reduce(vcat, vecs) + function unflatten_to_chaintransform(v::Vector{T}) + length(v) == length(flat_vecs) || error("incorrect number of parameters") + offset = Ref(0) + transforms = map(backs, length_vecs) do back, length_vec + oldoffset = offset[] + newoffset = offset[] = oldoffset + length_vec + return back(v[(oldoffset + 1):newoffset]) + end + return ChainTransform(transforms) + end + return flat_vecs, unflatten_to_chaintransform +end + Base.length(t::ChainTransform) = length(t.transforms) # Constructor to create a chain transform with an array of parameters diff --git a/src/transform/lineartransform.jl b/src/transform/lineartransform.jl index b61ba6a94..620f99c02 100644 --- a/src/transform/lineartransform.jl +++ b/src/transform/lineartransform.jl @@ -18,7 +18,11 @@ struct LinearTransform{T<:AbstractMatrix{<:Real}} <: Transform A::T end -@functor LinearTransform +function ParameterHandling.flatten(::Type{T}, t::LinearTransform) where {T<:Real} + vec, back = flatten(T, t.A) + unflatten_to_lineartransform(v::Vector{T}) = LinearTransform(back(v)) + return vec, unflatten_to_lineartransform +end function set!(t::LinearTransform{<:AbstractMatrix{T}}, A::AbstractMatrix{T}) where {T<:Real} size(t.A) == size(A) || error( diff --git a/src/transform/periodic_transform.jl b/src/transform/periodic_transform.jl index 098262309..115625a61 100644 --- a/src/transform/periodic_transform.jl +++ b/src/transform/periodic_transform.jl @@ -15,26 +15,26 @@ julia> t(x) == [sinpi(2 * f * x), cospi(2 * f * x)] true ``` """ -struct PeriodicTransform{Tf<:AbstractVector{<:Real}} <: Transform - f::Tf +struct PeriodicTransform{T<:Real} <: Transform + f::T end -@functor PeriodicTransform - -PeriodicTransform(f::Real) = PeriodicTransform([f]) +function ParameterHandling.flatten(::Type{T}, t::PeriodicTransform) where {T<:Real} + f = t.f + unflatten_to_periodictransform(v::Vector{T}) = PeriodicTransform(oftype(f, only(v))) + return T[f], unflatten_to_periodictransform +end dim(t::PeriodicTransform) = 2 -(t::PeriodicTransform)(x::Real) = [sinpi(2 * only(t.f) * x), cospi(2 * only(t.f) * x)] +(t::PeriodicTransform)(x::Real) = [sinpi(2 * t.f * x), cospi(2 * t.f * x)] function _map(t::PeriodicTransform, x::AbstractVector{<:Real}) - return RowVecs(hcat(sinpi.((2 * only(t.f)) .* x), cospi.((2 * only(t.f)) .* x))) + return RowVecs(hcat(sinpi.((2 * t.f) .* x), cospi.((2 * t.f) .* x))) end -function Base.isequal(t1::PeriodicTransform, t2::PeriodicTransform) - return isequal(only(t1.f), only(t2.f)) -end +Base.isequal(t1::PeriodicTransform, t2::PeriodicTransform) = isequal(t1.f, t2.f) function Base.show(io::IO, t::PeriodicTransform) - return print(io, "Periodic Transform with frequency $(only(t.f))") + return print(io, "Periodic Transform with frequency ", t.f) end diff --git a/src/transform/scaletransform.jl b/src/transform/scaletransform.jl index 18923fcc4..3d013a67e 100644 --- a/src/transform/scaletransform.jl +++ b/src/transform/scaletransform.jl @@ -13,26 +13,30 @@ true ``` """ struct ScaleTransform{T<:Real} <: Transform - s::Vector{T} -end + s::T -function ScaleTransform(s::T=1.0) where {T<:Real} - return ScaleTransform{T}([s]) + function ScaleTransform(s::Real) + @check_args(ScaleTransform, s, s > zero(s), "s > 0") + return new{typeof(s)}(s) + end end -@functor ScaleTransform +ScaleTransform() = ScaleTransform(1.0) -set!(t::ScaleTransform, ρ::Real) = t.s .= [ρ] +function ParameterHandling.flatten(::Type{T}, t::ScaleTransform{S}) where {T<:Real,S<:Real} + unflatten_to_scaletransform(v::Vector{T}) = ScaleTransform(S(exp(only(v)))) + return T[log(t.s)], unflatten_to_scaletransform +end -(t::ScaleTransform)(x) = only(t.s) * x +(t::ScaleTransform)(x) = t.s * x -_map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x -_map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X) -_map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X) +_map(t::ScaleTransform, x::AbstractVector{<:Real}) = t.s .* x +_map(t::ScaleTransform, x::ColVecs) = ColVecs(t.s .* x.X) +_map(t::ScaleTransform, x::RowVecs) = RowVecs(t.s .* x.X) -Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(only(t.s), only(t2.s)) +Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(t.s, t2.s) -Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", only(t.s), ")") +Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", t.s, ")") # Helpers diff --git a/src/transform/transform.jl b/src/transform/transform.jl index 795a3498f..82b378e14 100644 --- a/src/transform/transform.jl +++ b/src/transform/transform.jl @@ -31,6 +31,8 @@ Transformation that returns exactly the input. """ struct IdentityTransform <: Transform end +@noparams IdentityTransform + (t::IdentityTransform)(x) = x # More efficient implementation than `map(IdentityTransform(), x)` diff --git a/src/utils.jl b/src/utils.jl index 29087259b..ff66ff8d4 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -241,3 +241,20 @@ end function validate_inplace_dims(K::AbstractVecOrMat, x::AbstractVector) return validate_inplace_dims(K, x, x) end + +# TODO: move to ParameterHandling? +""" + @noparams T + +Define `ParameterHandling.flatten` for a type `T` without parameters. +""" +macro noparams(T) + return quote + Base.@__doc__ function ParameterHandling.flatten( + ::Type{S}, x::$(esc(T)) + ) where {S<:Real} + unflatten(::Vector{S}) = x + return S[], unflatten + end + end +end diff --git a/test/basekernels/constant.jl b/test/basekernels/constant.jl index 435a1678a..c899e496b 100644 --- a/test/basekernels/constant.jl +++ b/test/basekernels/constant.jl @@ -9,6 +9,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) TestUtils.test_interface(k, Vector{String}) + test_params(k, (Float64[],)) test_ADs(ZeroKernel) test_interface_ad_perf(_ -> k, nothing, StableRNG(123456)) end @@ -24,6 +25,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) TestUtils.test_interface(k, Vector{String}) + test_params(k, (Float64[],)) test_ADs(WhiteKernel) test_interface_ad_perf(_ -> k, nothing, StableRNG(123456)) end @@ -36,11 +38,11 @@ @test metric(ConstantKernel()) == KernelFunctions.Delta() @test metric(ConstantKernel(; c=2.0)) == KernelFunctions.Delta() @test repr(k) == "Constant Kernel (c = $(c))" - test_params(k, ([c],)) # Standardised tests. TestUtils.test_interface(k, Float64) TestUtils.test_interface(k, Vector{String}) + test_params(k, ([log(c)],)) test_ADs(c -> ConstantKernel(; c=only(c)), [c]) test_interface_ad_perf(c -> ConstantKernel(; c=c), c, StableRNG(123456)) end diff --git a/test/basekernels/cosine.jl b/test/basekernels/cosine.jl index 2c083ae8d..ca923a90b 100644 --- a/test/basekernels/cosine.jl +++ b/test/basekernels/cosine.jl @@ -19,6 +19,7 @@ # Standardised tests. TestUtils.test_interface(k, Vector{Float64}) + test_params(k, (Float64[],)) test_ADs(CosineKernel) test_interface_ad_perf(_ -> CosineKernel(), nothing, StableRNG(123456)) end diff --git a/test/basekernels/exponential.jl b/test/basekernels/exponential.jl index 1b0bb5fc9..79e0f55fb 100644 --- a/test/basekernels/exponential.jl +++ b/test/basekernels/exponential.jl @@ -21,6 +21,7 @@ # Standardised tests. TestUtils.test_interface(k) + test_params(k, (Float64[],)) test_ADs(SEKernel) test_interface_ad_perf(_ -> SEKernel(), nothing, StableRNG(123456)) end @@ -40,6 +41,7 @@ # Standardised tests. TestUtils.test_interface(k) + test_params(k, (Float64[],)) test_ADs(ExponentialKernel) test_interface_ad_perf(_ -> ExponentialKernel(), nothing, StableRNG(123456)) end @@ -48,7 +50,7 @@ k = GammaExponentialKernel(; γ=γ) @test k(v1, v2) ≈ exp(-norm(v1 - v2)^γ) @test kappa(GammaExponentialKernel(), x) == kappa(k, x) - @test GammaExponentialKernel(; gamma=γ).γ == [γ] + @test GammaExponentialKernel(; gamma=γ).γ == γ @test metric(GammaExponentialKernel()) == Euclidean() @test metric(GammaExponentialKernel(; γ=2.0)) == Euclidean() @test repr(k) == "Gamma Exponential Kernel (γ = $(γ), metric = Euclidean(0.0))" @@ -59,7 +61,7 @@ @test k2(v1, v2) ≈ k(v1, v2) test_ADs(γ -> GammaExponentialKernel(; gamma=only(γ)), [1 + 0.5 * rand()]) - test_params(k, ([γ],)) + test_params(k, ([logit(γ / 2)],)) TestUtils.test_interface(GammaExponentialKernel(; γ=1.36)) #Coherence : diff --git a/test/basekernels/exponentiated.jl b/test/basekernels/exponentiated.jl index 14cc6d0d6..65f964946 100644 --- a/test/basekernels/exponentiated.jl +++ b/test/basekernels/exponentiated.jl @@ -13,6 +13,7 @@ # Standardised tests. This kernel appears to be fairly numerically unstable. TestUtils.test_interface(k; atol=1e-3) + test_params(k, (Float64[],)) test_ADs(ExponentiatedKernel) test_interface_ad_perf(_ -> ExponentiatedKernel(), nothing, StableRNG(123456)) end diff --git a/test/basekernels/fbm.jl b/test/basekernels/fbm.jl index 669d9721d..0a08350b2 100644 --- a/test/basekernels/fbm.jl +++ b/test/basekernels/fbm.jl @@ -22,7 +22,7 @@ Zygote.gradient((x, y) -> sum(f.(x, y)), zeros(1), fill(0.9, 1))[1][1] ) - test_params(k, ([h],)) + test_params(k, ([logit(h)],)) test_interface_ad_perf(h -> FBMKernel(; h=h), h, StableRNG(123456)) end diff --git a/test/basekernels/gabor.jl b/test/basekernels/gabor.jl index 69dee1b9a..b546f214f 100644 --- a/test/basekernels/gabor.jl +++ b/test/basekernels/gabor.jl @@ -13,8 +13,8 @@ TransformedKernel{<:CosineKernel,<:ScaleTransform}, }, } - @test k.kernels[1].transform.s[1] == inv(ell) - @test k.kernels[2].transform.s[1] == inv(p) + @test k.kernels[1].transform.s == inv(ell) + @test k.kernels[2].transform.s == inv(p) k_manual = exp(-sqeuclidean(v1, v2) / (2 * ell^2)) * cospi(euclidean(v1, v2) / p) @test k_manual ≈ k(v1, v2) atol = 1e-5 diff --git a/test/basekernels/matern.jl b/test/basekernels/matern.jl index e0be7054d..950fe407f 100644 --- a/test/basekernels/matern.jl +++ b/test/basekernels/matern.jl @@ -7,7 +7,7 @@ ν = 2.1 k = MaternKernel(; ν=ν) matern(x, ν) = 2^(1 - ν) / gamma(ν) * (sqrt(2ν) * x)^ν * besselk(ν, sqrt(2ν) * x) - @test MaternKernel(; nu=ν).ν == [ν] + @test MaternKernel(; nu=ν).ν == ν @test kappa(k, x) ≈ matern(x, ν) @test kappa(k, 0.0) == 1.0 @test metric(MaternKernel()) == Euclidean() @@ -20,10 +20,9 @@ # Standardised tests. TestUtils.test_interface(k, Float64) + test_params(k, ([log(ν)],)) test_ADs(() -> MaternKernel(; nu=ν)) - test_params(k, ([ν],)) - # The performance of this kernel varies quite a lot from method to method, so # requires us to specify whether performance tests pass or not. @testset "performance ($T)" for T in [ @@ -59,6 +58,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) + test_params(k, (Float64[],)) test_ADs(Matern32Kernel) test_interface_ad_perf(_ -> Matern32Kernel(), nothing, StableRNG(123456)) end @@ -79,6 +79,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) + test_params(k, (Float64[],)) test_ADs(Matern52Kernel) test_interface_ad_perf(_ -> Matern52Kernel(), nothing, StableRNG(123456)) end diff --git a/test/basekernels/nn.jl b/test/basekernels/nn.jl index ee4863356..d396c9eed 100644 --- a/test/basekernels/nn.jl +++ b/test/basekernels/nn.jl @@ -7,6 +7,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) + test_params(k, (Float64[],)) test_ADs(NeuralNetworkKernel) test_interface_ad_perf(_ -> NeuralNetworkKernel(), nothing, StableRNG(123456)) end diff --git a/test/basekernels/periodic.jl b/test/basekernels/periodic.jl index fb149dff5..a97a6ec86 100644 --- a/test/basekernels/periodic.jl +++ b/test/basekernels/periodic.jl @@ -17,5 +17,5 @@ # test_ADs(r->PeriodicKernel(r =exp.(r)), log.(r), ADs = [:ForwardDiff, :ReverseDiff]) @test_broken "Undefined adjoint for Sinus metric, and failing randomly for ForwardDiff and ReverseDiff" - test_params(k, (r,)) + test_params(k, (map(log, r),)) end diff --git a/test/basekernels/piecewisepolynomial.jl b/test/basekernels/piecewisepolynomial.jl index 4d9979e1a..5bfbed826 100644 --- a/test/basekernels/piecewisepolynomial.jl +++ b/test/basekernels/piecewisepolynomial.jl @@ -31,9 +31,9 @@ # Standardised tests. TestUtils.test_interface(k, ColVecs{Float64}; dim_in=2) TestUtils.test_interface(k, RowVecs{Float64}; dim_in=2) + test_params(k, (Float64[],)) test_ADs(() -> PiecewisePolynomialKernel{degree}(; dim=D)) - test_params(k, ()) if VERSION >= v"1.8.0" test_interface_ad_perf(nothing, StableRNG(123456)) do _ PiecewisePolynomialKernel{degree}(; dim=D) diff --git a/test/basekernels/polynomial.jl b/test/basekernels/polynomial.jl index 80a3b1bb8..22e19cafd 100644 --- a/test/basekernels/polynomial.jl +++ b/test/basekernels/polynomial.jl @@ -18,8 +18,8 @@ # Standardised tests. TestUtils.test_interface(k, Float64) + test_params(LinearKernel(; c=c), ([log(c)],)) test_ADs(x -> LinearKernel(; c=x[1]), [c]) - test_params(LinearKernel(; c=c), ([c],)) test_interface_ad_perf(c -> LinearKernel(; c=c), c, StableRNG(123456)) end @testset "PolynomialKernel" begin @@ -41,8 +41,8 @@ # Standardised tests. TestUtils.test_interface(k, Float64) + test_params(PolynomialKernel(; c=c), ([log(c)],)) test_ADs(x -> PolynomialKernel(; c=x[1]), [c]) - test_params(PolynomialKernel(; c=c), ([c],)) test_interface_ad_perf( c -> PolynomialKernel(; degree=2, c=c), 0.3, StableRNG(123456) ) diff --git a/test/basekernels/rational.jl b/test/basekernels/rational.jl index 956295e04..0e3b27fea 100644 --- a/test/basekernels/rational.jl +++ b/test/basekernels/rational.jl @@ -28,7 +28,7 @@ # Standardised tests. TestUtils.test_interface(k, Float64) test_ADs(x -> RationalKernel(; alpha=exp(x[1])), [α]) - test_params(k, ([α],)) + test_params(k, ([log(α)],)) test_interface_ad_perf(α -> RationalKernel(; alpha=α), α, StableRNG(123456)) end @@ -55,8 +55,8 @@ # Standardised tests. TestUtils.test_interface(k, Float64) - # test_ADs(x -> RationalQuadraticKernel(; alpha=exp(x[1])), [α]) - test_params(k, ([α],)) + test_params(k, ([log(α)],)) + test_ADs(x -> RationalQuadraticKernel(; alpha=exp(x[1])), [α]) test_interface_ad_perf(α, StableRNG(123456)) do α RationalQuadraticKernel(; alpha=α) end @@ -146,8 +146,8 @@ # Standardised tests. TestUtils.test_interface(k, Float64) a = 1.0 + rand() + test_params(GammaRationalKernel(; α=a, γ=x), ([log(a), logit(x / 2)],)) test_ADs(x -> GammaRationalKernel(; α=x[1], γ=x[2]), [a, 1 + 0.5 * rand()]) - test_params(GammaRationalKernel(; α=a, γ=x), ([a], [x])) test_interface_ad_perf((2.0, 1.5), StableRNG(123456)) do θ GammaRationalKernel(; α=θ[1], γ=θ[2]) end diff --git a/test/test_utils.jl b/test/test_utils.jl index 8367fbd6d..c1432cf31 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -22,7 +22,7 @@ function params(m...) end function test_params(kernel, reference) - params_kernel = params(kernel) + params_kernel = params(ParameterKernel(kernel)) params_reference = params(reference) @test length(params_kernel) == length(params_reference) diff --git a/test/transform/scaletransform.jl b/test/transform/scaletransform.jl index 1d0988441..575c845cd 100644 --- a/test/transform/scaletransform.jl +++ b/test/transform/scaletransform.jl @@ -14,11 +14,11 @@ @test all([t(x[n]) ≈ x′[n] for n in eachindex(x)]) end - s2 = 2.0 - KernelFunctions.set!(t, s2) - @test t.s == [s2] @test isequal(ScaleTransform(s), ScaleTransform(s)) - @test repr(t) == "Scale Transform (s = $(s2))" + + s2 = 2.0 + @test repr(ScaleTransform(s2)) == "Scale Transform (s = $(s2))" + test_ADs(x -> SEKernel() ∘ ScaleTransform(exp(x[1])), randn(rng, 1)) test_interface_ad_perf(0.3, StableRNG(123456)) do c SEKernel() ∘ ScaleTransform(c)