Skip to content

Commit 1b0a74e

Browse files
authored
Merge pull request #57 from theogf/syntacticsugarallkernels
Created concrete types to call syntactic sugar on all kernels
2 parents 5128791 + b8f7e82 commit 1b0a74e

File tree

8 files changed

+86
-65
lines changed

8 files changed

+86
-65
lines changed

src/generic.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@ printshifted(io::IO,κ::Kernel,shift::Int) = print(io,"$κ")
1515
Base.show(io::IO::Kernel) = print(io,nameof(typeof(κ)))
1616

1717
### Syntactic sugar for creating matrices and using kernel functions
18-
for k in subtypes(BaseKernel)
19-
if k [FBMKernel] continue end #for kernels without `metric` or `kappa`
18+
function concretetypes(k, ktypes::Vector)
19+
isempty(subtypes(k)) ? push!(ktypes, k) : concretetypes.(subtypes(k), Ref(ktypes))
20+
return ktypes
21+
end
22+
23+
for k in concretetypes(Kernel, [])
2024
@eval begin
21-
@inline::$k)(d::Real) = kappa(κ,d) #TODO Add test
2225
@inline::$k)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) = kappa(κ, x, y)
2326
@inline::$k)(X::AbstractMatrix{T}, Y::AbstractMatrix{T}; obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ, X, Y, obsdim=obsdim)
2427
@inline::$k)(X::AbstractMatrix{T}; obsdim::Integer=defaultobs) where {T} = kernelmatrix(κ, X, obsdim=obsdim)

