Skip to content

Commit 6f467a3

Browse files
committed
Solved kernel matrix construction issues
1 parent b070c02 commit 6f467a3

File tree

4 files changed

+20
-18
lines changed

4 files changed

+20
-18
lines changed

src/KernelFunctions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ export MaternKernel, Matern32Kernel, Matern52Kernel
77
export LinearKernel, PolynomialKernel
88
export ConstantKernel, WhiteKernel, ZeroKernel
99

10-
export Transform, ScaleTransform
10+
1111

1212
using Distances, LinearAlgebra
1313
using Zygote: @adjoint
@@ -16,7 +16,7 @@ using StatsFuns: logtwo
1616

1717
const defaultobs = 2
1818

19-
include("zygote_rules.jl")
19+
# include("zygote_rules.jl")
2020
include("utils.jl")
2121
include("distances/dotproduct.jl")
2222
include("distances/delta.jl")

src/kernelmatrix.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@ function kernelmatrix!(
88
K::Matrix{T₁},
99
κ::Kernel{T},
1010
X::AbstractMatrix{T₂};
11-
obsdim::Int = defaultobs,
12-
symmetrize::Bool = true
11+
obsdim::Int = defaultobs
1312
) where {T,T₁<:Real,T₂<:Real}
14-
@assert check_dims(K,X,X,obsdim) "Dimensions of the target array are not consistent with X and Y"
13+
@assert check_dims(K,X,X,feature_dim(obsdim),obsdim) "Dimensions of the target array are not consistent with X and Y"
1514
map!(x->kappa(κ,x),K,pairwise(metric(κ),transform(κ,X,obsdim),dims=obsdim))
1615
end
1716

@@ -28,16 +27,15 @@ function kernelmatrix!(
2827
Y::AbstractMatrix{T₃};
2928
obsdim::Int = defaultobs
3029
) where {T,T₁,T₂,T₃}
31-
@assert check_dims(K,X,Y,obsdim) "Dimensions of the target array are not consistent with X and Y"
30+
@assert check_dims(K,X,Y,feature_dim(obsdim),obsdim) "Dimensions $(size(K)) of the target array K are not consistent with X ($(size(X))) and Y ($(size(Y)))"
3231
map!(x->kappa(κ,x),K,pairwise(metric(κ),transform(κ,X,obsdim),transform(κ,Y,obsdim),dims=obsdim))
3332
end
3433

3534
"""
3635
```
3736
kernel(κ::Kernel, x, y; obsdim=2)
3837
```
39-
Apply the kernel `κ` to ``x`` and ``y`` where ``x`` and ``y`` are vectors or scalars of
40-
some subtype of ``Real``.
38+
Apply the kernel `κ` to `x` and `y`.
4139
"""
4240
function kernel::Kernel{T}, x::Real, y::Real) where {T}
4341
kernel(κ, [T(x)], [T(y)])
@@ -58,6 +56,8 @@ end
5856
kernelmatrix(κ::Kernel, X::Matrix ; obsdim::Int=2, symmetrize::Bool=true)
5957
```
6058
Calculate the kernel matrix of `X` with respect to kernel `κ`.
59+
`obsdim=1` means the matrix `X` has size #samples x #dimension
60+
`obsdim=2` means the matrix `X` has size #dimension x #samples
6161
"""
6262
function kernelmatrix(
6363
κ::Kernel{T,<:Transform},
@@ -73,13 +73,16 @@ end
7373
kernelmatrix(κ::Kernel, X::Matrix, Y::Matrix; obsdim::Int=2)
7474
```
7575
Calculate the base matrix of `X` and `Y` with respect to kernel `κ`.
76+
`obsdim=1` means the matrices `X` and `Y` have sizes #samples x #dimension
77+
`obsdim=2` means the matrices `X` and `Y` have size #dimension x #samples
7678
"""
7779
function kernelmatrix(
7880
κ::Kernel{T},
7981
X::AbstractMatrix{T₁},
8082
Y::AbstractMatrix{T₂};
8183
obsdim=defaultobs
8284
) where {T,T₁<:Real,T₂<:Real}
85+
@assert check_dims(X,Y,feature_dim(obsdim),obsdim) "X ($(size(X))) and Y ($(size(Y))) do not have the same number of features on the dimension obsdim : $(feature_dim(obsdim))"
8386
K = map(x->kappa(κ,x),pairwise(metric(κ),transform(κ,X,obsdim),transform(κ,Y,obsdim),dims=obsdim))
8487
return K
8588
end
@@ -89,6 +92,8 @@ end
8992
kerneldiagmatrix(κ::Kernel, X::Matrix; obsdim::Int=2)
9093
```
9194
Calculate the diagonal matrix of `X` with respect to kernel `κ`
95+
`obsdim=1` means the matrix `X` has size #samples x #dimension
96+
`obsdim=2` means the matrix `X` has size #dimension x #samples
9297
"""
9398
function kerneldiagmatrix(
9499
κ::Kernel{T},

src/utils.jl

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,9 @@ function promote_float(Tₖ::DataType...)
1919
return T <: Real ? T : Float64
2020
end
2121

22-
function check_dims(K,X,Y,obsdim)
23-
if size(X,obsdim) == size(Y,obsdim)
24-
if obsdim == 1
25-
return size(K) == (size(X,1),size(Y,1))
26-
elseif obsdim == 2
27-
return size(K) == (size(X,2),size(Y,2))
28-
end
29-
end
30-
return false
31-
end
22+
check_dims(K,X,Y,featdim,obsdim) = check_dims(X,Y,featdim,obsdim) && (size(K) == (size(X,obsdim),size(Y,obsdim)))
23+
24+
check_dims(X,Y,featdim,obsdim) = size(X,featdim) == size(Y,featdim)
25+
26+
27+
feature_dim(obsdim::Int) = obsdim == 1 ? 2 : 1

test/kernelmatrix.jl renamed to test/test_kernelmatrix.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using Distances, LinearAlgebra
2+
using Test
23
using KernelFunctions
34

45
dims = [10,5]

0 commit comments

Comments
 (0)