Skip to content

Commit dfced70

Browse files
committed
One more fix and tests on validate_inputs and validate_inplace_dims
1 parent 606b3dc commit dfced70

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

src/utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ function vec_of_vecs(X::AbstractMatrix; obsdim::Int = 2)
2020
end
2121
end
2222

23-
dim(x::AbstractVector{<:Real}) = 1
23+
dim(x::AbstractVector{<:AbstractVector{<:Real}}) = length(first(x))
2424
dim(x::AbstractVector{Tuple{Any,Int}}) = 1
2525

2626
"""
@@ -96,7 +96,7 @@ For a transform return its parameters, for a `ChainTransform` return a vector of
9696

9797

9898
function validate_inplace_dims(K::AbstractMatrix, x::AbstractVector, y::AbstractVector)
99-
validate_dims(x, y)
99+
validate_inputs(x, y)
100100
if size(K) != (length(x), length(y))
101101
throw(DimensionMismatch(
102102
"Size of the target matrix K ($(size(K))) not consistent with lengths of " *
@@ -120,7 +120,7 @@ end
120120

121121
validate_inputs(x, y) = nothing
122122

123-
function validate_inputs(x::V, y::V) where {V<:Union{RowVecs, ColVecs}}
123+
function validate_inputs(x::V, y::V) where {V<:Union{RowVecs, ColVecs, AbstractVector{<:AbstractVector{<:Real}}}}
124124
if dim(x) != dim(y)
125125
throw(DimensionMismatch(
126126
"Dimensionality of x ($(dim(x))) not equality to that of y ($(dim(y)))",

test/utils.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,26 @@
7575
@test back(ones(size(X)))[1].X == ones(size(X))
7676
end
7777
end
78+
@testset "input checks" begin
79+
D = 3; D⁻ = 2
80+
N1 = 2; N2 = 3
81+
x = [rand(rng, D) for _ in 1:N1]
82+
x⁻ = [rand(rng, D⁻) for _ in 1:N1]
83+
y = [rand(rng, D) for _ in 1:N2]
84+
xx = [rand(rng, D, D) for _ in 1:N1]
85+
xx⁻ = [rand(rng, D, D⁻) for _ in 1:N1]
86+
yy = [rand(rng, D, D) for _ in 1:N2]
87+
@test_nowarn KernelFunctions.validate_inplace_dims(zeros(N1, N2), x, y)
88+
@test_throws DimensionMismatch KernelFunctions.validate_inplace_dims(zeros(N1, N1), x, y)
89+
@test_throws DimensionMismatch KernelFunctions.validate_inplace_dims(zeros(N1, N2), x⁻, y)
90+
@test_nowarn KernelFunctions.validate_inplace_dims(zeros(N1, N1), x)
91+
@test_nowarn KernelFunctions.validate_inplace_dims(zeros(N1), x)
92+
@test_throws DimensionMismatch KernelFunctions.validate_inplace_dims(zeros(N2), x)
93+
94+
@test_nowarn KernelFunctions.validate_inputs(x, y)
95+
@test_throws DimensionMismatch KernelFunctions.validate_inputs(x⁻, y)
96+
97+
@test_nowarn KernelFunctions.validate_inputs(xx, yy)
98+
@test_nowarn KernelFunctions.validate_inputs(xx⁻, yy)
99+
end
78100
end

0 commit comments

Comments
 (0)