Skip to content

Commit 73d68b8

Browse files
committed
WIP : Refactoring with Transform
1 parent ef39ba7 commit 73d68b8

File tree

6 files changed

+119
-39
lines changed

6 files changed

+119
-39
lines changed

src/KernelFunctions.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@ module KernelFunctions
33
export kernelmatrix, kernelmatrix!, kappa
44
export Kernel, SquaredExponentialKernel
55

6+
export Transform, ScaleTransform
7+
68
using Distances, LinearAlgebra
79

810
const defaultobs = 2
9-
abstract type Kernel{T<:Real} end
11+
abstract type Kernel{T,Tr} end
1012

1113
include("utils.jl")
1214
include("common.jl")
15+
include("transform/transform.jl")
1316
include("kernelmatrix.jl")
1417

1518
kernels = ["squaredexponential"]

src/common.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,8 @@
44
"""Apply functions of a kernel on a distance"""
55
# @inline (κ::K)(d::Real) where {K<:Kernel} = kappa(κ,d)
66

7+
@inline transform::Kernel) = κ.transform
8+
@inline transform::Kernel,x::AbstractVecOrMat) = transform.transform,x)
9+
@inline transform::Kernel,x::AbstractVecOrMat,obsdim::Int) = transform.transform,x,obsdim)
710

811
# @inline (κ::Kernel)(x::AbstractVector{<:Real},y::AbstractVector{<:Real}) = kappa(κ,evaluate(κ.(metric),x,y))

src/kernelmatrix.jl

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11

2-
function _kappamatrix!::Kernel{T}, P::AbstractMatrix{T₁}) where {T<:Real,T₁<:Real}
2+
function _kappamatrix!::Kernel, P::AbstractMatrix{T₁}) where {T₁<:Real}
33
for i in eachindex(P)
44
@inbounds P[i] = kappa(κ, P[i])
55
end
66
P
77
end
88

99
function _symmetric_kappamatrix!(
10-
κ::Kernel{T},
10+
κ::Kernel,
1111
P::AbstractMatrix{T₁},
1212
symmetrize::Bool
13-
) where {T<:Real,T<:Real}
13+
) where {T₁<:Real}
1414
if !((n = size(P,1)) == size(P,2))
1515
throw(DimensionMismatch("Pairwise matrix must be square."))
1616
end
@@ -38,7 +38,12 @@ function kernelmatrix!(
3838
_kappamatrix!(κ, pairwise!(K, metric(κ), X, Y, dims=obsdim))
3939
end
4040

41-
41+
"""
42+
```
43+
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix; obsdim::Integer=2, symmetrize::Bool=true)
44+
```
45+
In-place version of `kernelmatrix` where pre-allocated matrix `K` will be overwritten with the kernel matrix.
46+
"""
4247
function kernelmatrix!(
4348
K::Matrix{T₁},
4449
κ::Kernel{T},
@@ -70,24 +75,33 @@ function kernel(
7075
obsdim::Int = defaultobs
7176
) where {T,T₁<:Real,T₂<:Real}
7277
# TODO Verify dimensions
73-
kappa(κ, evaluate(metric(κ),x,y))
78+
kappa(κ, evaluate(metric(κ),transform(κ,x),transform(κ,y)))
7479
end
7580

7681
"""
7782
```
78-
kernelmatrix(κ::Kernel, X::Matrix ; obsdim::Int=2, symmetrize::Bool)
83+
kernelmatrix(κ::Kernel, X::Matrix ; obsdim::Int=2, symmetrize::Bool=true)
7984
```
8085
Calculate the kernel matrix of `X` with respect to kernel `κ`.
8186
"""
8287
function kernelmatrix(
83-
κ::Kernel{T},
84-
X::AbstractMatrix{T₁};
88+
κ::Kernel{T,<:Transform{A}},
89+
X::AbstractMatrix;
8590
obsdim::Int = defaultobs,
8691
symmetrize::Bool = true
87-
) where {T,T₁<:Real}
88-
Tₛ = typeof(zero(eltype(X))*zero(T))
89-
m = size(X,obsdim)
90-
return kernelmatrix!(Matrix{promote_float(T,T₁)}(undef,m,m),κ,X,obsdim=obsdim,symmetrize=symmetrize)
92+
) where {T,A}
93+
# Tₖ = typeof(zero(eltype(X))*zero(T))
94+
# m = size(X,obsdim)
95+
K = map(x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),dims=obsdim))
96+
# K = Matrix{Tₖ}(undef,m,m)
97+
# for i in 1:m
98+
# tx = transform(κ,@view X[i,:])
99+
# for j in 1:i
100+
# K[i,j] = kappa(κ,kernel(κ,tx,transform(@view X[j,:])))
101+
# end
102+
# end
103+
return K
104+
# return kernelmatrix!(Matrix{Tₖ}(undef,m,m),κ,X,obsdim=obsdim,symmetrize=symmetrize)
91105
end
92106

