Skip to content

Commit 764fcd7

Browse files
committed
Implementation of all methods
1 parent 1cb4868 commit 764fcd7

File tree

1 file changed

+33
-6
lines changed

1 file changed

+33
-6
lines changed

src/matrix/kernelmatrix.jl

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ In-place version of [`kernelmatrix`](@ref) where pre-allocated matrix `K` will b
66
"""
77
kernelmatrix!
88

9+
## Wrapper for vector of reals
10+
function kernelmatrix!(
11+
K::AbstractMatrix,
12+
κ::Kernel,
13+
X::AbstractVector{<:Real}
14+
)
15+
kernelmatrix!(K, κ, reshape(X, 1, :), obsdim = 2)
16+
end
17+
918
function kernelmatrix!(
1019
K::AbstractMatrix,
1120
κ::SimpleKernel,
@@ -27,9 +36,9 @@ function kernelmatrix!(
2736
)
2837
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
2938
if obsdim == 1
30-
@compat kernelmatrix!(K, κ, ColVecs(X))
39+
kernelmatrix!(K, κ, ColVecs(X))
3140
else
32-
@compat kernelmatrix!(K, κ, RowVecs(X))
41+
kernelmatrix!(K, κ, ColVecs(X'))
3342
end
3443
end
3544

@@ -44,6 +53,16 @@ function kernelmatrix!(
4453
map!(κ, K, X, X')
4554
end
4655

56+
## Wrapper for vector of reals
57+
function kernelmatrix!(
58+
K::AbstractMatrix,
59+
κ::Kernel,
60+
X::AbstractVector{<:Real},
61+
Y::AbstractVector{<:Real}
62+
)
63+
kernelmatrix!(K, κ, reshape(X, 1, :), reshape(Y, 1, :), obsdim = 2)
64+
end
65+
4766
function kernelmatrix!(
4867
K::AbstractMatrix,
4968
κ::SimpleKernel,
@@ -108,8 +127,6 @@ function kernelmatrix(κ::Kernel, X::AbstractVector, Y::AbstractVector)
108127
κ.(X, Y')
109128
end
110129

111-
112-
113130
function kernelmatrix::SimpleKernel, X::AbstractMatrix; obsdim::Int = defaultobs)
114131
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
115132
K = map(x -> kappa(κ, x), pairwise(metric(κ), X, dims = obsdim))
@@ -124,6 +141,14 @@ function kernelmatrix(κ::Kernel, X::AbstractMatrix; obsdim::Int = defaultobs)
124141
end
125142
end
126143

144+
function kernelmatrix(
145+
κ::Kernel,
146+
X::AbstractVector{<:Real},
147+
Y::AbstractMatrix{<:Real}
148+
)
149+
kernelmatrix(κ, reshape(X, 1, :), reshape(Y, 1, :), obsdim = 2)
150+
end
151+
127152
function kernelmatrix(
128153
κ::SimpleKernel,
129154
X::AbstractMatrix,
@@ -147,16 +172,18 @@ Calculate the diagonal matrix of `X` with respect to kernel `κ`
147172
`obsdim = 1` means the matrix `X` has size #samples x #dimension
148173
`obsdim = 2` means the matrix `X` has size #dimension x #samples
149174
"""
175+
kerneldiagmatrix
176+
150177
function kerneldiagmatrix(
151178
κ::Kernel,
152179
X::AbstractMatrix;
153180
obsdim::Int = defaultobs
154181
)
155182
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
156183
if obsdim == 1
157-
@compat kerneldiagmatrix(κ, ColVecs(X)) #[@views _kernel(κ,X[i,:],X[i,:]) for i in 1:size(X,obsdim)]
184+
kerneldiagmatrix(κ, ColVecs(X)) #[@views _kernel(κ,X[i,:],X[i,:]) for i in 1:size(X,obsdim)]
158185
elseif obsdim == 2
159-
@compat kerneldiagmatrix(κ, RowVecs(X)) #[@views _kernel(κ,X[:,i],X[:,i]) for i in 1:size(X,obsdim)]
186+
kerneldiagmatrix(κ, ColVecs(X')) #[@views _kernel(κ,X[:,i],X[:,i]) for i in 1:size(X,obsdim)]
160187
end
161188
end
162189

0 commit comments

Comments
 (0)