Skip to content

Commit 0526c5a

Browse files
committed
Changed the approach for the relaxation of validate_inputs
1 parent dfced70 commit 0526c5a

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

src/utils.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ function vec_of_vecs(X::AbstractMatrix; obsdim::Int = 2)
2020
end
2121
end
2222

23-
dim(x::AbstractVector{<:AbstractVector{<:Real}}) = length(first(x))
24-
dim(x::AbstractVector{Tuple{Any,Int}}) = 1
25-
2623
"""
2724
ColVecs(X::AbstractMatrix)
2825
@@ -94,6 +91,21 @@ For a transform return its parameters, for a `ChainTransform` return a vector of
9491
"""
9592
#params
9693

94+
dim(x) = 0
95+
dim(x::AbstractVector{<:AbstractVector{<:Real}}) = length(first(x))
96+
dim(x::AbstractVector{<:Real}) = 1
97+
dim(x::AbstractVector{Tuple{Any,Int}}) = 1
98+
99+
100+
function validate_inputs(x, y)
101+
if dim(x) != dim(y) # Passes by default if `dim` is not defined
102+
throw(DimensionMismatch(
103+
"Dimensionality of x ($(dim(x))) not equality to that of y ($(dim(y)))",
104+
))
105+
end
106+
return nothing
107+
end
108+
97109

98110
function validate_inplace_dims(K::AbstractMatrix, x::AbstractVector, y::AbstractVector)
99111
validate_inputs(x, y)
@@ -117,13 +129,3 @@ function validate_inplace_dims(K::AbstractVector, x::AbstractVector)
117129
))
118130
end
119131
end
120-
121-
validate_inputs(x, y) = nothing
122-
123-
function validate_inputs(x::V, y::V) where {V<:Union{RowVecs, ColVecs, AbstractVector{<:AbstractVector{<:Real}}}}
124-
if dim(x) != dim(y)
125-
throw(DimensionMismatch(
126-
"Dimensionality of x ($(dim(x))) not equality to that of y ($(dim(y)))",
127-
))
128-
end
129-
end

test/matrix/kernelmatrix.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ KernelFunctions.kappa(::ToySimpleKernel, d) = exp(-d / 2)
9090

9191
tmp_diag = Vector{Float64}(undef, length(x))
9292
@test kerneldiagmatrix!(tmp_diag, k, x) kerneldiagmatrix(k, x)
93-
@test_throws DimensionMismatch kerneldiagmatrix!(tmp_diag, k, y)
93+
@test_throws DimensionMismatch kerneldiagmatrix!(tmp_diag, k, y)
9494
end
9595
end
9696

0 commit comments

Comments
 (0)