diff --git a/Project.toml b/Project.toml index 3f3948cae..0b1a085d5 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -15,6 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] Compat = "2.2, 3" +MacroTools = "0.5" Distances = "0.9" Requires = "1.0.1" SpecialFunctions = "0.8, 0.9, 0.10" diff --git a/docs/src/userguide.md b/docs/src/userguide.md index b29ff7a0d..4be23360b 100644 --- a/docs/src/userguide.md +++ b/docs/src/userguide.md @@ -4,18 +4,25 @@ To create a kernel chose one of the kernels proposed, see [Kernels](@ref), or create your own, see [Creating Kernels](@ref) For example to create a square exponential kernel + ```julia k = SqExponentialKernel() ``` -Instead of having lengthscale(s) for each kernel we use `Transform` objects (see [Transform](@ref)) which are directly going to act on the inputs before passing them to the kernel. -For example to premultiply the input by 2.0 we create the kernel the following options are possible + +Instead of having lengthscale(s) for each kernel we use `Transform` objects (see [`Transform`](@ref)). The transformations are going to be applied on the inputs before the kernel is evaluated. +For example, the [`ScaleTransform`](@ref) multiplies every sample with a scalar. A `SqExponentialKernel` with a `ScaleTransform(ρ)`, is therefore equivalent to have a Squared Exponential Kernel with lengthscale `1/ρ`. +Here are some examples of how to use these transformations that are all equivalent: ```julia - k = transform(SqExponentialKernel(),ScaleTransform(2.0)) # returns a TransformedKernel - k = @kernel SqExponentialKernel() l=2.0 # Will be available soon - k = TransformedKernel(SqExponentialKernel(),ScaleTransform(2.0)) + k = TransformedKernel(SqExponentialKernel(), ScaleTransform(2.0)) # Constructor + k = transform(SqExponentialKernel(), ScaleTransform(2.0)) # wrapper for the constructor + k = transform(SqExponentialKernel(), 2.0) # Syntactic sugar + k = @kernel SqExponentialKernel() l=2.0 # Convenience macro ``` + Check the [`Transform`](@ref) page to see the other options. -To premultiply the kernel by a variance, you can use `*` or create a `ScaledKernel` +--- +To pre-multiply the kernel by a variance parameter, you can use `*` or create a `ScaledKernel` + ```julia k = 3.0*SqExponentialKernel() k = ScaledKernel(SqExponentialKernel(),3.0) @@ -29,7 +36,7 @@ To compute the kernel function on two vectors you can call k = SqExponentialKernel() x1 = rand(3) x2 = rand(3) - k(x1,x2) + k(x1, x2) ``` ## Creating a kernel matrix @@ -40,9 +47,8 @@ For example: ```julia k = SqExponentialKernel() A = rand(10,5) - kernelmatrix(k,A,obsdim=1) # Return a 10x10 matrix - kernelmatrix(k,A,obsdim=2) # Return a 5x5 matrix - k(A,obsdim=1) # Syntactic sugar + kernelmatrix(k, A, obsdim = 1) # Return a 10x10 matrix + kernelmatrix(k, A, obsdim = 2) # Return a 5x5 matrix ``` We also support specific kernel matrices outputs: diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 6844e7ae1..a7d87ee50 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -8,7 +8,7 @@ export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix! export transform export duplicate, set! # Helpers -export Kernel +export Kernel, BaseKernel, @kernel export ConstantKernel, WhiteKernel, EyeKernel, ZeroKernel, WienerKernel export CosineKernel export SqExponentialKernel, RBFKernel, GaussianKernel, SEKernel @@ -38,6 +38,7 @@ using SpecialFunctions: loggamma, besselk, polygamma using ZygoteRules: @adjoint, pullback using StatsFuns: logtwo using InteractiveUtils: subtypes +using MacroTools: @capture using StatsBase """ @@ -61,6 +62,7 @@ end include("kernels/transformedkernel.jl") include("kernels/scaledkernel.jl") +include("kernels/kernel_macro.jl") include("matrix/kernelmatrix.jl") include("kernels/kernelsum.jl") include("kernels/kernelproduct.jl") diff --git a/src/kernels/kernel_macro.jl b/src/kernels/kernel_macro.jl new file mode 100644 index 000000000..125fa41d7 --- /dev/null +++ b/src/kernels/kernel_macro.jl @@ -0,0 +1,55 @@ +""" + @kernel [variance *] kernel + @kernel [variance *] kernel l=Real/Vector + @kernel [variance *] kernel t=transform + +The `@kernel` macro is an helping alias to the [`transform`](@ref) function. +The first argument should be a kernel multiplied (or not) by a scalar (variance of the kernel). +The second argument (optional) can be a keyword : + - `l=ρ` where `ρ` is a positive scalar or a vector of scalar + - `t=transform` where `transform` is a [`Transform`](@ref) object +One can also directly use a `Transform` object without a keyword. + +# Examples +```jldoctest +julia> k = @kernel SqExponentialKernel() l=3.0 +Squared Exponential Kernel + - Scale Transform (s = 3.0) + +julia> k == transform(SqExponentialKernel(), ScaleTransform(3.0)) +true + +julia> k = @kernel (MaternKernel(ν=3.0) + LinearKernel()) t=LinearTransform(rand(4,3)) +Sum of 2 kernels: + - (w = 1.0) Matern Kernel (ν = 3.0) + - (w = 1.0) Linear Kernel (c = 0.0) + - Linear transform (size(A) = (4, 3)) + +julia> k == transform(KernelSum(MaternKernel(ν=3.0), LinearKernel()), LinearTransform(rand(4,3))) +true + +julia> k = @kernel 4.0*ExponentiatedKernel() l=3.0 +Exponentiated Kernel + - Scale Transform (s = 3.0) + - σ² = 4.0 +julia> k == ScaleTransform(transform(ExponentiatedKernel(), ScaleTransform(3.0)), 4.0) +true +``` +""" +macro kernel(expr::Expr, arg = nothing) + @capture(expr, ((scale_ * k_) | (k_))) + if arg === nothing + t = nothing + else + if @capture(arg, ((l = val_) | (t = val_))) + t = val + else + return :(error("The additional argument of `@kernel` is incorrect")) + end + end + if scale === nothing + return :(transform($(esc(k)), $(esc(t)))) + else + return :($(esc(scale)) * transform($(esc(k)), $(esc(t)))) + end +end diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index b96208308..39bfb3e94 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -9,6 +9,10 @@ struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel transform::Tr end +function TransformedKernel(k::TransformedKernel, t::Transform) + TransformedKernel(kernel(k), t ∘ k.transform) +end + (k::TransformedKernel)(x, y) = k.kernel(k.transform(x), k.transform(y)) # Optimizations for scale transforms of simple kernels to save allocations: @@ -41,11 +45,15 @@ _scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y)) """ transform -transform(k::BaseKernel, t::Transform) = TransformedKernel(k, t) +transform(k::Kernel, t::Transform) = TransformedKernel(k, t) + +transform(k::Kernel, ρ::Real) = TransformedKernel(k, ScaleTransform(ρ)) + +transform(k::Kernel, ρ::AbstractVector) = TransformedKernel(k, ARDTransform(ρ)) -transform(k::BaseKernel, ρ::Real) = TransformedKernel(k, ScaleTransform(ρ)) +transform(k::Kernel, ρ::AbstractMatrix) = TransformedKernel(k, LinearTransform(ρ)) -transform(k::BaseKernel,ρ::AbstractVector) = TransformedKernel(k, ARDTransform(ρ)) +transform(k::Kernel, ::Nothing) = k kernel(κ) = κ.kernel diff --git a/test/kernels/kernel_macro.jl b/test/kernels/kernel_macro.jl new file mode 100644 index 000000000..8bd2b8959 --- /dev/null +++ b/test/kernels/kernel_macro.jl @@ -0,0 +1,10 @@ +@testset "Kernel Macro" begin + @test (@kernel SqExponentialKernel()) isa SqExponentialKernel + @test (@kernel 3.0 * SqExponentialKernel()) isa ScaledKernel{SqExponentialKernel,Float64} + @test (@kernel 3.0 * SqExponentialKernel() l = 3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64} + # @test (@kernel 3.0 * SqExponentialKernel() 3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64} + @test (@kernel 3.0 * SqExponentialKernel() l=[3.0]) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ARDTransform{Vector{Float64}}},Float64} + # @test (@kernel 3.0 * SqExponentialKernel() LinearTransform(rand(3,2))) isa ScaledKernel{TransformedKernel{SqExponentialKernel,LinearTransform{Array{Float64,2}}},Float64} + @test (@kernel 3.0 * SqExponentialKernel() + 5.0 * Matern32Kernel() l = 3.0) isa TransformedKernel{KernelSum,ScaleTransform{Float64}} + @test_throws ErrorException (@kernel SqExponentialKernel() w = 2.0) +end diff --git a/test/kernels/transformedkernel.jl b/test/kernels/transformedkernel.jl index cf49dde2d..af1b30715 100644 --- a/test/kernels/transformedkernel.jl +++ b/test/kernels/transformedkernel.jl @@ -7,8 +7,8 @@ s = rand(rng) v = rand(rng, 3) k = SqExponentialKernel() - kt = TransformedKernel(k,ScaleTransform(s)) - ktard = TransformedKernel(k,ARDTransform(v)) + kt = TransformedKernel(k, ScaleTransform(s)) + ktard = TransformedKernel(k, ARDTransform(v)) @test kt(v1, v2) == transform(k, ScaleTransform(s))(v1, v2) @test kt(v1, v2) == transform(k, s)(v1,v2) @test kt(v1, v2) ≈ k(s * v1, s * v2) atol=1e-5 @@ -16,6 +16,12 @@ @test ktard(v1, v2) == transform(k,v)(v1, v2) @test ktard(v1, v2) == k(v .* v1, v .* v2) + @test transform(kt, s) isa TransformedKernel{SqExponentialKernel,ChainTransform{Array{ScaleTransform{Float64},1}}} + + @test transform(k, s) isa TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}} + @test transform(k, v) isa TransformedKernel{SqExponentialKernel,ARDTransform{Array{Float64,1}}} + @test transform(k, rand(3, 2)) isa TransformedKernel{SqExponentialKernel,LinearTransform{Array{Float64,2}}} + @testset "kernelmatrix" begin rng = MersenneTwister(123456) diff --git a/test/runtests.jl b/test/runtests.jl index d0ea3e3c5..9f311c73b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -87,6 +87,7 @@ using KernelFunctions: metric, kappa, ColVecs, RowVecs @testset "kernels" begin include(joinpath("kernels", "kernelproduct.jl")) include(joinpath("kernels", "kernelsum.jl")) + include(joinpath("kernels", "kernel_macro.jl")) include(joinpath("kernels", "scaledkernel.jl")) include(joinpath("kernels", "tensorproduct.jl")) include(joinpath("kernels", "transformedkernel.jl"))