Skip to content

Commit 7640550

Browse files
committed
Applied suggested corrections
1 parent d117f25 commit 7640550

File tree

2 files changed

+19
-9
lines changed

2 files changed

+19
-9
lines changed

src/matrix/kernelmatrix.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function kernelmatrix!(
1212
κ::Kernel,
1313
X::AbstractVector{<:Real}
1414
)
15-
kernelmatrix!(K, κ, reshape(X, 1, :), obsdim = 2)
15+
kernelmatrix!(K, κ, ColVecs(reshape(X, 1, :)))
1616
end
1717

1818
function kernelmatrix!(
@@ -47,7 +47,7 @@ function kernelmatrix!(
4747
κ::Kernel,
4848
X::AbstractVector
4949
)
50-
if (size(K, 1) != size(K, 2)) || (length(X) != size(K, 1))
50+
if !check_dims(K, X, X)
5151
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X))"))
5252
end
5353
K .= κ.(X, X')
@@ -60,7 +60,7 @@ function kernelmatrix!(
6060
X::AbstractVector{<:Real},
6161
Y::AbstractVector{<:Real}
6262
)
63-
kernelmatrix!(K, κ, reshape(X, 1, :), reshape(Y, 1, :), obsdim = 2)
63+
kernelmatrix!(K, κ, ColVecs(reshape(X, 1, :)), ColVecs(reshape(Y, 1, :)))
6464
end
6565

6666
function kernelmatrix!(
@@ -98,7 +98,7 @@ function kernelmatrix!(
9898
X::AbstractVector,
9999
Y::AbstractVector
100100
)
101-
if (size(K, 1) != length(X)) || (size(K, 2) != length(Y))
101+
if !check_dims(K, X, Y)
102102
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(size(X)) and Y $(size(Y))"))
103103
end
104104
K .= κ.(X, Y')
@@ -147,9 +147,9 @@ end
147147
function kernelmatrix(
148148
κ::Kernel,
149149
X::AbstractVector{<:Real},
150-
Y::AbstractMatrix{<:Real}
150+
Y::AbstractVector{<:Real}
151151
)
152-
kernelmatrix(κ, reshape(X, 1, :), reshape(Y, 1, :), obsdim = 2)
152+
kernelmatrix(κ, ColVecs(reshape(X, 1, :)), ColVecs(reshape(Y, 1, :)))
153153
end
154154

155155
function kernelmatrix(
@@ -231,5 +231,8 @@ function kerneldiagmatrix!(
231231
κ::Kernel,
232232
X::AbstractVector
233233
)
234+
if length(K) != length(X)
235+
throw(DimensionMismatch("Dimensions of the target array K $(size(K)) are not consistent with X $(length(X))"))
236+
end
234237
map!(κ, K, X, X)
235238
end

src/utils.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,18 @@ Base.getindex(D::ColVecs, i) = ColVecs(view(D.X, :, i))
3737
# return T <: Real ? T : Float64
3838
# end
3939

40-
check_dims(K, X, Y, featdim, obsdim) =
41-
check_dims(X, Y, featdim, obsdim) &&
40+
function check_dims(K, X::AbstractVector, Y::AbstractVector)
41+
size(K) == (length(X), length(Y))
42+
end
43+
44+
45+
## Won't be needed with full ColVecs implementation
46+
function check_dims(K, X::AbstractMatrix, Y::AbstractMatrix, featdim, obsdim)
47+
check_dims(X, Y, featdim) &&
4248
(size(K) == (size(X, obsdim), size(Y, obsdim)))
49+
end
4350

44-
check_dims(X, Y, featdim, obsdim) = size(X, featdim) == size(Y, featdim)
51+
check_dims(X::AbstractMatrix, Y::AbstractMatrix, featdim) = size(X, featdim) == size(Y, featdim)
4552

4653

4754
feature_dim(obsdim::Int) = obsdim == 1 ? 2 : 1

0 commit comments

Comments
 (0)