Skip to content

Commit 8551780

Browse files
committed
Removed unnecessary functions and removed "view" tests
1 parent 30f1c96 commit 8551780

File tree

2 files changed

+15
-19
lines changed

2 files changed

+15
-19
lines changed

src/utils.jl

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,14 @@ end
1515
1616
A lightweight box for an `AbstractMatrix` to make it behave like a vector of vectors.
1717
"""
18-
struct ColVecs{T, TX<:AbstractMatrix{T}} <: AbstractVector{Vector{T}}
18+
struct ColVecs{T, TX<:AbstractMatrix{T}} <: AbstractVector{SubArray}
1919
X::TX
2020
ColVecs(X::TX) where {T, TX<:AbstractMatrix{T}} = new{T, TX}(X)
2121
end
2222

23-
Base.:(==)(D1::ColVecs, D2::ColVecs) = D1.X == D2.X
2423
Base.size(D::ColVecs) = (size(D.X, 2),)
25-
Base.length(D::ColVecs) = size(D.X, 2)
26-
Base.getindex(D::ColVecs, n::Int) = D.X[:, n]
27-
Base.getindex(D::ColVecs, n::CartesianIndex{1}) = getindex(D, n[1])
28-
Base.getindex(D::ColVecs, n) = ColVecs(D.X[:, n])
29-
Base.view(D::ColVecs, n::Int) = view(D.X, :, n)
30-
Base.view(D::ColVecs, n) = ColVecs(view(D.X, :, n))
31-
Base.eltype(::Type{<:ColVecs{T}}) where T = Vector{T}
32-
Base.zero(D::ColVecs) = ColVecs(zero(D.X))
33-
Base.iterate(D::ColVecs) = (view(D.X, :, 1), 2)
34-
Base.iterate(D::ColVecs, state) = state > length(D) ? nothing : (view(D.X, :, state), state + 1)
35-
24+
Base.getindex(D::ColVecs, i::Int) = view(D.X, :, i)
25+
Base.getindex(D::ColVecs, i::AbstractVector{Int}) = ColVecs(view(D.X, :, i))
3626

3727
# Take highest Float among possibilities
3828
# function promote_float(Tₖ::DataType...)

test/utils.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
using Test
2+
using KernelFunctions
3+
using Random
4+
using KernelFunctions: ColVecs
5+
rng, N, D = MersenneTwister(123456), 10, 2
6+
x, X = randn(rng, N), randn(rng, D, N)
7+
DX = ColVecs(X)
8+
9+
110
@testset "utils" begin
211
using KernelFunctions: ColVecs
312
rng, N, D = MersenneTwister(123456), 10, 2
@@ -9,16 +18,13 @@
918
@test DX == DX
1019
@test size(DX) == (N,)
1120
@test length(DX) == N
12-
@test getindex(DX, 5) isa Vector
21+
@test getindex(DX, 5) isa AbstractVector
1322
@test getindex(DX, 5) == X[:, 5]
1423
@test getindex(DX, 1:2:6) isa ColVecs
1524
@test getindex(DX, 1:2:6) == ColVecs(X[:, 1:2:6])
16-
@test view(DX, 4) isa AbstractVector
17-
@test view(DX, 4) == view(X, :, 4)
18-
@test view(DX, 1:2:4) isa ColVecs
19-
@test view(DX, 1:2:4) == ColVecs(view(X, :, 1:2:4))
20-
@test eltype(DX) == Vector{Float64}
2125
@test eachindex(DX) == 1:N
26+
@test first(DX) == X[:, 1]
27+
2228

2329
let
2430
@test Zygote.pullback(ColVecs, X)[1] == DX

0 commit comments

Comments
 (0)