From 2ebb65d6025cb06dbdd756cb9dde0112af6ee12f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Fri, 28 Feb 2020 17:53:49 +0100 Subject: [PATCH 01/14] First version of the macro --- src/kernels/kernel_macro.jl | 25 +++++++++++++++++++++++++ src/kernels/transformedkernel.jl | 2 ++ 2 files changed, 27 insertions(+) create mode 100644 src/kernels/kernel_macro.jl diff --git a/src/kernels/kernel_macro.jl b/src/kernels/kernel_macro.jl new file mode 100644 index 000000000..4a55c8200 --- /dev/null +++ b/src/kernels/kernel_macro.jl @@ -0,0 +1,25 @@ +using MacroTools: @capture + +""" + +""" +macro kernel(expr::Expr,arg=nothing) + @capture(expr,(scale_*k_ | k_)) || throw(error("@kernel first arguments should be of the form `σ*Kernel()` or `Kernel()`")) + @show kw + t = if @capture(arg,kw_=val_) + if kw == :l + val + elseif kw == :t + val + else + throw(error("The additional argument could not be intepreted. Please see documentation of `@kernel`")) + end + else + arg + end + if isnothing(scale) + return esc(:(transform($k,$t))) + else + return esc(:($scale*transform($k,$t))) + end +end diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index ef02be638..fa4b83c1f 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -21,6 +21,8 @@ transform(k::BaseKernel, ρ::Real) = TransformedKernel(k, ScaleTransform(ρ)) transform(k::BaseKernel,ρ::AbstractVector) = TransformedKernel(k, ARDTransform(ρ)) +transform(k::BaseKernel,::Nothing) = k + kernel(κ) = κ.kernel kappa(κ::TransformedKernel, x) = kappa(κ.kernel, x) From 1df722adc094bc6f2cea72c4607b03f013f38689 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Fri, 28 Feb 2020 17:57:54 +0100 Subject: [PATCH 02/14] Added MacroTools in dependency --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 3f0a9d45f..40d7c07de 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" @@ -16,6 +17,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] Compat = "2.2, 3" Distances = "0.8" +MacroTools = "0.5" Requires = "1.0.1" SpecialFunctions = "0.8, 0.9, 0.10" StatsBase = "0.32" From 742428cb9498af7e88aecd6b8c5bcb2e752f236a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Fri, 28 Feb 2020 17:58:43 +0100 Subject: [PATCH 03/14] include file and export --- src/KernelFunctions.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 000c99034..975e0d480 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -4,7 +4,7 @@ export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!, kappa export transform export params, duplicate, set! # Helpers -export Kernel +export Kernel, BaseKernel, @kernel export ConstantKernel, WhiteKernel, ZeroKernel export SqExponentialKernel, ExponentialKernel, GammaExponentialKernel export ExponentiatedKernel @@ -46,6 +46,7 @@ for k in ["exponential","matern","polynomial","constant","rationalquad","exponen end include("kernels/transformedkernel.jl") include("kernels/scaledkernel.jl") +include("kernels/kernel_macro.jl") include("matrix/kernelmatrix.jl") include("kernels/kernelsum.jl") include("kernels/kernelproduct.jl") From c80423536532e0ba0f6d4d694bbf47f81f8b91ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 11 Mar 2020 16:19:04 +0100 Subject: [PATCH 04/14] Correction unwanted kw --- src/kernels/kernel_macro.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/kernels/kernel_macro.jl b/src/kernels/kernel_macro.jl index 4a55c8200..a26b38c5e 100644 --- a/src/kernels/kernel_macro.jl +++ b/src/kernels/kernel_macro.jl @@ -5,7 +5,6 @@ using MacroTools: @capture """ macro kernel(expr::Expr,arg=nothing) @capture(expr,(scale_*k_ | k_)) || throw(error("@kernel first arguments should be of the form `σ*Kernel()` or `Kernel()`")) - @show kw t = if @capture(arg,kw_=val_) if kw == :l val From 80f057bbf6e7634ce5b862c473a1c92a4f0d511f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 11 Mar 2020 17:33:31 +0100 Subject: [PATCH 05/14] Adding tests for the macro --- test/runtests.jl | 1 + test/test_macro.jl | 12 ++++++++++++ 2 files changed, 13 insertions(+) create mode 100644 test/test_macro.jl diff --git a/test/runtests.jl b/test/runtests.jl index 45f9b3909..787b7aec5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ using Random include("test_kernelmatrix.jl") include("test_approximations.jl") include("test_constructors.jl") +include("test_macro.jl") # include("test_AD.jl") include("test_transform.jl") include("test_distances.jl") diff --git a/test/test_macro.jl b/test/test_macro.jl new file mode 100644 index 000000000..26628af7e --- /dev/null +++ b/test/test_macro.jl @@ -0,0 +1,12 @@ +using KernelFunctions +using Test + +@testset "Kernel Macro" begin + @test (@kernel SqExponentialKernel()) isa SqExponentialKernel + @test (@kernel 3.0*SqExponentialKernel()) isa ScaledKernel{SqExponentialKernel,Float64} + @test (@kernel 3.0*SqExponentialKernel() l=3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64} + @test (@kernel 3.0*SqExponentialKernel() 3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64} + @test (@kernel 3.0*SqExponentialKernel() l=[3.0]) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ARDTransform{Float64,1}},Float64} + @test (@kernel 3.0*SqExponentialKernel() LowRankTransform(rand(3,2))) isa ScaledKernel{TransformedKernel{SqExponentialKernel,LowRankTransform{Array{Float64,2}}},Float64} + @test (@kernel (3.0*SqExponentialKernel()+5.0*Matern32Kernel()) 3.0) isa TransformedKernel{KernelSum,ScaleTransform{Float64}} +end From 2982a18cd305be896f9703784568c7b6ede60981 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 11 Mar 2020 17:34:00 +0100 Subject: [PATCH 06/14] Added print for a few transforms --- src/generic.jl | 4 ++-- src/transform/ardtransform.jl | 2 ++ src/transform/chaintransform.jl | 13 ++++++++++--- src/transform/scaletransform.jl | 2 +- src/transform/selecttransform.jl | 2 ++ 5 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/generic.jl b/src/generic.jl index da2e0ad03..705a77c0a 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -11,8 +11,8 @@ _scale(t::ScaleTransform, metric::Euclidean, x, y) = first(t.s) * evaluate(metr _scale(t::ScaleTransform, metric::Union{SqEuclidean,DotProduct}, x, y) = first(t.s)^2 * evaluate(metric, x, y) _scale(t::ScaleTransform, metric, x, y) = evaluate(metric, apply(t, x), apply(t, y)) -printshifted(io::IO,κ::Kernel,shift::Int) = print(io,"$κ") -Base.show(io::IO,κ::Kernel) = print(io,nameof(typeof(κ))) +printshifted(io::IO, κ::Kernel, shift::Int) = print(io, "$κ") +Base.show(io::IO, κ::Kernel) = print(io, nameof(typeof(κ))) ### Syntactic sugar for creating matrices and using kernel functions for k in subtypes(BaseKernel) diff --git a/src/transform/ardtransform.jl b/src/transform/ardtransform.jl index 0f4b23e9b..a3d7bb092 100644 --- a/src/transform/ardtransform.jl +++ b/src/transform/ardtransform.jl @@ -38,3 +38,5 @@ apply(t::ARDTransform,x::AbstractVector{<:Real};obsdim::Int=defaultobs) = t.v .* _transform(t::ARDTransform,X::AbstractMatrix{<:Real},obsdim::Int=defaultobs) = obsdim == 1 ? t.v'.*X : t.v .* X Base.isequal(t::ARDTransform,t2::ARDTransform) = isequal(t.v,t2.v) + +Base.show(io::IO, t::ARDTransform) = print(io,"ARD Transform, ρ = $(t.v)") diff --git a/src/transform/chaintransform.jl b/src/transform/chaintransform.jl index 91d8880b8..34bcec372 100644 --- a/src/transform/chaintransform.jl +++ b/src/transform/chaintransform.jl @@ -36,6 +36,13 @@ params(t::ChainTransform) = (params.(t.transforms)) duplicate(t::ChainTransform,θ) = ChainTransform(duplicate.(t.transforms,θ)) -Base.:∘(t₁::Transform,t₂::Transform) = ChainTransform([t₂,t₁]) -Base.:∘(t::Transform,tc::ChainTransform) = ChainTransform(vcat(tc.transforms,t)) #TODO add test -Base.:∘(tc::ChainTransform,t::Transform) = ChainTransform(vcat(t,tc.transforms)) +Base.:∘(t₁::Transform, t₂::Transform) = ChainTransform([t₂, t₁]) +Base.:∘(t::Transform, tc::ChainTransform) = ChainTransform(vcat(tc.transforms, t)) #TODO add test +Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transforms)) + +function Base.show(io::IO, tc::ChainTransform) + print(io,"Chain Transform : $(first(tc.transforms))") + for t in tc.transforms[2:end] + print(io, " |> $t") + end +end diff --git a/src/transform/scaletransform.jl b/src/transform/scaletransform.jl index c7d57835b..9444967cb 100644 --- a/src/transform/scaletransform.jl +++ b/src/transform/scaletransform.jl @@ -23,4 +23,4 @@ apply(t::ScaleTransform,x::AbstractVecOrMat;obsdim::Int=defaultobs) = first(t.s) Base.isequal(t::ScaleTransform,t2::ScaleTransform) = isequal(first(t.s),first(t2.s)) -Base.show(io::IO,t::ScaleTransform) = print(io,"Scale Transform s=$(first(t.s))") +Base.show(io::IO,t::ScaleTransform) = print(io,"Scale Transform, s = $(first(t.s))") diff --git a/src/transform/selecttransform.jl b/src/transform/selecttransform.jl index 9e8770139..19e51e728 100644 --- a/src/transform/selecttransform.jl +++ b/src/transform/selecttransform.jl @@ -41,3 +41,5 @@ function apply(t::SelectTransform, x::AbstractVector{<:Real}; obsdim::Int = defa end _transform(t::SelectTransform,X::AbstractMatrix{<:Real},obsdim::Int=defaultobs) = obsdim == 2 ? view(X,t.select,:) : view(X,:,t.select) + +Base.show(io::IO, t::SelectTransform) = print(io, "Selected Dimensions : $(t.select)") From 1d5b4fe86bfe13151cfa9a587f45634931fd2e76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Wed, 11 Mar 2020 17:34:24 +0100 Subject: [PATCH 07/14] Applying transform on a transformed kernel results in a chain transform --- src/kernels/transformedkernel.jl | 9 ++++++--- src/transform/transform.jl | 4 ++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index fa4b83c1f..5d13787ba 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -3,6 +3,9 @@ struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel transform::Tr end +function TransformedKernel(k::TransformedKernel,t::Transform) + TransformedKernel(kernel(k),t∘k.transform) +end """ ```julia transform(k::BaseKernel, t::Transform) (1) @@ -15,11 +18,11 @@ end """ transform -transform(k::BaseKernel, t::Transform) = TransformedKernel(k, t) +transform(k::Kernel, t::Transform) = TransformedKernel(k, t) -transform(k::BaseKernel, ρ::Real) = TransformedKernel(k, ScaleTransform(ρ)) +transform(k::Kernel, ρ::Real) = TransformedKernel(k, ScaleTransform(ρ)) -transform(k::BaseKernel,ρ::AbstractVector) = TransformedKernel(k, ARDTransform(ρ)) +transform(k::Kernel,ρ::AbstractVector) = TransformedKernel(k, ARDTransform(ρ)) transform(k::BaseKernel,::Nothing) = k diff --git a/src/transform/transform.jl b/src/transform/transform.jl index 6b2c7b988..d56dfda3f 100644 --- a/src/transform/transform.jl +++ b/src/transform/transform.jl @@ -17,6 +17,10 @@ params(t::IdentityTransform) = nothing apply(t::IdentityTransform, x; obsdim::Int=defaultobs) = x #TODO add test +Transform(ρ::Real) = ScaleTransform(ρ) +Transform(ρ::AbstractVector) = ARDTransform(ρ) +Transform(t::Transform) = t + ### TODO Maybe defining adjoints could help but so far it's not working From 39463132b9f0aa0c5d2f0e3bd4dae7f4bd78ba69 Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Mon, 30 Mar 2020 17:56:39 +0200 Subject: [PATCH 08/14] Added tests and description for the macro --- src/KernelFunctions.jl | 1 + src/kernels/kernel_macro.jl | 29 ++++++++++++++----- .../kernel_macro.jl} | 0 test/runtests.jl | 1 + 4 files changed, 24 insertions(+), 7 deletions(-) rename test/{test_macro.jl => kernels/kernel_macro.jl} (100%) diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index b545c2d0c..f882428a9 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -30,6 +30,7 @@ using SpecialFunctions: logabsgamma, besselk using ZygoteRules: @adjoint using StatsFuns: logtwo using InteractiveUtils: subtypes +using MacroTools: @capture using StatsBase const defaultobs = 2 diff --git a/src/kernels/kernel_macro.jl b/src/kernels/kernel_macro.jl index a26b38c5e..22a7ec3d0 100644 --- a/src/kernels/kernel_macro.jl +++ b/src/kernels/kernel_macro.jl @@ -1,11 +1,26 @@ -using MacroTools: @capture - """ + @kernel [variance *]kernel::Kernel [l=Real/Vector / t=transform::Transform / transform::Transform] + +The `@kernel` macro is an helping alias to the [`transform`](@ref) function. +The first argument should be a kernel multiplied (or not) by a scalar (variance of the kernel). +The second argument (optional) can be a keyword : + - `l=ρ` where `ρ` is a positive scalar or a vector of scalar + - `t=transform` where `transform` is a [`Transform`](@ref) object +One can also directly use a `Transform` object without a keyword. +Here are some examples : +```julia + k = @kernel SqExponentialKernel() l=3.0 + k == transform(SqExponentialKernel(), ScaleTransform(3.0)) + + k = @kernel (MaternKernel(ν=3.0) + LinearKernel()) t=LowRankTransform(rand(4,3)) + k == transform(KernelSum(MaternKernel(ν=3.0), LinearKernel()), LowRankTransform(rand(4,3))) + k = @kernel 4.0*ExponentiatedKernel() ScaleTransform(3.0) + k == ScaleTransform(transform(ExponentiatedKernel(), ScaleTransform(3.0)), 4.0) """ -macro kernel(expr::Expr,arg=nothing) - @capture(expr,(scale_*k_ | k_)) || throw(error("@kernel first arguments should be of the form `σ*Kernel()` or `Kernel()`")) - t = if @capture(arg,kw_=val_) +macro kernel(expr::Expr, arg = nothing) + @capture(expr, (scale_ * k_ | k_)) || throw(error("@kernel first arguments should be of the form `σ*kernel` or `kernel`")) + t = if @capture(arg, kw_ = val_) if kw == :l val elseif kw == :t @@ -17,8 +32,8 @@ macro kernel(expr::Expr,arg=nothing) arg end if isnothing(scale) - return esc(:(transform($k,$t))) + return esc(:(transform($k, $t))) else - return esc(:($scale*transform($k,$t))) + return esc(:($scale*transform($k, $t))) end end diff --git a/test/test_macro.jl b/test/kernels/kernel_macro.jl similarity index 100% rename from test/test_macro.jl rename to test/kernels/kernel_macro.jl diff --git a/test/runtests.jl b/test/runtests.jl index 698751ce1..4bf4d310a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -69,6 +69,7 @@ using KernelFunctions: metric include(joinpath("kernels", "fbm.jl")) include(joinpath("kernels", "kernelproduct.jl")) include(joinpath("kernels", "kernelsum.jl")) + include(joinpath("kernels", "kernel_macro.jl")) include(joinpath("kernels", "matern.jl")) include(joinpath("kernels", "polynomial.jl")) include(joinpath("kernels", "rationalquad.jl")) From 63219ff8a69aa5007658be3236e57d6f8440e8dd Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Mon, 11 May 2020 17:50:42 +0200 Subject: [PATCH 09/14] Added some docs and updated tests --- docs/src/userguide.md | 18 ++++++++++++------ src/kernels/kernel_macro.jl | 4 ++-- test/kernels/kernel_macro.jl | 15 ++++++--------- 3 files changed, 20 insertions(+), 17 deletions(-) diff --git a/docs/src/userguide.md b/docs/src/userguide.md index b29ff7a0d..243342203 100644 --- a/docs/src/userguide.md +++ b/docs/src/userguide.md @@ -4,18 +4,24 @@ To create a kernel chose one of the kernels proposed, see [Kernels](@ref), or create your own, see [Creating Kernels](@ref) For example to create a square exponential kernel + ```julia k = SqExponentialKernel() ``` -Instead of having lengthscale(s) for each kernel we use `Transform` objects (see [Transform](@ref)) which are directly going to act on the inputs before passing them to the kernel. -For example to premultiply the input by 2.0 we create the kernel the following options are possible + +Instead of having lengthscale(s) for each kernel we use `Transform` objects (see [Transform](@ref)). The transform operations are going to be applied on the inputs before they are passed to the kernel. +For example, the [ScaleTransform](@ref) multiply every sample by a scalar $\rho$. A `SqExponentialKernel` with a `ScaleTransform(ρ)`, is therefore equivalent to have a `SqExponentialKernel` with lengthscale `1/ρ`. +Here are some examples on how to use these transformations and are all equivalent: ```julia - k = transform(SqExponentialKernel(),ScaleTransform(2.0)) # returns a TransformedKernel - k = @kernel SqExponentialKernel() l=2.0 # Will be available soon - k = TransformedKernel(SqExponentialKernel(),ScaleTransform(2.0)) + k = TransformedKernel(SqExponentialKernel(), ScaleTransform(2.0)) # Constructor + k = transform(SqExponentialKernel(), ScaleTransform(2.0)) # wrapper for a constructor + k = @kernel SqExponentialKernel() l=2.0 # Convenience macro ``` + Check the [`Transform`](@ref) page to see the other options. -To premultiply the kernel by a variance, you can use `*` or create a `ScaledKernel` +--- +To pre-multiply the kernel by a variance parameter, you can use `*` or create a `ScaledKernel` + ```julia k = 3.0*SqExponentialKernel() k = ScaledKernel(SqExponentialKernel(),3.0) diff --git a/src/kernels/kernel_macro.jl b/src/kernels/kernel_macro.jl index 22a7ec3d0..d23540671 100644 --- a/src/kernels/kernel_macro.jl +++ b/src/kernels/kernel_macro.jl @@ -19,7 +19,7 @@ Here are some examples : k == ScaleTransform(transform(ExponentiatedKernel(), ScaleTransform(3.0)), 4.0) """ macro kernel(expr::Expr, arg = nothing) - @capture(expr, (scale_ * k_ | k_)) || throw(error("@kernel first arguments should be of the form `σ*kernel` or `kernel`")) + @capture(expr, (scale_ * k_ | k_)) || throw(error("@kernel first arguments should be of the form `σ * kernel` or `kernel`")) t = if @capture(arg, kw_ = val_) if kw == :l val @@ -34,6 +34,6 @@ macro kernel(expr::Expr, arg = nothing) if isnothing(scale) return esc(:(transform($k, $t))) else - return esc(:($scale*transform($k, $t))) + return esc(:($scale * transform($k, $t))) end end diff --git a/test/kernels/kernel_macro.jl b/test/kernels/kernel_macro.jl index 26628af7e..97f70cd72 100644 --- a/test/kernels/kernel_macro.jl +++ b/test/kernels/kernel_macro.jl @@ -1,12 +1,9 @@ -using KernelFunctions -using Test - @testset "Kernel Macro" begin @test (@kernel SqExponentialKernel()) isa SqExponentialKernel - @test (@kernel 3.0*SqExponentialKernel()) isa ScaledKernel{SqExponentialKernel,Float64} - @test (@kernel 3.0*SqExponentialKernel() l=3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64} - @test (@kernel 3.0*SqExponentialKernel() 3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64} - @test (@kernel 3.0*SqExponentialKernel() l=[3.0]) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ARDTransform{Float64,1}},Float64} - @test (@kernel 3.0*SqExponentialKernel() LowRankTransform(rand(3,2))) isa ScaledKernel{TransformedKernel{SqExponentialKernel,LowRankTransform{Array{Float64,2}}},Float64} - @test (@kernel (3.0*SqExponentialKernel()+5.0*Matern32Kernel()) 3.0) isa TransformedKernel{KernelSum,ScaleTransform{Float64}} + @test (@kernel 3.0 * SqExponentialKernel()) isa ScaledKernel{SqExponentialKernel,Float64} + @test (@kernel 3.0 * SqExponentialKernel() l=3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64} + @test (@kernel 3.0 * SqExponentialKernel() 3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64} + @test (@kernel 3.0 * SqExponentialKernel() l=[3.0]) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ARDTransform{Vector{Float64}}},Float64} + @test (@kernel 3.0 * SqExponentialKernel() LinearTransform(rand(3,2))) isa ScaledKernel{TransformedKernel{SqExponentialKernel,LinearTransform{Array{Float64,2}}},Float64} + @test (@kernel (3.0 * SqExponentialKernel() + 5.0 * Matern32Kernel()) 3.0) isa TransformedKernel{KernelSum,ScaleTransform{Float64}} end From 6ebbbcfd8be17aa12396e05678cd8518a7215a09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Galy-Fajou?= Date: Mon, 11 May 2020 22:28:13 +0200 Subject: [PATCH 10/14] Update src/kernels/transformedkernel.jl Co-authored-by: David Widmann --- src/kernels/transformedkernel.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index fab74c96a..9357c6770 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -50,7 +50,7 @@ transform(k::Kernel, ρ::Real) = TransformedKernel(k, ScaleTransform(ρ)) transform(k::Kernel,ρ::AbstractVector) = TransformedKernel(k, ARDTransform(ρ)) -transform(k::BaseKernel,::Nothing) = k +transform(k::Kernel, ::Nothing) = k kernel(κ) = κ.kernel From 87d621197b25ed94dc10c62c6676ca193d2acf05 Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Tue, 12 May 2020 14:37:42 +0200 Subject: [PATCH 11/14] Correction on docs --- docs/src/userguide.md | 16 ++++++++-------- src/kernels/kernel_macro.jl | 36 +++++++++++++++++++++++++++--------- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/docs/src/userguide.md b/docs/src/userguide.md index 243342203..4be23360b 100644 --- a/docs/src/userguide.md +++ b/docs/src/userguide.md @@ -9,12 +9,13 @@ For example to create a square exponential kernel k = SqExponentialKernel() ``` -Instead of having lengthscale(s) for each kernel we use `Transform` objects (see [Transform](@ref)). The transform operations are going to be applied on the inputs before they are passed to the kernel. -For example, the [ScaleTransform](@ref) multiply every sample by a scalar $\rho$. A `SqExponentialKernel` with a `ScaleTransform(ρ)`, is therefore equivalent to have a `SqExponentialKernel` with lengthscale `1/ρ`. -Here are some examples on how to use these transformations and are all equivalent: +Instead of having lengthscale(s) for each kernel we use `Transform` objects (see [`Transform`](@ref)). The transformations are going to be applied on the inputs before the kernel is evaluated. +For example, the [`ScaleTransform`](@ref) multiplies every sample with a scalar. A `SqExponentialKernel` with a `ScaleTransform(ρ)`, is therefore equivalent to have a Squared Exponential Kernel with lengthscale `1/ρ`. +Here are some examples of how to use these transformations that are all equivalent: ```julia k = TransformedKernel(SqExponentialKernel(), ScaleTransform(2.0)) # Constructor - k = transform(SqExponentialKernel(), ScaleTransform(2.0)) # wrapper for a constructor + k = transform(SqExponentialKernel(), ScaleTransform(2.0)) # wrapper for the constructor + k = transform(SqExponentialKernel(), 2.0) # Syntactic sugar k = @kernel SqExponentialKernel() l=2.0 # Convenience macro ``` @@ -35,7 +36,7 @@ To compute the kernel function on two vectors you can call k = SqExponentialKernel() x1 = rand(3) x2 = rand(3) - k(x1,x2) + k(x1, x2) ``` ## Creating a kernel matrix @@ -46,9 +47,8 @@ For example: ```julia k = SqExponentialKernel() A = rand(10,5) - kernelmatrix(k,A,obsdim=1) # Return a 10x10 matrix - kernelmatrix(k,A,obsdim=2) # Return a 5x5 matrix - k(A,obsdim=1) # Syntactic sugar + kernelmatrix(k, A, obsdim = 1) # Return a 10x10 matrix + kernelmatrix(k, A, obsdim = 2) # Return a 5x5 matrix ``` We also support specific kernel matrices outputs: diff --git a/src/kernels/kernel_macro.jl b/src/kernels/kernel_macro.jl index d23540671..c68b63dfa 100644 --- a/src/kernels/kernel_macro.jl +++ b/src/kernels/kernel_macro.jl @@ -1,5 +1,7 @@ """ - @kernel [variance *]kernel::Kernel [l=Real/Vector / t=transform::Transform / transform::Transform] + @kernel [variance *] kernel + @kernel [variance *] kernel l=Real/Vector + @kernel [variance *] kernel t=transform The `@kernel` macro is an helping alias to the [`transform`](@ref) function. The first argument should be a kernel multiplied (or not) by a scalar (variance of the kernel). @@ -7,16 +9,32 @@ The second argument (optional) can be a keyword : - `l=ρ` where `ρ` is a positive scalar or a vector of scalar - `t=transform` where `transform` is a [`Transform`](@ref) object One can also directly use a `Transform` object without a keyword. -Here are some examples : -```julia - k = @kernel SqExponentialKernel() l=3.0 - k == transform(SqExponentialKernel(), ScaleTransform(3.0)) - k = @kernel (MaternKernel(ν=3.0) + LinearKernel()) t=LowRankTransform(rand(4,3)) - k == transform(KernelSum(MaternKernel(ν=3.0), LinearKernel()), LowRankTransform(rand(4,3))) +# Examples +```jldoctest +julia> k = @kernel SqExponentialKernel() l=3.0 +Squared Exponential Kernel + - Scale Transform (s = 3.0) - k = @kernel 4.0*ExponentiatedKernel() ScaleTransform(3.0) - k == ScaleTransform(transform(ExponentiatedKernel(), ScaleTransform(3.0)), 4.0) +julia> k == transform(SqExponentialKernel(), ScaleTransform(3.0)) +true + +julia> k = @kernel (MaternKernel(ν=3.0) + LinearKernel()) t=LinearTransform(rand(4,3)) +Sum of 2 kernels: + - (w = 1.0) Matern Kernel (ν = 3.0) + - (w = 1.0) Linear Kernel (c = 0.0) + - Linear transform (size(A) = (4, 3)) + +julia> k == transform(KernelSum(MaternKernel(ν=3.0), LinearKernel()), LinearTransform(rand(4,3))) +true + +julia> k = @kernel 4.0*ExponentiatedKernel() l=3.0 +Exponentiated Kernel + - Scale Transform (s = 3.0) + - σ² = 4.0 +julia> k == ScaleTransform(transform(ExponentiatedKernel(), ScaleTransform(3.0)), 4.0) +true +``` """ macro kernel(expr::Expr, arg = nothing) @capture(expr, (scale_ * k_ | k_)) || throw(error("@kernel first arguments should be of the form `σ * kernel` or `kernel`")) From d7df138b4a37cc6904ab91a5be123557e7fb9f27 Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Tue, 12 May 2020 14:38:24 +0200 Subject: [PATCH 12/14] Applied suggestions on @kernel and adapted tests --- src/kernels/kernel_macro.jl | 22 ++++++++++------------ test/kernels/kernel_macro.jl | 9 +++++---- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/kernels/kernel_macro.jl b/src/kernels/kernel_macro.jl index c68b63dfa..fbbe40065 100644 --- a/src/kernels/kernel_macro.jl +++ b/src/kernels/kernel_macro.jl @@ -37,21 +37,19 @@ true ``` """ macro kernel(expr::Expr, arg = nothing) - @capture(expr, (scale_ * k_ | k_)) || throw(error("@kernel first arguments should be of the form `σ * kernel` or `kernel`")) - t = if @capture(arg, kw_ = val_) - if kw == :l - val - elseif kw == :t - val + @capture(expr, (scale_ * k_ | k_)) || error("@kernel first arguments should be of the form `σ * kernel` or `kernel`") + if arg === nothing + t = nothing + else + if @capture(arg, ((l = val_) | (t = val_))) + t = val else - throw(error("The additional argument could not be intepreted. Please see documentation of `@kernel`")) + error("The additional argument of `@kernel` is incorrect") end - else - arg end - if isnothing(scale) - return esc(:(transform($k, $t))) + if scale === nothing + return :(transform($(esc(k)), $(esc(t)))) else - return esc(:($scale * transform($k, $t))) + return :($(esc(scale)) * transform($(esc(k)), $(esc(t)))) end end diff --git a/test/kernels/kernel_macro.jl b/test/kernels/kernel_macro.jl index 97f70cd72..d27582357 100644 --- a/test/kernels/kernel_macro.jl +++ b/test/kernels/kernel_macro.jl @@ -1,9 +1,10 @@ @testset "Kernel Macro" begin @test (@kernel SqExponentialKernel()) isa SqExponentialKernel + @test_throws ErrorException @kernel sqrt(SqExponentialKernel) @test (@kernel 3.0 * SqExponentialKernel()) isa ScaledKernel{SqExponentialKernel,Float64} - @test (@kernel 3.0 * SqExponentialKernel() l=3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64} - @test (@kernel 3.0 * SqExponentialKernel() 3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64} + @test (@kernel 3.0 * SqExponentialKernel() l = 3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64} + # @test (@kernel 3.0 * SqExponentialKernel() 3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64} @test (@kernel 3.0 * SqExponentialKernel() l=[3.0]) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ARDTransform{Vector{Float64}}},Float64} - @test (@kernel 3.0 * SqExponentialKernel() LinearTransform(rand(3,2))) isa ScaledKernel{TransformedKernel{SqExponentialKernel,LinearTransform{Array{Float64,2}}},Float64} - @test (@kernel (3.0 * SqExponentialKernel() + 5.0 * Matern32Kernel()) 3.0) isa TransformedKernel{KernelSum,ScaleTransform{Float64}} + # @test (@kernel 3.0 * SqExponentialKernel() LinearTransform(rand(3,2))) isa ScaledKernel{TransformedKernel{SqExponentialKernel,LinearTransform{Array{Float64,2}}},Float64} + # @test (@kernel (3.0 * SqExponen Date: Tue, 12 May 2020 14:44:01 +0200 Subject: [PATCH 13/14] Added some more transform options and tests --- src/kernels/transformedkernel.jl | 9 ++++++--- src/transform/transform.jl | 4 ---- test/kernels/transformedkernel.jl | 10 ++++++++-- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index 9357c6770..39bfb3e94 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -9,9 +9,10 @@ struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel transform::Tr end -function TransformedKernel(k::TransformedKernel,t::Transform) - TransformedKernel(kernel(k),t∘k.transform) +function TransformedKernel(k::TransformedKernel, t::Transform) + TransformedKernel(kernel(k), t ∘ k.transform) end + (k::TransformedKernel)(x, y) = k.kernel(k.transform(x), k.transform(y)) # Optimizations for scale transforms of simple kernels to save allocations: @@ -48,7 +49,9 @@ transform(k::Kernel, t::Transform) = TransformedKernel(k, t) transform(k::Kernel, ρ::Real) = TransformedKernel(k, ScaleTransform(ρ)) -transform(k::Kernel,ρ::AbstractVector) = TransformedKernel(k, ARDTransform(ρ)) +transform(k::Kernel, ρ::AbstractVector) = TransformedKernel(k, ARDTransform(ρ)) + +transform(k::Kernel, ρ::AbstractMatrix) = TransformedKernel(k, LinearTransform(ρ)) transform(k::Kernel, ::Nothing) = k diff --git a/src/transform/transform.jl b/src/transform/transform.jl index 10f7551c6..7d2bbe22c 100644 --- a/src/transform/transform.jl +++ b/src/transform/transform.jl @@ -22,10 +22,6 @@ struct IdentityTransform <: Transform end (t::IdentityTransform)(x) = x Base.map(::IdentityTransform, x::AbstractVector) = x -Transform(ρ::Real) = ScaleTransform(ρ) -Transform(ρ::AbstractVector) = ARDTransform(ρ) -Transform(t::Transform) = t - ### TODO Maybe defining adjoints could help but so far it's not working diff --git a/test/kernels/transformedkernel.jl b/test/kernels/transformedkernel.jl index cabbe0008..7fca38bb9 100644 --- a/test/kernels/transformedkernel.jl +++ b/test/kernels/transformedkernel.jl @@ -7,8 +7,8 @@ s = rand(rng) v = rand(rng, 3) k = SqExponentialKernel() - kt = TransformedKernel(k,ScaleTransform(s)) - ktard = TransformedKernel(k,ARDTransform(v)) + kt = TransformedKernel(k, ScaleTransform(s)) + ktard = TransformedKernel(k, ARDTransform(v)) @test kt(v1, v2) == transform(k, ScaleTransform(s))(v1, v2) @test kt(v1, v2) == transform(k, s)(v1,v2) @test kt(v1, v2) ≈ k(s * v1, s * v2) atol=1e-5 @@ -16,6 +16,12 @@ @test ktard(v1, v2) == transform(k,v)(v1, v2) @test ktard(v1, v2) == k(v .* v1, v .* v2) + @test transform(kt, s) isa TransformedKernel{SqExponentialKernel,ChainTransform{Array{ScaleTransform{Float64},1}}} + + @test transform(k, s) isa TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}} + @test transform(k, v) isa TransformedKernel{SqExponentialKernel,ARDTransform{Array{Float64,1}}} + @test transform(k, rand(3, 2)) isa TransformedKernel{SqExponentialKernel,LinearTransform{Array{Float64,2}}} + @testset "kernelmatrix" begin rng = MersenneTwister(123456) From 18c2fec1e6a0d1aa1540247ba8c229e5d33a7f9e Mon Sep 17 00:00:00 2001 From: Theo Galy-Fajou Date: Tue, 12 May 2020 15:17:14 +0200 Subject: [PATCH 14/14] Corrected error behavior --- src/kernels/kernel_macro.jl | 4 ++-- test/kernels/kernel_macro.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/kernels/kernel_macro.jl b/src/kernels/kernel_macro.jl index fbbe40065..125fa41d7 100644 --- a/src/kernels/kernel_macro.jl +++ b/src/kernels/kernel_macro.jl @@ -37,14 +37,14 @@ true ``` """ macro kernel(expr::Expr, arg = nothing) - @capture(expr, (scale_ * k_ | k_)) || error("@kernel first arguments should be of the form `σ * kernel` or `kernel`") + @capture(expr, ((scale_ * k_) | (k_))) if arg === nothing t = nothing else if @capture(arg, ((l = val_) | (t = val_))) t = val else - error("The additional argument of `@kernel` is incorrect") + return :(error("The additional argument of `@kernel` is incorrect")) end end if scale === nothing diff --git a/test/kernels/kernel_macro.jl b/test/kernels/kernel_macro.jl index d27582357..8bd2b8959 100644 --- a/test/kernels/kernel_macro.jl +++ b/test/kernels/kernel_macro.jl @@ -1,10 +1,10 @@ @testset "Kernel Macro" begin @test (@kernel SqExponentialKernel()) isa SqExponentialKernel - @test_throws ErrorException @kernel sqrt(SqExponentialKernel) @test (@kernel 3.0 * SqExponentialKernel()) isa ScaledKernel{SqExponentialKernel,Float64} @test (@kernel 3.0 * SqExponentialKernel() l = 3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64} # @test (@kernel 3.0 * SqExponentialKernel() 3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64} @test (@kernel 3.0 * SqExponentialKernel() l=[3.0]) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ARDTransform{Vector{Float64}}},Float64} # @test (@kernel 3.0 * SqExponentialKernel() LinearTransform(rand(3,2))) isa ScaledKernel{TransformedKernel{SqExponentialKernel,LinearTransform{Array{Float64,2}}},Float64} - # @test (@kernel (3.0 * SqExponen