Skip to content

Commit fd4cfd7

Browse files
committed
Correct version of kernelamtrix
1 parent 32a1e68 commit fd4cfd7

File tree

1 file changed

+15
-19
lines changed

1 file changed

+15
-19
lines changed

src/matrix/kernelmatrix.jl

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,21 @@ end
3030

3131
function kernelmatrix!(
3232
K::AbstractMatrix,
33-
κ::BaseKernel,
33+
κ::Kernel,
3434
X::AbstractMatrix;
3535
obsdim::Int = defaultobs
3636
)
3737
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
3838
if obsdim == 1
39-
kernelmatrix!(K, κ, ColVecs(X))
40-
else
4139
kernelmatrix!(K, κ, ColVecs(X'))
40+
else
41+
kernelmatrix!(K, κ, ColVecs(X))
4242
end
4343
end
4444

4545
function kernelmatrix!(
4646
K::AbstractMatrix,
47-
κ::BaseKernel,
47+
κ::Kernel,
4848
X::AbstractVector
4949
)
5050
if !check_dims(K, X, X, feature_dim(obsdim), obsdim)
@@ -74,27 +74,27 @@ function kernelmatrix!(
7474
if !check_dims(K, X, Y, feature_dim(obsdim), obsdim)
7575
throw(DimensionMismatch("Dimensions $(size(K)) of the target array K are not consistent with X ($(size(X))) and Y ($(size(Y)))"))
7676
end
77-
map!(κ, K, pairwise(metric(κ), X, Y, dims = obsdim))
77+
map!(x -> kappa(κ, x), K, pairwise(metric(κ), X, Y, dims = obsdim))
7878
end
7979

8080
function kernelmatrix!(
8181
K::AbstractMatrix,
82-
κ::BaseKernel,
82+
κ::Kernel,
8383
X::AbstractMatrix,
8484
Y::AbstractMatrix;
8585
obsdim::Int = defaultobs
8686
)
8787
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
8888
if obsdim == 1
89-
@compat kernelmatrix!(K, κ, ColVecs(X), ColVecs(Y))
89+
kernelmatrix!(K, κ, ColVecs(X'), ColVecs(Y'))
9090
else
91-
@compat kernelmatrix!(K, κ, RowVecs(X), RowVecs(Y))
91+
kernelmatrix!(K, κ, ColVecs(X), ColVecs(Y))
9292
end
9393
end
9494

9595
function kernelmatrix!(
9696
K::AbstractMatrix,
97-
κ::BaseKernel,
97+
κ::Kernel,
9898
X::AbstractVector,
9999
Y::AbstractVector
100100
)
@@ -135,9 +135,9 @@ end
135135
function kernelmatrix::Kernel, X::AbstractMatrix; obsdim::Int = defaultobs)
136136
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
137137
if obsdim == 1
138-
kernelmatrix(κ, ColVecs(X))
138+
kernelmatrix(κ, ColVecs(X'))
139139
else
140-
kernelmatrix(κ, RowVecs(X))
140+
kernelmatrix(κ, ColVecs(X))
141141
end
142142
end
143143

@@ -181,9 +181,9 @@ function kerneldiagmatrix(
181181
)
182182
@assert obsdim [1,2] "obsdim should be 1 or 2 (see docs of kernelmatrix))"
183183
if obsdim == 1
184-
kerneldiagmatrix(κ, ColVecs(X))
185-
elseif obsdim == 2
186184
kerneldiagmatrix(κ, ColVecs(X'))
185+
else
186+
kerneldiagmatrix(κ, ColVecs(X))
187187
end
188188
end
189189

@@ -207,13 +207,9 @@ function kerneldiagmatrix!(
207207
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
208208
end
209209
if obsdim == 1
210-
for i in eachindex(K)
211-
@inbounds @views K[i] = κ(X[i,:], X[i,:])
212-
end
210+
kerneldiagmatrix!(K, κ, ColVecs(X'))
213211
else
214-
for i in eachindex(K)
215-
@inbounds @views K[i] = κ(X[:,i], X[:,i])
216-
end
212+
kerneldiagmatrix!(K, κ, ColVecs(X))
217213
end
218214
return K
219215
end

0 commit comments

Comments
 (0)