Skip to content

Commit cbb024a

Browse files
committed
Added tests for the new methods. And corrected bugs
1 parent fd4cfd7 commit cbb024a

File tree

3 files changed

+44
-6
lines changed

3 files changed

+44
-6
lines changed

src/matrix/kernelmatrix.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ function kernelmatrix!(
4747
κ::Kernel,
4848
X::AbstractVector
4949
)
50-
if !check_dims(K, X, X, feature_dim(obsdim), obsdim)
50+
if (size(K, 1) != size(K, 2)) || (length(X) != size(K, 1))
5151
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
5252
end
53-
map!(κ, K, X, X')
53+
K .= κ.(X, X')
5454
end
5555

5656
## Wrapper for vector of reals
@@ -98,7 +98,10 @@ function kernelmatrix!(
9898
X::AbstractVector,
9999
Y::AbstractVector
100100
)
101-
map!(K, κ, X, Y')
101+
if (size(K, 1) != length(X)) || (size(K, 2) != length(Y))
102+
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X)) and Y $(size(Y))"))
103+
end
104+
K .= κ.(X, Y')
102105
end
103106

104107
"""
@@ -162,6 +165,15 @@ function kernelmatrix(
162165
_kernelmatrix(κ, X, Y, obsdim)
163166
end
164167

168+
function kernelmatrix::Kernel, X::AbstractMatrix, Y::AbstractMatrix; obsdim::Int = defaultobs)
169+
@assert obsdim [1, 2] "obsdim should be 1 or 2 (see docs of `kernelmatrix`))"
170+
if obsdim == 1
171+
kernelmatrix(κ, ColVecs(X'), ColVecs(Y'))
172+
else
173+
kernelmatrix(κ, ColVecs(X), ColVecs(Y))
174+
end
175+
end
176+
165177
@inline _kernelmatrix::SimpleKernel, X, Y, obsdim) =
166178
map(x -> kappa(κ, x), pairwise(metric(κ), X, Y, dims = obsdim))
167179

src/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ end
2525

2626
Base.size(D::ColVecs) = (size(D.X, 2),)
2727
Base.getindex(D::ColVecs, i::Int) = view(D.X, :, i)
28+
Base.getindex(D::ColVecs, i::CartesianIndex{1}) = view(D.X, :, i)
2829
Base.getindex(D::ColVecs, i) = ColVecs(view(D.X, :, i))
2930

3031
# Take highest Float among possibilities

test/matrix/kernelmatrix.jl

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,50 @@
22

33
rng = MersenneTwister(123456)
44
dims = [10,5]
5-
6-
A = rand(rng, dims...)
7-
B = rand(rng, dims...)
5+
vA = [rand(rng, dims[1]) for _ in 1:dims[2]]
6+
A = hcat(vA...)
7+
vB = [rand(rng, dims[1]) for _ in 1:dims[2]]
8+
B = hcat(vB...)
9+
x = rand(rng, dims[1])
10+
X = collect(reshape(x, 1, :))
11+
y = rand(rng, dims[2])
12+
Y = collect(reshape(y, 1 , :))
13+
KX = zeros(dims[1], dims[1])
14+
KXY = zeros(dims[1], dims[2])
815
C = rand(rng, 8, 9)
916
K = [zeros(dims[1],dims[1]),zeros(dims[2],dims[2])]
1017
Kdiag = [zeros(dims[1]),zeros(dims[2])]
1118
s = rand(rng)
1219
k = SqExponentialKernel()
20+
struct baseSE <: KernelFunctions.BaseKernel end
21+
(k::baseSE)(x, y) = exp(-evaluate(SqEuclidean(), x, y))
22+
newk = baseSE()
1323
kt = transform(SqExponentialKernel(),s)
1424

1525
@testset "Kernel Matrix Operations" begin
1626
@testset "Inplace Kernel Matrix" begin
27+
@test kernelmatrix!(KX, k, x) kernelmatrix!(KX, k, X)
28+
@test kernelmatrix!(KXY, k, x, y) kernelmatrix!(KXY, k, X, Y)
29+
@test kernelmatrix!(K[2], k, vA) kernelmatrix(k, A) atol = 1e-5
30+
@test kernelmatrix!(K[2], k, vA, vB) kernelmatrix(k, A, B) atol = 1e-5
1731
for obsdim in [1,2]
32+
@show obsdim
1833
@test kernelmatrix!(K[obsdim], k, A, B, obsdim = obsdim) == kernelmatrix(k, A, B, obsdim = obsdim)
1934
@test kernelmatrix!(K[obsdim], k, A, obsdim = obsdim) == kernelmatrix(k, A, obsdim = obsdim)
2035
@test kerneldiagmatrix!(Kdiag[obsdim], k, A, obsdim = obsdim) == kerneldiagmatrix(k, A, obsdim = obsdim)
2136
@test_throws DimensionMismatch kernelmatrix!(K[obsdim], k, A, C, obsdim=obsdim)
2237
@test_throws DimensionMismatch kernelmatrix!(K[obsdim], k, C, obsdim=obsdim)
2338
@test_throws DimensionMismatch kerneldiagmatrix!(Kdiag[obsdim], k, C, obsdim=obsdim)
39+
@test kernelmatrix!(K[obsdim], newk, A, B, obsdim = obsdim) kernelmatrix(k, A, B, obsdim = obsdim)
40+
@test kernelmatrix!(K[obsdim], newk, A, obsdim = obsdim) kernelmatrix(k, A, obsdim = obsdim)
41+
@test kerneldiagmatrix!(Kdiag[obsdim], newk, A, obsdim = obsdim) kerneldiagmatrix(k, A, obsdim = obsdim)
2442
end
2543
end
2644
@testset "Kernel matrix" begin
45+
@test kernelmatrix(k, x) kernelmatrix(k, X)
46+
@test kernelmatrix(k, x, y) kernelmatrix(k, X, Y)
47+
@test kernelmatrix(k, vA) kernelmatrix(k, A) atol = 1e-5
48+
@test kernelmatrix(k, vA, vB) kernelmatrix(k, A, B) atol = 1e-5
2749
for obsdim in [1,2]
2850
@test kernelmatrix(k,A,B,obsdim=obsdim) == kappa.(k,pairwise(KernelFunctions.metric(k),A,B,dims=obsdim))
2951
@test kernelmatrix(k,A,obsdim=obsdim) == kappa.(k,pairwise(KernelFunctions.metric(k),A,dims=obsdim))
@@ -32,6 +54,9 @@
3254
@test k(A,obsdim=obsdim) == kernelmatrix(k,A,obsdim=obsdim)
3355
# @test KernelFunctions._kernel(k,1.0,2.0) == KernelFunctions._kernel(k,[1.0],[2.0])
3456
@test_throws DimensionMismatch kernelmatrix(k,A,C,obsdim=obsdim)
57+
@test kernelmatrix!(K[obsdim], newk, A, B, obsdim = obsdim) kernelmatrix(k, A, B, obsdim = obsdim)
58+
@test kernelmatrix!(K[obsdim], newk, A, obsdim = obsdim) kernelmatrix(k, A, obsdim = obsdim)
59+
@test kerneldiagmatrix!(Kdiag[obsdim], newk, A, obsdim = obsdim) kerneldiagmatrix(k, A, obsdim = obsdim)
3560
end
3661
end
3762
@testset "Transformed Kernel Matrix Operations" begin

0 commit comments

Comments
 (0)