Skip to content

Commit c3554df

Browse files
committed
Cleaned up the kernelmatrix file
1 parent 3b3eb63 commit c3554df

File tree

3 files changed

+41
-74
lines changed

3 files changed

+41
-74
lines changed
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11

2-
"""Get method for the kernel metric"""
32
@inline metric::Kernel) = κ.metric
4-
"""Apply functions of a kernel on a distance"""
5-
# @inline (κ::K)(d::Real) where {K<:Kernel} = kappa(κ,d)
3+
@inline::K)(d::Real) where {K<:Kernel} = kappa(κ,d)
4+
5+
### Transform generics
66

77
@inline transform::Kernel) = κ.transform
88
@inline transform::Kernel,x::AbstractVecOrMat) = transform.transform,x)
99
@inline transform::Kernel,x::AbstractVecOrMat,obsdim::Int) = transform.transform,x,obsdim)
1010

11-
# @inline (κ::Kernel)(x::AbstractVector{<:Real},y::AbstractVector{<:Real}) = kappa(κ,evaluate(κ.(metric),x,y))
11+
@inline::Kernel)(x::AbstractVector{<:Real},y::AbstractVector{<:Real}) = kernel(κ,evaluate(κ.(metric),x,y))

src/kernelmatrix.jl

Lines changed: 26 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,37 @@
1-
2-
function _kappamatrix!::Kernel, P::AbstractMatrix{T₁}) where {T₁<:Real}
3-
for i in eachindex(P)
4-
@inbounds P[i] = kappa(κ, P[i])
5-
end
6-
P
7-
end
8-
9-
function _symmetric_kappamatrix!(
10-
κ::Kernel,
11-
P::AbstractMatrix{T₁},
12-
symmetrize::Bool
13-
) where {T₁<:Real}
14-
if !((n = size(P,1)) == size(P,2))
15-
throw(DimensionMismatch("Pairwise matrix must be square."))
16-
end
17-
for j = 1:n, i = (1:j)
18-
@inbounds P[i,j] = kappa(κ, P[i,j])
19-
end
20-
symmetrize ? LinearAlgebra.copytri!(P, 'U') : P
21-
end
22-
23-
241
"""
252
```
26-
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix, Y::Matrix; obsdim::Integer=2)
3+
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix; obsdim::Integer=2, symmetrize::Bool=true)
274
```
285
In-place version of `kernelmatrix` where pre-allocated matrix `K` will be overwritten with the kernel matrix.
296
"""
307
function kernelmatrix!(
31-
K::AbstractMatrix{T₁},
8+
K::Matrix{T₁},
329
κ::Kernel{T},
33-
X::AbstractMatrix{T₂},
34-
Y::AbstractMatrix{T₃};
35-
obsdim::Int = defaultobs
36-
) where {T,T₁,T₂,T₃}
37-
#TODO Check dimension consistency
38-
_kappamatrix!(κ, pairwise!(K, metric(κ), X, Y, dims=obsdim))
10+
X::AbstractMatrix{T₂};
11+
obsdim::Int = defaultobs,
12+
symmetrize::Bool = true
13+
) where {T,T₁<:Real,T₂<:Real}
14+
@assert check_dims(K,X,X,obsdim) "Dimensions of the target array are not consistent with X and Y"
15+
map!(K,x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),dims=obsdim))
3916
end
4017

4118
"""
4219
```
43-
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix; obsdim::Integer=2, symmetrize::Bool=true)
20+
kernelmatrix!(K::Matrix, κ::Kernel, X::Matrix, Y::Matrix; obsdim::Integer=2)
4421
```
4522
In-place version of `kernelmatrix` where pre-allocated matrix `K` will be overwritten with the kernel matrix.
4623
"""
4724
function kernelmatrix!(
48-
K::Matrix{T₁},
25+
K::AbstractMatrix{T₁},
4926
κ::Kernel{T},
50-
X::AbstractMatrix{T₂};
51-
obsdim::Int = defaultobs,
52-
symmetrize::Bool = true
53-
) where {T,T₁<:Real,T₂<:Real}
54-
#TODO Check dimension consistency
55-
_symmetric_kappamatrix!(κ,pairwise!(K, metric(κ), X, dims=obsdim), symmetrize)
27+
X::AbstractMatrix{T₂},
28+
Y::AbstractMatrix{T₃};
29+
obsdim::Int = defaultobs
30+
) where {T,T₁,T₂,T₃}
31+
@assert check_dims(K,X,Y,obsdim) "Dimensions of the target array are not consistent with X and Y"
32+
map!(K,x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),transform(κ,Y,obsdim),dims=obsdim))
5633
end
5734

