Skip to content

Commit 8994bc9

Browse files
committed
First type and functions
1 parent c56a328 commit 8994bc9

File tree

5 files changed

+161
-0
lines changed

5 files changed

+161
-0
lines changed

src/KernelFunctions.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
module KernelFunctions
2+
3+
using Distances, LinearAlgebra
4+
5+
const defaultobs = 2
6+
abstract type Kernel{T} where {T<:Real} end
7+
8+
include("kernelmatrix.jl")
9+
include("kernels/common.jl")
10+
11+
kernels = ("squaredexponential")
12+
for k in kernels
13+
include(joinpath("kernels",k*".jl"))
14+
end
15+
16+
end

src/common.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
2+
"""Get method for the kernel metric"""
3+
@inline metric::Kernel) = κ.metric
4+
"""Apply functions of a kernel on a distance"""
5+
@inline::Kernel)(d::Real) = kappa(κ,d)
6+
7+
8+
@inline::Kernel)(x::AbstractVector{<:Real},y::AbstractVector{<:Real}) = kappa(κ,evaluate(κ.(metric),x,y))

src/kernelmatrix.jl

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""
2+
```
3+
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix, Y::Matrix; obsdim::Integer=2)
4+
```
5+
In-place version of `kernelmatrix` where pre-allocated matrix `K` will be overwritten with the kernel matrix.
6+
"""
7+
function kernelmatrix!(
8+
K::Matrix{T},
9+
κ::Kernel{T},
10+
X::AbstractMatrix{T},
11+
Y::AbstractMatrix{T};
12+
obsdim::Integer = defaultobs
13+
) where {T<:Real}
14+
basematrix!(σ, K, basefunction(κ), κ.α, X, Y)
15+
kappamatrix!(κ, K)
16+
end
17+
18+
function kernelmatrix(
19+
κ::Kernel{T},
20+
X::AbstractMatrix{T};
21+
obsdim::Int = defaultobs,
22+
symmetrize::Bool = true
23+
) where {T<:Real}
24+
return symmetric_kappamatrix!(κ,pairwise(basefunction(κ),X,dims=obsdim),symmetrize)
25+
end
26+
27+
function kernelmatrix(
28+
κ::Kernel{T},
29+
X::AbstractMatrix{T},
30+
Y::AbstractMatrix{T};
31+
obsdim::Int = defaultobs
32+
) where {T<:Real}
33+
kappamatrix!(κ, pairwise(basefunction(κ), X, Y, dims=obsdim))
34+
end
35+
36+
37+
# Convenience Methods ======================================================================
38+
39+
"""
40+
kernel(κ::Kernel, x, y)
41+
42+
Apply the kernel `κ` to ``x`` and ``y`` where ``x`` and ``y`` are vectors or scalars of
43+
some subtype of ``Real``.
44+
"""
45+
function kernel::Kernel{T}, x::Real, y::Real) where {T}
46+
kernel(κ, T(x), T(y))
47+
end
48+
49+
function kernel(
50+
κ::Kernel{T},
51+
x::AbstractArray{T1},
52+
y::AbstractArray{T2};
53+
obsdim::Int = defaultobs
54+
) where {T,T1<:Real,T2<:Real}
55+
kappamatrix!(κ, pairwise(metric(κ),X,Y,dims=obsdim))
56+
end
57+
58+
"""
59+
```
60+
kernelmatrix(κ::Kernel, X::Matrix ; obsdim::Int=2, symmetrize::Bool)
61+
```
62+
Calculate the kernel matrix of `X` with respect to kernel `κ`.
63+
"""
64+
function kernelmatrix(
65+
κ::Kernel{T},
66+
X::AbstractMatrix{T1};
67+
obsdim::Int = defaultobs,
68+
symmetrize::Bool = true
69+
) where {T,T1}
70+
return symmetric_kappamatrix!(κ,pairwise(basefunction(κ),X,dims=obsdim),symmetrize)
71+
end
72+
73+
"""
74+
kernelmatrix(κ::Kernel, X::Matrix, Y::Matrix; obsdim::Int=2)
75+
76+
Calculate the base matrix of `X` and `Y` with respect to kernel `κ`.
77+
"""
78+
function kernelmatrix(
79+
κ::Kernel{T},
80+
X::AbstractMatrix{T1},
81+
Y::AbstractMatrix{T2};
82+
obsdim=defaultobs
83+
) where {T,T1,T2}
84+
kappamatrix!(κ, pairwise(basefunction(κ), X, Y, dims=dim(σ)))
85+
end

src/kernels/squaredexponential.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
@doc raw"""
2+
SquaredExponentialKernel([α=1])
3+
4+
The squared exponential kernel is an isotropic Mercer kernel given by the formula:
5+
6+
```
7+
κ(x,y) = exp(α‖x-y‖²) α > 0
8+
```
9+
10+
where `α` is a positive scaling parameter. See also [`ExponentialKernel`](@ref) for a
11+
related form of the kernel or [`GammaExponentialKernel`](@ref) for a generalization.
12+
13+
# Examples
14+
15+
```jldoctest; setup = :(using MLKernels)
16+
julia> SquaredExponentialKernel()
17+
SquaredExponentialKernel{Float64}(1.0)
18+
19+
julia> SquaredExponentialKernel(2.0f0)
20+
SquaredExponentialKernel{Float32}(2.0)
21+
```
22+
"""
23+
struct SquaredExponentialKernel{T<:Real,A} <: Kernel{T}
24+
α::A
25+
metric::SemiMetric
26+
function SquaredExponentialKernel{T}::A=T(1)) where {A<:Union{Real,AbstractVector{<:Real}},T<:Real}
27+
@check_args(SquaredExponentialKernel, α, all.> zero(T)), "α > 0")
28+
if A <: Real
29+
return new{eltype(A),A}(α,SqEuclidean())
30+
else
31+
return new{eltype(A),A}(α,WeightedSqEuclidean(α))
32+
end
33+
end
34+
end
35+
36+
function SquaredExponentialKernel::Union{T,AbstractVector{T}}=1.0) where {T<:Real}
37+
SquaredExponentialKernel{promote_float(T)}(α)
38+
end
39+
40+
@inline kappa::SquaredExponentialKernel{T,<:Real}, d²::T) where {T} = exp(-κ.α*d²)
41+
@inline kappa::SquaredExponentialKernel{T}, d²::T) where {T} = exp(-d²)
42+
43+
function convert(
44+
::Type{K},
45+
κ::SquaredExponentialKernel
46+
) where {K>:SquaredExponentialKernel{T,A} where {T,A}}
47+
return SquaredExponentialKernel{T}(T.(κ.α))
48+
end

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
using Test
2+
using KernelFunctions
3+
4+
@test 1+1 == 2

0 commit comments

Comments
 (0)