Skip to content

Commit 2599a55

Browse files
committed
All needed components for kernel matrices
1 parent 16aba4b commit 2599a55

File tree

4 files changed

+96
-40
lines changed

4 files changed

+96
-40
lines changed

src/KernelFunctions.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
module KernelFunctions
22

3+
export kernelmatrix, kernelmatrix!, kappa
4+
export Kernel, SquaredExponentialKernel
5+
36
using Distances, LinearAlgebra
47

58
const defaultobs = 2
6-
abstract type Kernel{T} where {T<:Real} end
9+
abstract type Kernel{T<:Real} end
710

11+
include("utils.jl")
12+
include("common.jl")
813
include("kernelmatrix.jl")
9-
include("kernels/common.jl")
1014

11-
kernels = ("squaredexponential")
15+
kernels = ["squaredexponential"]
1216
for k in kernels
1317
include(joinpath("kernels",k*".jl"))
1418
end

src/common.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""Get method for the kernel metric"""
33
@inline metric::Kernel) = κ.metric
44
"""Apply functions of a kernel on a distance"""
5-
@inline::Kernel)(d::Real) = kappa(κ,d)
5+
# @inline (κ::K)(d::Real) where {K<:Kernel} = kappa(κ,d)
66

77

8-
@inline::Kernel)(x::AbstractVector{<:Real},y::AbstractVector{<:Real}) = kappa(κ,evaluate(κ.(metric),x,y))
8+
# @inline (κ::Kernel)(x::AbstractVector{<:Real},y::AbstractVector{<:Real}) = kappa(κ,evaluate(κ.(metric),x,y))

src/kernelmatrix.jl

Lines changed: 69 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,61 @@
1+
2+
function _kappamatrix!::Kernel{T}, P::AbstractMatrix{T}) where {T<:Real}
3+
for i in eachindex(P)
4+
@inbounds P[i] = kappa(κ, P[i])
5+
end
6+
P
7+
end
8+
9+
function _symmetric_kappamatrix!(
10+
κ::Kernel{T},
11+
P::AbstractMatrix{T},
12+
symmetrize::Bool
13+
) where {T<:Real}
14+
if !((n = size(P,1)) == size(P,2))
15+
throw(DimensionMismatch("Pairwise matrix must be square."))
16+
end
17+
for j = 1:n, i = (1:j)
18+
@inbounds P[i,j] = kappa(κ, P[i,j])
19+
end
20+
symmetrize ? LinearAlgebra.copytri!(P, 'U') : P
21+
end
22+
23+
124
"""
225
```
326
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix, Y::Matrix; obsdim::Integer=2)
427
```
528
In-place version of `kernelmatrix` where pre-allocated matrix `K` will be overwritten with the kernel matrix.
629
"""
730
function kernelmatrix!(
8-
K::Matrix{T},
31+
K::Matrix{T},
932
κ::Kernel{T},
10-
X::AbstractMatrix{T},
11-
Y::AbstractMatrix{T};
12-
obsdim::Integer = defaultobs
13-
) where {T<:Real}
14-
basematrix!(σ, K, basefunction(κ), κ.α, X, Y)
15-
kappamatrix!(κ, K)
33+
X::AbstractMatrix{T},
34+
Y::AbstractMatrix{T};
35+
obsdim::Int = defaultobs
36+
) where {T,T₁,T₂,T₃}
37+
#TODO Check dimension consistency
38+
_kappamatrix!(κ, pairwise!(K,metric(κ), X, Y, dims=obsdim))
1639
end
1740

18-
function kernelmatrix(
41+
42+
function kernelmatrix!(
43+
K::Matrix{T₁},
1944
κ::Kernel{T},
20-
X::AbstractMatrix{T};
45+
X::AbstractMatrix{T};
2146
obsdim::Int = defaultobs,
2247
symmetrize::Bool = true
23-
) where {T<:Real}
24-
return symmetric_kappamatrix!(κ,pairwise(basefunction(κ),X,dims=obsdim),symmetrize)
48+
) where {T,T₁<:Real,T₂<:Real}
49+
#TODO Check dimension consistency
50+
_symmetric_kappamatrix!(κ,pairwise!(metric(κ),X,dims=obsdim),symmetrize)
2551
end
2652

27-
function kernelmatrix(
28-
κ::Kernel{T},
29-
X::AbstractMatrix{T},
30-
Y::AbstractMatrix{T};
31-
obsdim::Int = defaultobs
32-
) where {T<:Real}
33-
kappamatrix!(κ, pairwise(basefunction(κ), X, Y, dims=obsdim))
34-
end
35-
36-
3753
# Convenience Methods ======================================================================
3854

3955
"""
40-
kernel(κ::Kernel, x, y)
41-
56+
```
57+
kernel(κ::Kernel, x, y; obsdim=2)
58+
```
4259
Apply the kernel `κ` to ``x`` and ``y`` where ``x`` and ``y`` are vectors or scalars of
4360
some subtype of ``Real``.
4461
"""
@@ -48,11 +65,12 @@ end
4865

4966
function kernel(
5067
κ::Kernel{T},
51-
x::AbstractArray{T1},
52-
y::AbstractArray{T2};
68+
x::AbstractArray{T₁},
69+
y::AbstractArray{T₂};
5370
obsdim::Int = defaultobs
54-
) where {T,T1<:Real,T2<:Real}
55-
kappamatrix!(κ, pairwise(metric(κ),X,Y,dims=obsdim))
71+
) where {T,T₁<:Real,T₂<:Real}
72+
# TODO Verify dimensions
73+
_kappamatrix!(κ, pairwise(metric(κ),X,Y,dims=obsdim))
5674
end
5775

5876
"""
@@ -63,23 +81,39 @@ Calculate the kernel matrix of `X` with respect to kernel `κ`.
6381
"""
6482
function kernelmatrix(
6583
κ::Kernel{T},
66-
X::AbstractMatrix{T1};
84+
X::AbstractMatrix{T₁};
6785
obsdim::Int = defaultobs,
6886
symmetrize::Bool = true
69-
) where {T,T1}
70-
return symmetric_kappamatrix!(κ,pairwise(basefunction(κ),X,dims=obsdim),symmetrize)
87+
) where {T,T₁<:Real}
88+
return _symmetric_kappamatrix!(κ,pairwise(metric(κ),X,dims=obsdim),symmetrize)
7189
end
7290

7391
"""
92+
```
7493
kernelmatrix(κ::Kernel, X::Matrix, Y::Matrix; obsdim::Int=2)
75-
94+
```
7695
Calculate the base matrix of `X` and `Y` with respect to kernel `κ`.
7796
"""
7897
function kernelmatrix(
7998
κ::Kernel{T},
80-
X::AbstractMatrix{T1},
81-
Y::AbstractMatrix{T2};
99+
X::AbstractMatrix{T₁},
100+
Y::AbstractMatrix{T₂};
82101
obsdim=defaultobs
83-
) where {T,T1,T2}
84-
kappamatrix!(κ, pairwise(basefunction(κ), X, Y, dims=dim(σ)))
102+
) where {T,T₁<:Real,T₂<:Real}
103+
_kappamatrix!(κ, pairwise(metric(κ), X, Y, dims=obsdim))
104+
end
105+
106+
107+
"""
108+
```
109+
kerneldiagmatrix(κ::Kernel, X::Matrix; obsdim::Int=2)
110+
```
111+
Calculate the diagonal matrix of `X` with respect to kernel `κ`
112+
"""
113+
function kerneldiagmatrix(
114+
κ::Kernel{T},
115+
X::AbstractMatrix{T₁}
116+
) where {T,T₁,T₂}
117+
@error "Not implemented yet"
118+
#TODO
85119
end

src/utils.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
## Macro for checking
2+
macro check_args(K, param, cond, desc=string(cond))
3+
quote
4+
if !($(esc(cond)))
5+
throw(ArgumentError(string(
6+
$(string(K)), ": ", $(string(param)), " = ", $(esc(param)), " does not ",
7+
"satisfy the constraint ", $(string(desc)), ".")))
8+
end
9+
end
10+
end
11+
12+
function promote_float(Tₖ::DataType...)
13+
if length(Tₖ) == 0
14+
return Float64
15+
end
16+
T = promote_type(Tₖ...)
17+
return T <: Real ? T : Float64
18+
end

0 commit comments

Comments
 (0)