src/kernels/fbm.jl

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,30 @@ For `h=1/2`, this is the Wiener Kernel, for `h>1/2`, the increments are
1010
positively correlated and for `h<1/2` the increments are negatively correlated.
1111
"""
1212
struct FBMKernel{T<:Real} <: BaseKernel
13-
h::T
13+
h::Vector{T}
1414
function FBMKernel(; h::T=0.5) where {T<:Real}
15-
@assert h<=1.0 && h>=0.0 "FBMKernel: Given Hurst index h is invalid."
16-
return new{T}(h)
15+
@assert 0.0 <= h <= 1.0 "FBMKernel: Given Hurst index h is invalid."
16+
return new{T}([h])
1717
end
1818
end
1919

20-
Base.show(io::IO, κ::FBMKernel) = print(io, "Fractional Brownian Motion Kernel (h = $(k.h))")
20+
Base.show(io::IO, κ::FBMKernel) = print(io, "Fractional Brownian Motion Kernel (h = $(first(k.h)))")
21+
22+
const sqroundoff = 1e-15
2123

2224
_fbm(modX, modY, modXY, h) = (modX^h + modY^h - modXY^h)/2
2325

2426
function kernelmatrix::FBMKernel, X::AbstractMatrix; obsdim::Int = defaultobs)
2527
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
26-
modX = sum(abs2, X; dims = 3 - obsdim)
27-
modXX = pairwise(SqEuclidean(), X, dims = obsdim)
28+
modX = sum(abs2, X; dims = feature_dim(obsdim))
29+
modXX = pairwise(SqEuclidean(sqroundoff), X, dims = obsdim)
2830
return _fbm.(vec(modX), reshape(modX, 1, :), modXX, κ.h)
2931
end
3032

3133
function kernelmatrix!(K::AbstractMatrix, κ::FBMKernel, X::AbstractMatrix; obsdim::Int = defaultobs)
3234
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
33-
modX = sum(abs2, X; dims = 3 - obsdim)
34-
modXX = pairwise(SqEuclidean(), X, dims = obsdim)
35+
modX = sum(abs2, X; dims = feature_dim(obsdim))
36+
modXX = pairwise(SqEuclidean(sqroundoff), X, dims = obsdim)
3537
K .= _fbm.(vec(modX), reshape(modX, 1, :), modXX, κ.h)
3638
return K
3739
end
@@ -43,9 +45,9 @@ function kernelmatrix(
4345
obsdim::Int = defaultobs,
4446
)
4547
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
46-
modX = sum(abs2, X, dims=3-obsdim)
47-
modY = sum(abs2, Y, dims=3-obsdim)
48-
modXY = pairwise(SqEuclidean(), X, Y,dims=obsdim)
48+
modX = sum(abs2, X, dims = feature_dim(obsdim))
49+
modY = sum(abs2, Y, dims = feature_dim(obsdim))
50+
modXY = pairwise(SqEuclidean(sqroundoff), X, Y,dims = obsdim)
4951
return _fbm.(vec(modX), reshape(modY, 1, :), modXY, κ.h)
5052
end
5153

@@ -57,9 +59,9 @@ function kernelmatrix!(
5759
obsdim::Int = defaultobs,
5860
)
5961
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
60-
modX = sum(abs2, X, dims=3-obsdim)
61-
modY = sum(abs2, Y, dims=3-obsdim)
62-
modXY = pairwise(SqEuclidean(), X, Y,dims=obsdim)
62+
modX = sum(abs2, X, dims = feature_dim(obsdim))
63+
modY = sum(abs2, Y, dims = feature_dim(obsdim))
64+
modXY = pairwise(SqEuclidean(sqroundoff), X, Y,dims = obsdim)
6365
K .= _fbm.(vec(modX), reshape(modY, 1, :), modXY, κ.h)
6466
return K
6567
end
@@ -72,23 +74,15 @@ function _kernel(
7274
obsdim::Int = defaultobs
7375
)
7476
@assert length(x) == length(y) "x and y don't have the same dimension!"
75-
return κ(x,y)
77+
return kappa(κ, x, y)
7678
end
7779

78-
#Syntactic Sugar
79-
function::FBMKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
80+
function kappa::FBMKernel, x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
8081
modX = sum(abs2, x)
8182
modY = sum(abs2, y)
82-
modXY = sqeuclidean(x, y)
83-
return (modX^κ.h + modY^κ.h - modXY^κ.h)/2
83+
modXY = evaluate(SqEuclidean(sqroundoff), x, y)
84+
h = first.h)
85+
return (modX^h + modY^h - modXY^h)/2
8486
end
8587

86-
::FBMKernel)(x::Real, y::Real) = (abs2(x)^κ.h + abs2(y)^κ.h - abs2(x-y)^κ.h)/2
87-
88-
function::FBMKernel)(X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}; obsdim::Integer=defaultobs)
89-
return kernelmatrix(κ, X, Y, obsdim=obsdim)
90-
end
91-
92-
function::FBMKernel)(X::AbstractMatrix{<:Real}; obsdim::Integer=defaultobs)
93-
return kernelmatrix(κ, X, obsdim=obsdim)
94-
end
88+
::FBMKernel)(x::Real, y::Real) = (abs2(x)^first.h) + abs2(y)^first.h) - abs2(x-y)^first.h))/2

src/matrix/kernelmatrix.jl

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""
2-
```
3-
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix; obsdim::Integer=2)
4-
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix, Y::Matrix; obsdim::Integer=2)
5-
```
6-
In-place version of `kernelmatrix` where pre-allocated matrix `K` will be overwritten with the kernel matrix.
2+
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix; obsdim::Integer = 2)
3+
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix, Y::Matrix; obsdim::Integer = 2)
4+
5+
In-place version of [`kernelmatrix`](@ref) where pre-allocated matrix `K` will be overwritten with the kernel matrix.
76
"""
87
kernelmatrix!
98

@@ -21,7 +20,7 @@ function kernelmatrix!(
2120
map!(x->kappa(κ,x),K,pairwise(metric(κ),X,dims=obsdim))
2221
end
2322

24-
kernelmatrix!(K::Matrix, κ::TransformedKernel, X::AbstractMatrix; obsdim::Int = defaultobs) =
23+
kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, X::AbstractMatrix; obsdim::Int = defaultobs) =
2524
kernelmatrix!(K, kernel(κ), apply.transform, X, obsdim = obsdim), obsdim = obsdim)
2625

2726
function kernelmatrix!(
@@ -61,13 +60,12 @@ _kernel(κ::TransformedKernel, x::AbstractVector, y::AbstractVector; obsdim::Int
6160
_kernel(kernel(κ), apply.transform, x), apply.transform, y), obsdim = obsdim)
6261