93107
"""
@@ -102,10 +116,10 @@ function kernelmatrix(
102116
Y::AbstractMatrix{T₂};
103117
obsdim=defaultobs
104118
) where {T,T₁<:Real,T₂<:Real}
105-
Tₛ = typeof(zero(eltype(X))*zero(eltype(Y))*zero(T))
119+
Tₖ = typeof(zero(eltype(X))*zero(eltype(Y))*zero(T))
106120
m = size(X,obsdim)
107121
n = size(Y,obsdim)
108-
kernelmatrix!(Matrix{Tₛ}(undef,m,n),κ,X,Y,obsdim=obsdim)
122+
kernelmatrix!(Matrix{Tₖ}(undef,m,n),κ,X,Y,obsdim=obsdim)
109123
end
110124

111125

@@ -117,8 +131,30 @@ Calculate the diagonal matrix of `X` with respect to kernel `κ`
117131
"""
118132
function kerneldiagmatrix(
119133
κ::Kernel{T},
120-
X::AbstractMatrix{T₁}
134+
X::AbstractMatrix{T₁};
135+
obsdim::Int = 2
136+
) where {T,T₁}
137+
n = size(X,obsdim)
138+
Tₖ = typeof(zero(T)*zero(eltype(X)))
139+
K = Vector{Tₖ}(undef,n)
140+
kerneldiagmatrix!(K,κ,X,obsdim=obsdim)
141+
return K
142+
end
143+
144+
function kerneldiagmatrix!(
145+
K::AbstractVector{T₁},
146+
κ::Kernel{T},
147+
X::AbstractMatrix{T₂};
148+
obsdim::Int = 2
121149
) where {T,T₁,T₂}
122-
@error "Not implemented yet"
123-
#TODO
150+
if obsdim == 1
151+
for i in eachindex(K)
152+
@inbounds @views K[i] = kernel(κ, X[i,:],X[i,:])
153+
end
154+
else
155+
for i in eachindex(K)
156+
@inbounds @views K[i] = kernel(κ,X[:,i],X[:,i])
157+
end
158+
end
159+
return K
124160
end

