Skip to content
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Expand All @@ -16,6 +17,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
Compat = "2.2, 3"
Distances = "0.8.2"
MacroTools = "0.5"
Requires = "1.0.1"
SpecialFunctions = "0.8, 0.9, 0.10"
StatsBase = "0.32, 0.33"
Expand Down
18 changes: 12 additions & 6 deletions docs/src/userguide.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,24 @@

To create a kernel chose one of the kernels proposed, see [Kernels](@ref), or create your own, see [Creating Kernels](@ref)
For example to create a square exponential kernel

```julia
k = SqExponentialKernel()
```
Instead of having lengthscale(s) for each kernel we use `Transform` objects (see [Transform](@ref)) which are directly going to act on the inputs before passing them to the kernel.
For example to premultiply the input by 2.0 we create the kernel the following options are possible

Instead of having lengthscale(s) for each kernel we use `Transform` objects (see [Transform](@ref)). The transform operations are going to be applied on the inputs before they are passed to the kernel.
For example, the [ScaleTransform](@ref) multiply every sample by a scalar $\rho$. A `SqExponentialKernel` with a `ScaleTransform(ρ)`, is therefore equivalent to have a `SqExponentialKernel` with lengthscale `1/ρ`.
Here are some examples on how to use these transformations and are all equivalent:
```julia
k = transform(SqExponentialKernel(),ScaleTransform(2.0)) # returns a TransformedKernel
k = @kernel SqExponentialKernel() l=2.0 # Will be available soon
k = TransformedKernel(SqExponentialKernel(),ScaleTransform(2.0))
k = TransformedKernel(SqExponentialKernel(), ScaleTransform(2.0)) # Constructor
k = transform(SqExponentialKernel(), ScaleTransform(2.0)) # wrapper for a constructor
k = @kernel SqExponentialKernel() l=2.0 # Convenience macro
```

Check the [`Transform`](@ref) page to see the other options.
To premultiply the kernel by a variance, you can use `*` or create a `ScaledKernel`
---
To pre-multiply the kernel by a variance parameter, you can use `*` or create a `ScaledKernel`

```julia
k = 3.0*SqExponentialKernel()
k = ScaledKernel(SqExponentialKernel(),3.0)
Expand Down
4 changes: 3 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ export kernelmatrix, kernelmatrix!, kerneldiagmatrix, kerneldiagmatrix!
export transform
export duplicate, set! # Helpers

export Kernel
export Kernel, BaseKernel, @kernel
export ConstantKernel, WhiteKernel, EyeKernel, ZeroKernel, WienerKernel
export CosineKernel
export SqExponentialKernel, RBFKernel, GaussianKernel, SEKernel
Expand Down Expand Up @@ -38,6 +38,7 @@ using SpecialFunctions: logabsgamma, besselk
using ZygoteRules: @adjoint
using StatsFuns: logtwo
using InteractiveUtils: subtypes
using MacroTools: @capture
using StatsBase

"""
Expand All @@ -61,6 +62,7 @@ end

include("kernels/transformedkernel.jl")
include("kernels/scaledkernel.jl")
include("kernels/kernel_macro.jl")
include("matrix/kernelmatrix.jl")
include("kernels/kernelsum.jl")
include("kernels/kernelproduct.jl")
Expand Down
39 changes: 39 additions & 0 deletions src/kernels/kernel_macro.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""
@kernel [variance *]kernel::Kernel [l=Real/Vector / t=transform::Transform / transform::Transform]
The `@kernel` macro is an helping alias to the [`transform`](@ref) function.
The first argument should be a kernel multiplied (or not) by a scalar (variance of the kernel).
The second argument (optional) can be a keyword :
- `l=ρ` where `ρ` is a positive scalar or a vector of scalar
- `t=transform` where `transform` is a [`Transform`](@ref) object
One can also directly use a `Transform` object without a keyword.
Here are some examples :
```julia
k = @kernel SqExponentialKernel() l=3.0
k == transform(SqExponentialKernel(), ScaleTransform(3.0))
k = @kernel (MaternKernel(ν=3.0) + LinearKernel()) t=LowRankTransform(rand(4,3))
k == transform(KernelSum(MaternKernel(ν=3.0), LinearKernel()), LowRankTransform(rand(4,3)))
k = @kernel 4.0*ExponentiatedKernel() ScaleTransform(3.0)
k == ScaleTransform(transform(ExponentiatedKernel(), ScaleTransform(3.0)), 4.0)
"""
macro kernel(expr::Expr, arg = nothing)
@capture(expr, (scale_ * k_ | k_)) || throw(error("@kernel first arguments should be of the form `σ * kernel` or `kernel`"))
t = if @capture(arg, kw_ = val_)
if kw == :l
val
elseif kw == :t
val
else
throw(error("The additional argument could not be intepreted. Please see documentation of `@kernel`"))
end
else
arg
end
if isnothing(scale)
return esc(:(transform($k, $t)))
else
return esc(:($scale * transform($k, $t)))
end
end
11 changes: 8 additions & 3 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel
transform::Tr
end

function TransformedKernel(k::TransformedKernel,t::Transform)
TransformedKernel(kernel(k),t∘k.transform)
end
(k::TransformedKernel)(x, y) = k.kernel(k.transform(x), k.transform(y))

# Optimizations for scale transforms of simple kernels to save allocations:
Expand Down Expand Up @@ -41,11 +44,13 @@ _scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y))
"""
transform