58-
# Convenience Methods ======================================================================
59-
6035
"""
6136
```
6237
kernel(κ::Kernel, x, y; obsdim=2)
@@ -65,7 +40,7 @@ Apply the kernel `κ` to ``x`` and ``y`` where ``x`` and ``y`` are vectors or sc
6540
some subtype of ``Real``.
6641
"""
6742
function kernel::Kernel{T}, x::Real, y::Real) where {T}
68-
kernel(κ, T(x), T(y))
43+
kernel(κ, [T(x)], [T(y)])
6944
end
7045

7146
function kernel(
@@ -74,7 +49,7 @@ function kernel(
7449
y::AbstractArray{T₂};
7550
obsdim::Int = defaultobs
7651
) where {T,T₁<:Real,T₂<:Real}
77-
# TODO Verify dimensions
52+
@assert length(x) == length(y) "x and y don't have the same dimension!"
7853
kappa(κ, evaluate(metric(κ),transform(κ,x),transform(κ,y)))
7954
end
8055

@@ -83,51 +58,32 @@ end
8358
kernelmatrix(κ::Kernel, X::Matrix ; obsdim::Int=2, symmetrize::Bool=true)
8459
```
8560
Calculate the kernel matrix of `X` with respect to kernel `κ`.
86-
# USED
8761
"""
8862
function kernelmatrix(
8963
κ::Kernel{T,<:Transform},
9064
X::AbstractMatrix;
9165
obsdim::Int = defaultobs,
9266
symmetrize::Bool = true
93-
) where {T,A}
94-
# Tₖ = typeof(zero(eltype(X))*zero(T))
95-
# m = size(X,obsdim)
96-
#WARNING TEMP FIX
97-
# X̂ = transform(κ,X,obsdim)
98-
# K = map(x->kappa(κ,x),pairwise(metric(κ),X̂,X̂,dims=obsdim))
67+
) where {T}
9968
K = map(x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),dims=obsdim))
100-
return K
10169
end
10270

10371
"""
10472
```
10573
kernelmatrix(κ::Kernel, X::Matrix, Y::Matrix; obsdim::Int=2)
10674
```
10775
Calculate the base matrix of `X` and `Y` with respect to kernel `κ`.
108-
# USED
10976
"""
11077
function kernelmatrix(
11178
κ::Kernel{T},
11279
X::AbstractMatrix{T₁},
11380
Y::AbstractMatrix{T₂};
11481
obsdim=defaultobs
11582
) where {T,T₁<:Real,T₂<:Real}
116-
# Tₖ = typeof(zero(eltype(X))*zero(T))
117-
# m = size(X,obsdim)
11883
K = map(x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),transform(κ,Y,obsdim),dims=obsdim))
119-
# K = Matrix{Tₖ}(undef,m,m)
120-
# for i in 1:m
121-
# tx = transform(κ,@view X[i,:])
122-
# for j in 1:i
123-
# K[i,j] = kappa(κ,kernel(κ,tx,transform(@view X[j,:])))
124-
# end
125-
# end
12684
return K
127-
# return kernelmatrix!(Matrix{Tₖ}(undef,m,m),κ,X,obsdim=obsdim,symmetrize=symmetrize)
12885
end
12986

130-
13187
"""
13288
```
13389
kerneldiagmatrix(κ::Kernel, X::Matrix; obsdim::Int=2)
@@ -137,20 +93,20 @@ Calculate the diagonal matrix of `X` with respect to kernel `κ`
13793
function kerneldiagmatrix(
13894
κ::Kernel{T},
13995
X::AbstractMatrix{T₁};
140-
obsdim::Int = 2
96+
obsdim::Int = defaultobs
14197
) where {T,T₁}
142-
n = size(X,obsdim)
143-
Tₖ = typeof(zero(T)*zero(eltype(X)))
144-
K = Vector{Tₖ}(undef,n)
145-
kerneldiagmatrix!(K,κ,X,obsdim=obsdim)
146-
return K
98+
if obsdim == 1
99+
[@views kernel(κ,X[i,:],X[i,:]) for i in 1:size(X,obsdim)]
100+
elseif obsdim == 2
101+
[@views kernel(κ,X[i,:],X[i,:]) for i in 1:size(X,obsdim)]
102+
end
147103
end
148104

149105
function kerneldiagmatrix!(
150106
K::AbstractVector{T₁},
151107
κ::Kernel{T},
152108
X::AbstractMatrix{T₂};
153-
obsdim::Int = 2
109+
obsdim::Int = defaultobs
154110
) where {T,T₁,T₂}
155111
if obsdim == 1
156112
for i in eachindex(K)

src/utils.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,14 @@ function promote_float(Tₖ::DataType...)
1818
T = promote_type(Tₖ...)
1919
return T <: Real ? T : Float64
2020
end
21+
22+
function check_dims(K,X,Y,obsdim)
23+
if size(X,obsdim) == size(Y,obsdim)
24+
if obsdim == 1
25+
return size(K) == (size(X,2),size(Y,2))
26+
elseif obsdim == 2
27+
return size(K) == (size(X,1),size(Y,1))
28+
end
29+
end
30+
return false
31+
end

0 commit comments

Comments
 (0)