Skip to content

Commit 2315bfe

Browse files
committed
Reworked dim and added tests for it
1 parent d2b5a06 commit 2315bfe

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,9 @@ For a transform return its parameters, for a `ChainTransform` return a vector of
9292
#params
9393

9494
dim(x) = 0 # This is the passes-by-default choice. For a proper check, implement `KernelFunctions.dim` for your datatype.
95+
dim(x::AbstractVector) = dim(first(x))
9596
dim(x::AbstractVector{<:AbstractVector{<:Real}}) = length(first(x))
9697
dim(x::AbstractVector{<:Real}) = 1
97-
dim(x::AbstractVector{Tuple{Any,Int}}) = 1
9898

9999

100100
function validate_inputs(x, y)

test/utils.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@
8484
xx = [rand(rng, D, D) for _ in 1:N1]
8585
xx⁻ = [rand(rng, D, D⁻) for _ in 1:N1]
8686
yy = [rand(rng, D, D) for _ in 1:N2]
87+
88+
@test KernelFunctions.dim("string") == 0
89+
@test KernelFunctions.dim(["string", "string2"]) == 0
90+
@test KernelFunctions.dim(rand(rng, 4)) == 1
91+
@test KernelFunctions.dim(x) == D
92+
8793
@test_nowarn KernelFunctions.validate_inplace_dims(zeros(N1, N2), x, y)
8894
@test_throws DimensionMismatch KernelFunctions.validate_inplace_dims(zeros(N1, N1), x, y)
8995
@test_throws DimensionMismatch KernelFunctions.validate_inplace_dims(zeros(N1, N2), x⁻, y)

0 commit comments

Comments
 (0)