Skip to content

Commit 589daff

Browse files
Extension of #269: Use \circ and compose and deprecate transform (#276)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: kaandocal <>
1 parent 55f4909 commit 589daff

20 files changed

+122
-101
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "KernelFunctions"
22
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
3-
version = "0.9.1"
3+
version = "0.9.2"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
8+
CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b"
89
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
910
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1011
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -20,6 +21,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2021
[compat]
2122
ChainRulesCore = "0.9"
2223
Compat = "3.7"
24+
CompositionsBase = "0.1"
2325
Distances = "0.10"
2426
Functors = "0.1"
2527
Requires = "1.0.1"

docs/create_kernel_plots.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ n_grid = 101
1313
fill(x₀, n_grid, 1)
1414
xrange = reshape(collect(range(-3, 3; length=n_grid)), :, 1)
1515

16-
k = transform(SqExponentialKernel(), 1.0)
16+
k = SqExponentialKernel() ScaleTransform(1.0)
1717
K1 = kernelmatrix(k, xrange; obsdim=1)
1818
p = heatmap(
1919
K1;
@@ -35,7 +35,7 @@ p = heatmap(
3535
)
3636
savefig(joinpath(@__DIR__, "src", "assets", "heatmap_matern.png"))
3737

38-
k = transform(PolynomialKernel(; c=0.0, d=2.0), LinearTransform(randn(3, 1)))
38+
k = PolynomialKernel(; c=0.0, d=2.0) LinearTransform(randn(3, 1))
3939
K3 = kernelmatrix(k, xrange; obsdim=1)
4040
p = heatmap(
4141
K3;
@@ -47,7 +47,7 @@ p = heatmap(
4747
savefig(joinpath(@__DIR__, "src", "assets", "heatmap_poly.png"))
4848

4949
k =
50-
0.5 * SqExponentialKernel() * transform(LinearKernel(), 0.5) +
50+
0.5 * SqExponentialKernel() * (LinearKernel() ScaleTransform(0.5)) +
5151
0.4 * (@kernel Matern32Kernel() FunctionTransform(x -> sin.(x)))
5252
K4 = kernelmatrix(k, xrange; obsdim=1)
5353
p = heatmap(

docs/src/kernels.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,7 @@ of kernels together.
118118

119119
```@docs
120120
TransformedKernel
121-
transform(::Kernel, ::Transform)
122-
transform(::Kernel, ::Real)
123-
transform(::Kernel, ::AbstractVector)
121+
∘(::Kernel, ::Transform)
124122
ScaledKernel
125123
KernelSum
126124
KernelProduct

docs/src/transform.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ LowRankTransform(rand(10, 5)) ∘ ScaleTransform(2.0)
1818
A transformation `t` can be applied to a single input `x` with `t(x)` and to multiple inputs
1919
`xs` with `map(t, xs)`.
2020

21-
Kernels can be coupled with input transformations with
22-
[`transform`](@ref). It falls back to creating a [`TransformedKernel`](@ref) but allows more
21+
Kernels can be coupled with input transformations with [``](@ref) or its alias `compose`. It falls
22+
back to creating a [`TransformedKernel`](@ref) but allows more
2323
optimized implementations for specific kernels and transformations.
2424

2525
## List of Input Transforms

src/KernelFunctions.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ export MOInput
3939
export IndependentMOKernel, LatentFactorMOKernel
4040

4141
# Reexports
42-
export tensor,
42+
export tensor, , compose
4343

4444
using Compat
4545
using ChainRulesCore: ChainRulesCore, Composite, Zero, One, DoesNotExist, NO_FIELDS
4646
using ChainRulesCore: @thunk, InplaceableThunk
47+
using CompositionsBase
4748
using Requires
4849
using Distances, LinearAlgebra
4950
using Functors
@@ -106,6 +107,8 @@ include("zygoterules.jl")
106107

107108
include("test_utils.jl")
108109

110+
include("deprecations.jl")
111+
109112
function __init__()
110113
@require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin
111114
include(joinpath("matrix", "kernelkroneckermat.jl"))

src/basekernels/gabor.jl

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@ k(x, x'; l, p) = \\exp\\bigg(- \\cos\\bigg(\\pi\\sum_{i=1}^d \\frac{x_i - x'_i}{
1414
"""
1515
struct GaborKernel{K<:Kernel} <: Kernel
1616
kernel::K
17+
1718
function GaborKernel(; ell=nothing, p=nothing)
18-
k = _gabor(; ell=ell, p=p)
19+
ell_transform = _lengthscale_transform(ell)
20+
p_transform = _lengthscale_transform(p)
21+
k = (SqExponentialKernel() ell_transform) * (CosineKernel() p_transform)
1922
return new{typeof(k)}(k)
2023
end
2124
end
@@ -24,38 +27,23 @@ end
2427

2528
::GaborKernel)(x, y) = κ.kernel(x, y)
2629

27-
function _gabor(; ell=nothing, p=nothing)
28-
if ell === nothing
29-
if p === nothing
30-
return SqExponentialKernel() * CosineKernel()
31-
else
32-
return SqExponentialKernel() * transform(CosineKernel(), 1 ./ p)
33-
end
34-
elseif p === nothing
35-
return transform(SqExponentialKernel(), 1 ./ ell) * CosineKernel()
36-
else
37-
return transform(SqExponentialKernel(), 1 ./ ell) *
38-
transform(CosineKernel(), 1 ./ p)
39-
end
40-
end
30+
_lengthscale_transform(::Nothing) = IdentityTransform()
31+
_lengthscale_transform(x::Real) = ScaleTransform(inv(x))
32+
_lengthscale_transform(x::AbstractVector) = ARDTransform(map(inv, x))
33+
34+
_lengthscale(::IdentityTransform) = 1
35+
_lengthscale(t::ScaleTransform) = inv(first(t.s))
36+
_lengthscale(t::ARDTransform) = map(inv, t.v)
4137

4238
function Base.getproperty(k::GaborKernel, v::Symbol)
4339
if v == :kernel
4440
return getfield(k, v)
4541
elseif v == :ell
46-
kernel1 = k.kernel.kernels[1]
47-
if kernel1 isa TransformedKernel
48-
return 1 ./ kernel1.transform.s[1]
49-
else
50-
return 1.0
51-
end
42+
ell_transform = k.kernel.kernels[1].transform
43+
return _lengthscale(ell_transform)
5244
elseif v == :p
53-
kernel2 = k.kernel.kernels[2]
54-
if kernel2 isa TransformedKernel
55-
return 1 ./ kernel2.transform.s[1]
56-
else
57-
return 1.0
58-
end
45+
p_transform = k.kernel.kernels[2].transform
46+
return _lengthscale(p_transform)
5947
else
6048
error("Invalid Property")
6149
end

src/deprecations.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
@deprecate transform(k::Kernel, t::Transform) k t
2+
@deprecate transform(k::TransformedKernel, t::Transform) k.kernel t k.transform
3+
@deprecate transform(k::Kernel, ρ::Real) k ScaleTransform(ρ)
4+
@deprecate transform(k::Kernel, ρ::AbstractVector) k ARDTransform(ρ)

src/kernels/transformedkernel.jl

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,11 @@
33
44
Kernel derived from `k` for which inputs are transformed via a [`Transform`](@ref) `t`.
55
6-
It is preferred to create kernels with input transformations with [`transform`](@ref)
7-
instead of `TransformedKernel` directly since [`transform`](@ref) allows optimized
8-
implementations for specific kernels and transformations.
6+
The preferred way to create kernels with input transformations is to use the composition
7+
operator [`∘`](@ref) or its alias `compose` instead of `TransformedKernel` directly since
8+
this allows optimized implementations for specific kernels and transformations.
99
10-
# Definition
11-
12-
For inputs ``x, x'``, the transformed kernel ``\\widetilde{k}`` derived from kernel ``k`` by
13-
input transformation ``t`` is defined as
14-
```math
15-
\\widetilde{k}(x, x'; k, t) = k\\big(t(x), t(x')\\big).
16-
```
10+
See also: [`∘`](@ref)
1711
"""
1812
struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel
1913
kernel::Tk
@@ -42,30 +36,37 @@ end
4236
_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y))
4337

4438
"""
45-
transform(k::Kernel, t::Transform)
39+
kernel ∘ transform
40+
∘(kernel, transform)
41+
compose(kernel, transform)
4642
47-
Create a [`TransformedKernel`](@ref) for kernel `k` and transform `t`.
48-
"""
49-
transform(k::Kernel, t::Transform) = TransformedKernel(k, t)
50-
function transform(k::TransformedKernel, t::Transform)
51-
return TransformedKernel(k.kernel, t k.transform)
52-
end
43+
Compose a `kernel` with a transformation `transform` of its inputs.
5344
54-
"""
55-
transform(k::Kernel, ρ::Real)
45+
The prefix forms support chains of multiple transformations:
46+
`∘(kernel, transform1, transform2) = kernel ∘ transform1 ∘ transform2`.
5647
57-
Create a [`TransformedKernel`](@ref) for kernel `k` and inverse lengthscale `ρ`.
58-
"""
59-
transform(k::Kernel, ρ::Real) = transform(k, ScaleTransform(ρ))
48+
# Definition
6049
61-
"""
62-
transform(k::Kernel, ρ::AbstractVector)
50+
For inputs ``x, x'``, the transformed kernel ``\\widetilde{k}`` derived from kernel ``k`` by
51+
input transformation ``t`` is defined as
52+
```math
53+
\\widetilde{k}(x, x'; k, t) = k\\big(t(x), t(x')\\big).
54+
```
6355
64-
Create a [`TransformedKernel`](@ref) for kernel `k` and inverse lengthscales `ρ`.
65-
"""
66-
transform(k::Kernel, ρ::AbstractVector) = transform(k, ARDTransform(ρ))
56+
# Examples
57+
58+
```jldoctest
59+
julia> (SqExponentialKernel() ∘ ScaleTransform(0.5))(0, 2) == exp(-0.5)
60+
true
6761
68-
kernel(κ) = κ.kernel
62+
julia> ∘(ExponentialKernel(), ScaleTransform(2), ScaleTransform(0.5))(1, 2) == exp(-1)
63+
true
64+
```
65+
66+
See also: [`TransformedKernel`](@ref)
67+
"""
68+
Base.:(k::Kernel, t::Transform) = TransformedKernel(k, t)
69+
Base.:(k::TransformedKernel, t::Transform) = TransformedKernel(k.kernel, k.transform t)
6970

7071
Base.show(io::IO, κ::TransformedKernel) = printshifted(io, κ, 0)
7172

@@ -87,13 +88,13 @@ function kernelmatrix_diag!(
8788
end
8889

8990
function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector)
90-
return kernelmatrix!(K, kernel(κ), _map.transform, x))
91+
return kernelmatrix!(K, κ.kernel, _map.transform, x))
9192
end
9293

9394
function kernelmatrix!(
9495
K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
9596
)
96-
return kernelmatrix!(K, kernel(κ), _map.transform, x), _map.transform, y))
97+
return kernelmatrix!(K, κ.kernel, _map.transform, x), _map.transform, y))
9798
end
9899

99100
function kernelmatrix_diag::TransformedKernel, x::AbstractVector)
@@ -105,9 +106,9 @@ function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector, y::Abstract
105106
end
106107

107108
function kernelmatrix::TransformedKernel, x::AbstractVector)
108-
return kernelmatrix(kernel(κ), _map.transform, x))
109+
return kernelmatrix(κ.kernel, _map.transform, x))
109110
end
110111

111112
function kernelmatrix::TransformedKernel, x::AbstractVector, y::AbstractVector)
112-
return kernelmatrix(kernel(κ), _map.transform, x), _map.transform, y))
113+
return kernelmatrix(κ.kernel, _map.transform, x), _map.transform, y))
113114
end

test/basekernels/gabor.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
k_manual = exp(-sqeuclidean(v1, v2) / (2 * k.ell^2)) * cospi(euclidean(v1, v2) / k.p)
1111
@test k(v1, v2) k_manual atol = 1e-5
1212

13-
lhs_manual = transform(SqExponentialKernel(), 1 / k.ell)(v1, v2)
14-
rhs_manual = transform(CosineKernel(), 1 / k.p)(v1, v2)
13+
lhs_manual = (SqExponentialKernel() ScaleTransform(1 / k.ell))(v1, v2)
14+
rhs_manual = (CosineKernel() ScaleTransform(1 / k.p))(v1, v2)
1515
@test k(v1, v2) lhs_manual * rhs_manual atol = 1e-5
1616

1717
k = GaborKernel()
1818
@test k.ell 1.0 atol = 1e-5
1919
@test k.p 1.0 atol = 1e-5
20-
@test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)"
20+
@test repr(k) == "Gabor Kernel (ell = 1, p = 1)"
2121

2222
test_interface(k, Vector{Float64})
2323

test/deprecations.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
@testset "deprecations.jl" begin
2+
p = rand()
3+
v = rand(3)
4+
M = rand(3, 3)
5+
v1 = rand(3)
6+
v2 = rand(3)
7+
kernel = SqExponentialKernel()
8+
9+
k1 = @test_deprecated transform(kernel, LinearTransform(M))
10+
@test k1(v1, v2) == (kernel LinearTransform(M))(v1, v2)
11+
12+
k2 = @test_deprecated transform(kernel ScaleTransform(p), ARDTransform(v))
13+
@test k2(v1, v2) == (kernel ARDTransform(v) ScaleTransform(p))(v1, v2)
14+
15+
k3 = @test_deprecated transform(kernel, p)
16+
@test k3(v1, v2) == (kernel ScaleTransform(p))(v1, v2)
17+
18+
k4 = @test_deprecated transform(kernel, v)
19+
@test k4(v1, v2) == (kernel ARDTransform(v))(v1, v2)
20+
end

0 commit comments

Comments
 (0)