src/kernels/squaredexponential.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
@doc raw"""
1+
"""
22
SquaredExponentialKernel([α=1])
33
44
The squared exponential kernel is an isotropic Mercer kernel given by the formula:
55
66
```
7-
κ(x,y) = exp(α‖x-y‖²) α > 0
7+
κ(x,y) = exp(-‖x-y‖²)
88
```
99
1010
where `α` is a positive scaling parameter. See also [`ExponentialKernel`](@ref) for a
@@ -23,29 +23,31 @@ julia> SquaredExponentialKernel([2.0,3.0])
2323
SquaredExponentialKernel{Float64,Array{Float64}}(1.0)
2424
```
2525
"""
26-
struct SquaredExponentialKernel{T<:Real,A} <: Kernel{T}
27-
α::A
26+
struct SquaredExponentialKernel{T,Tr<:Transform} <: Kernel{T,Tr}
27+
transform::Tr
2828
metric::SemiMetric
29-
function SquaredExponentialKernel{T}::A=T(1)) where {A<:Union{Real,AbstractVector{<:Real}},T<:Real}
30-
@check_args(SquaredExponentialKernel, α, all.> zero(T)), "α > 0")
31-
if A <: Real
32-
return new{eltype(A),A}(α,SqEuclidean())
33-
else
34-
return new{eltype(A),A}(α,WeightedSqEuclidean(α))
35-
end
29+
function SquaredExponentialKernel{T,Tr}(transform::Tr) where {T,Tr<:Transform}
30+
return new{T,Tr}(transform,SqEuclidean())
3631
end
3732
end
3833

39-
function SquaredExponentialKernel::Union{T,AbstractVector{T}}=1.0) where {T<:Real}
40-
SquaredExponentialKernel{promote_float(T)})
34+
function SquaredExponentialKernel::T=1.0) where {T<:Real}
35+
SquaredExponentialKernel{T,ScaleTransform{T}}(ScaleTransform(α))
4136
end
4237

43-
@inline kappa::SquaredExponentialKernel{T,<:Real}, d²::Real) where {T} = exp(-κ.α*d²)
44-
@inline kappa::SquaredExponentialKernel{T}, d²::Real) where {T} = exp(-d²)
38+
function SquaredExponentialKernel::A) where {A<:AbstractVector{<:Real}}
39+
SquaredExponentialKernel{eltype(A),ScaleTransform{A}}(ScaleTransform(α))
40+
end
4541

46-
function convert(
47-
::Type{K},
48-
κ::SquaredExponentialKernel
49-
) where {K>:SquaredExponentialKernel{T,A} where {T,A}}
50-
return SquaredExponentialKernel{T}(T.(κ.α))
42+
function SquaredExponentialKernel(t::T) where {T<:Transform}
43+
SquaredExponentialKernel{eltype(t),T}(t)
5144
end
45+
46+
@inline kappa::SquaredExponentialKernel, d²::Real) where {T} = exp(-d²)
47+
48+
# function convert(
49+
# ::Type{K},
50+
# κ::SquaredExponentialKernel
51+
# ) where {K>:SquaredExponentialKernel{T,A} where {T,A}}
52+
# return SquaredExponentialKernel{T}(T.(κ.α))
53+
# end

src/transform/transform.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
abstract type Transform{T} end
2+
3+
struct TransformChain{T} <: Transform{T}
4+
end
5+
6+
7+
8+
struct InputTransform{T} <: Transform{T}
9+
10+
end
11+
12+
struct ScaleTransform{T<:Union{Real,AbstractVector{<:Real}}} <: Transform{T}
13+
s::T
14+
end
15+
16+
17+
function ScaleTransform(s::T=1.0) where {T<:Real}
18+
@check_args(ScaleTransform, s, s > zero(T), "s > 0")
19+
ScaleTransform{T}(s)
20+
end
21+
22+
function ScaleTransform(s::T,dims::Integer) where {T<:Real}
23+
@check_args(ScaleTransform, s, s > zero(T), "s > 0")
24+
ScaleTransform{Vector{T}}(fill(s,dims))
25+
end
26+
27+
function ScaleTransform(s::A) where {A<:AbstractVector{<:Real}}
28+
@check_args(ScaleTransform, s, all(s.>zero(eltype(A))), "s > 0")
29+
ScaleTransform{A}(s)
30+
end
31+
32+
transform(t::ScaleTransform{<:AbstractVector{<:Real}},x::AbstractVector{<:Real}) = t.s.*x
33+
transform(t::ScaleTransform{<:AbstractVector{<:Real}},X::AbstractMatrix{<:Real},obsdim::Int) = obsdim == 1 ? t.s'.*X : t.s.*X
34+
35+
transform(t::ScaleTransform{<:Real},x::AbstractVecOrMat) = t.s*x

test/constructors.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ vl = [l,l]
77
## SquaredExponentialKernel
88
@testset "SquaredExponentialKernel" begin
99
@test KernelFunctions.metric(SquaredExponentialKernel(l)) == SqEuclidean()
10-
@test KernelFunctions.metric(SquaredExponentialKernel(vl)) == WeightedSqEuclidean(vl)
10+
@test KernelFunctions.transform(SquaredExponentialKernel(l)) == ScaleTransform(l)
11+
@test KernelFunctions.transform(SquaredExponentialKernel(vl)) == ScaleTransform(vl)
1112
end
1213

1314
SquaredExponentialKernel(l)

0 commit comments

Comments
 (0)