6362
"""
64-
```
65-
kernelmatrix(κ::Kernel, X::Matrix ; obsdim::Int=2)
66-
kernelmatrix(κ::Kernel, X::Matrix, Y::Matrix; obsdim::Int=2)
67-
```
63+
kernelmatrix(κ::Kernel, X::Matrix; obsdim::Int = 2)
64+
kernelmatrix(κ::Kernel, X::Matrix, Y::Matrix; obsdim::Int = 2)
65+
6866
Calculate the kernel matrix of `X` (and `Y`) with respect to kernel `κ`.
69-
`obsdim=1` means the matrix `X` (and `Y`) has size #samples x #dimension
70-
`obsdim=2` means the matrix `X` (and `Y`) has size #dimension x #samples
67+
`obsdim = 1` means the matrix `X` (and `Y`) has size #samples x #dimension
68+
`obsdim = 2` means the matrix `X` (and `Y`) has size #dimension x #samples
7169
"""
7270
kernelmatrix
7371

@@ -109,12 +107,11 @@ kernelmatrix(κ::TransformedKernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim
109107
kernelmatrix(kernel(κ), apply.transform, X, obsdim = obsdim), apply.transform, Y, obsdim = obsdim), obsdim = obsdim)
110108

111109
"""
112-
```
113-
kerneldiagmatrix(κ::Kernel, X::Matrix; obsdim::Int=2)
114-
```
110+
kerneldiagmatrix(κ::Kernel, X::Matrix; obsdim::Int = 2)
111+
115112
Calculate the diagonal matrix of `X` with respect to kernel `κ`
116-
`obsdim=1` means the matrix `X` has size #samples x #dimension
117-
`obsdim=2` means the matrix `X` has size #dimension x #samples
113+
`obsdim = 1` means the matrix `X` has size #samples x #dimension
114+
`obsdim = 2` means the matrix `X` has size #dimension x #samples
118115
"""
119116
function kerneldiagmatrix(
120117
κ::Kernel,
@@ -130,10 +127,9 @@ function kerneldiagmatrix(
130127
end
131128

132129
"""
133-
```
134-
kerneldiagmatrix!(K::AbstractVector,κ::Kernel, X::Matrix; obsdim::Int=2)
135-
```
136-
In place version of `kerneldiagmatrix`
130+
kerneldiagmatrix!(K::AbstractVector,κ::Kernel, X::Matrix; obsdim::Int = 2)
131+
132+
In place version of [`kerneldiagmatrix`](@ref)
137133
"""
138134
function kerneldiagmatrix!(
139135
K::AbstractVector,

src/trainable.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import .Flux.trainable
44

55
trainable(k::ConstantKernel) = (k.c,)
66

7+
trainable(k::FBMKernel) = (k.h,)
8+
79
trainable(k::GammaExponentialKernel) = (k.γ,)
810

911
trainable(k::GammaRationalQuadraticKernel) = (k.α, k.γ)

test/kernels/custom.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ KernelFunctions.kappa(::MyKernel, d2::Real) = exp(-d2)
55
KernelFunctions.metric(::MyKernel) = SqEuclidean()
66

77
# some syntactic sugar
8-
::MyKernel)(d::Real) = kappa(κ, d)
98
::MyKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real}) = kappa(κ, x, y)
109
::MyKernel)(X::AbstractMatrix{<:Real}, Y::AbstractMatrix{<:Real}; obsdim = 2) = kernelmatrix(κ, X, Y; obsdim = obsdim)
1110
::MyKernel)(X::AbstractMatrix{<:Real}; obsdim = 2) = kernelmatrix(κ, X; obsdim = obsdim)
@@ -17,7 +16,6 @@ KernelFunctions.metric(::MyKernel) = SqEuclidean()
1716
@test kernelmatrix(MyKernel(), [1 2; 3 4], [5 6; 7 8]) == kernelmatrix(SqExponentialKernel(), [1 2; 3 4], [5 6; 7 8])
1817
@test kernelmatrix(MyKernel(), [1 2; 3 4]) == kernelmatrix(SqExponentialKernel(), [1 2; 3 4])
1918