transform(k::BaseKernel, t::Transform) = TransformedKernel(k, t)
transform(k::Kernel, t::Transform) = TransformedKernel(k, t)

transform(k::Kernel, ρ::Real) = TransformedKernel(k, ScaleTransform(ρ))

transform(k::BaseKernel, ρ::Real) = TransformedKernel(k, ScaleTransform(ρ))
transform(k::Kernel,ρ::AbstractVector) = TransformedKernel(k, ARDTransform(ρ))

transform(k::BaseKernel,ρ::AbstractVector) = TransformedKernel(k, ARDTransform(ρ))
transform(k::BaseKernel,::Nothing) = k

kernel(κ) = κ.kernel

Expand Down
4 changes: 4 additions & 0 deletions src/transform/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ struct IdentityTransform <: Transform end
(t::IdentityTransform)(x) = x
Base.map(::IdentityTransform, x::AbstractVector) = x

Transform(ρ::Real) = ScaleTransform(ρ)
Transform(ρ::AbstractVector) = ARDTransform(ρ)
Transform(t::Transform) = t

### TODO Maybe defining adjoints could help but so far it's not working


Expand Down
9 changes: 9 additions & 0 deletions test/kernels/kernel_macro.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testset "Kernel Macro" begin
@test (@kernel SqExponentialKernel()) isa SqExponentialKernel
@test (@kernel 3.0 * SqExponentialKernel()) isa ScaledKernel{SqExponentialKernel,Float64}
@test (@kernel 3.0 * SqExponentialKernel() l=3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64}
@test (@kernel 3.0 * SqExponentialKernel() 3.0) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ScaleTransform{Float64}},Float64}
@test (@kernel 3.0 * SqExponentialKernel() l=[3.0]) isa ScaledKernel{TransformedKernel{SqExponentialKernel,ARDTransform{Vector{Float64}}},Float64}
@test (@kernel 3.0 * SqExponentialKernel() LinearTransform(rand(3,2))) isa ScaledKernel{TransformedKernel{SqExponentialKernel,LinearTransform{Array{Float64,2}}},Float64}
@test (@kernel (3.0 * SqExponentialKernel() + 5.0 * Matern32Kernel()) 3.0) isa TransformedKernel{KernelSum,ScaleTransform{Float64}}
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ using KernelFunctions: metric, kappa
@testset "kernels" begin
include(joinpath("kernels", "kernelproduct.jl"))
include(joinpath("kernels", "kernelsum.jl"))
include(joinpath("kernels", "kernel_macro.jl"))
include(joinpath("kernels", "scaledkernel.jl"))
include(joinpath("kernels", "tensorproduct.jl"))
include(joinpath("kernels", "transformedkernel.jl"))
Expand Down