20-
@test MyKernel()(3) == SqExponentialKernel()(3)
2119
@test MyKernel()([1, 2], [3, 4]) == SqExponentialKernel()([1, 2], [3, 4])
2220
@test MyKernel()([1 2; 3 4], [5 6; 7 8]) == SqExponentialKernel()([1 2; 3 4], [5 6; 7 8])
2321
@test MyKernel()([1 2; 3 4]) == SqExponentialKernel()([1 2; 3 4])

test/kernels/fbm.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
@testset "FBM" begin
2+
h = 0.3
3+
k = FBMKernel(h = h)
4+
v1 = rand(3); v2 = rand(3)
5+
@test k(v1,v2) (sqeuclidean(v1, zero(v1))^h + sqeuclidean(v2, zero(v2))^h - sqeuclidean(v1-v2, zero(v1-v2))^h)/2 atol=1e-5
6+
7+
# kernelmatrix tests
8+
m1 = rand(3,3)
9+
m2 = rand(3,3)
10+
@test kernelmatrix(k, m1, m1) kernelmatrix(k, m1) atol=1e-5
11+
@test kernelmatrix(k, m1, m2) k(m1, m2) atol=1e-5
12+
13+
14+
x1 = rand()
15+
x2 = rand()
16+
@test kernelmatrix(k, x1*ones(1,1), x2*ones(1,1))[1] k(x1, x2) atol=1e-5
17+
end

test/runtests.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,17 +62,18 @@ using KernelFunctions: metric
6262
end
6363

6464
@testset "kernels" begin
65+
include(joinpath("kernels", "constant.jl"))
66+
include(joinpath("kernels", "cosine.jl"))
6567
include(joinpath("kernels", "exponential.jl"))
68+
include(joinpath("kernels", "exponentiated.jl"))
69+
include(joinpath("kernels", "fbm.jl"))
70+
include(joinpath("kernels", "kernelproduct.jl"))
71+
include(joinpath("kernels", "kernelsum.jl"))
6672
include(joinpath("kernels", "matern.jl"))
6773
include(joinpath("kernels", "polynomial.jl"))
68-
include(joinpath("kernels", "constant.jl"))
6974
include(joinpath("kernels", "rationalquad.jl"))
70-
include(joinpath("kernels", "exponentiated.jl"))
71-
include(joinpath("kernels", "cosine.jl"))
72-
include(joinpath("kernels", "transformedkernel.jl"))
7375
include(joinpath("kernels", "scaledkernel.jl"))
74-
include(joinpath("kernels", "kernelsum.jl"))
75-
include(joinpath("kernels", "kernelproduct.jl"))
76+
include(joinpath("kernels", "transformedkernel.jl"))
7677

7778
# Legacy tests that don't correspond to anything meaningful in src. Unclear how
7879
# helpful these are.

test/trainable.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
11
@testset "trainable" begin
2-
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5
2+
ν = 2.0; c = 3.0; d = 2.0; γ = 2.0; α = 2.5; h = 0.5
3+
34
kc = ConstantKernel(c=c)
45
@test all(params(kc) .== params([c]))
5-
km = MaternKernel=ν)
6-
@test all(params(km) .== params([ν]))
7-
kl = LinearKernel(c=c)
8-
@test all(params(kl) .== params([c]))
6+
7+
kfbm = FBMKernel(h = h)
8+
@test all(params(kfbm) .== params([h]))
9+
910
kge = GammaExponentialKernel=γ)
1011
@test all(params(kge) .== params([γ]))
12+
1113
kgr = GammaRationalQuadraticKernel=γ, α=α)
1214
@test all(params(kgr) .== params([α], [γ]))
15+
16+
kl = LinearKernel(c=c)
17+
@test all(params(kl) .== params([c]))
18+
19+
km = MaternKernel=ν)
20+
@test all(params(km) .== params([ν]))
21+
1322
kp = PolynomialKernel(c=c, d=d)
1423
@test all(params(kp) .== params([d], [c]))
24+
1525
kr = RationalQuadraticKernel=α)
1626
@test all(params(kr) .== params([α]))
1727

0 commit comments

Comments
 (0)