Skip to content

Commit f659d9f

Browse files
authored
Vector indexing for OneElement (#346)
* Vector indexing for OneElement * Handle non-int Integer indices * Restrict to AbstractUnitRanges to avoid repeated indices
1 parent 61af4ca commit f659d9f

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

src/oneelement.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,25 @@ OneElement{T}(inds::Int, sz::Int) where T = OneElement(one(T), inds, sz)
4242

4343
Base.size(A::OneElement) = map(length, A.axes)
4444
Base.axes(A::OneElement) = A.axes
45+
Base.getindex(A::OneElement{T,0}) where {T} = getindex_value(A)
4546
Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, kj::Vararg{Int,N}) where {T,N}
4647
@boundscheck checkbounds(A, kj...)
4748
ifelse(kj == A.ind, A.val, zero(T))
4849
end
50+
const VectorInds = Union{AbstractUnitRange{<:Integer}, Integer} # no index is repeated for these indices
51+
const VectorIndsWithColon = Union{VectorInds, Colon}
52+
# retain the values from Ainds corresponding to the vector indices in inds
53+
_index_shape(Ainds, inds::Tuple{Integer, Vararg{Any}}) = _index_shape(Base.tail(Ainds), Base.tail(inds))
54+
_index_shape(Ainds, inds::Tuple{AbstractVector, Vararg{Any}}) = (Ainds[1], _index_shape(Base.tail(Ainds), Base.tail(inds))...)
55+
_index_shape(::Tuple{}, ::Tuple{}) = ()
56+
Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorIndsWithColon,N}) where {T,N}
57+
I = to_indices(A, inds) # handle Bool, and convert to compatible index types
58+
@boundscheck checkbounds(A, I...)
59+
shape = _index_shape(I, I)
60+
nzind = _index_shape(A.ind, I) .- first.(shape) .+ firstindex.(shape)
61+
containsval = all(in.(A.ind, I))
62+
OneElement(getindex_value(A), containsval ? Int.(nzind) : Int.(lastindex.(shape,1)).+1, axes.(shape,1))
63+
end
4964

5065
"""
5166
nzind(A::OneElement{T,N}) -> CartesianIndex{N}

test/runtests.jl

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2151,10 +2151,12 @@ end
21512151
@test FillArrays.nzind(A) == CartesianIndex()
21522152
@test A == Fill(2, ())
21532153
@test A[] === 2
2154+
@test A[1] === A[1,1] === 2
21542155

21552156
e₁ = OneElement(2, 5)
21562157
@test e₁ == [0,1,0,0,0]
21572158
@test FillArrays.nzind(e₁) == CartesianIndex(2)
2159+
@test e₁[2] === e₁[2,1] === e₁[2,1,1] === 1
21582160
@test_throws BoundsError e₁[6]
21592161

21602162
f₁ = AbstractArray{Float64}(e₁)
@@ -2196,6 +2198,89 @@ end
21962198
@test A[1,1] === A[1,2] === A[2,1] === zero(S)
21972199
end
21982200

2201+
@testset "Vector indexing" begin
2202+
@testset "1D" begin
2203+
A = OneElement(2, 2, 4)
2204+
@test @inferred(A[:]) === @inferred(A[axes(A)...]) === A
2205+
@test @inferred(A[3:4]) isa OneElement{Int,1}
2206+
@test @inferred(A[3:4]) == Zeros(2)
2207+
@test @inferred(A[1:2]) === OneElement(2, 2, 2)
2208+
@test @inferred(A[2:3]) === OneElement(2, 1, 2)
2209+
@test @inferred(A[Base.IdentityUnitRange(2:3)]) isa OneElement{Int,1}
2210+
@test @inferred(A[Base.IdentityUnitRange(2:3)]) == OneElement(2,(2,),(Base.IdentityUnitRange(2:3),))
2211+
@test A[:,:] == reshape(A, size(A)..., 1)
2212+
2213+
@test A[reverse(axes(A,1))] == A[collect(reverse(axes(A,1)))]
2214+
2215+
@testset "repeated indices" begin
2216+
@test A[StepRangeLen(2, 0, 3)] == A[fill(2, 3)]
2217+
end
2218+
2219+
B = OneElement(2, (2,), (Base.IdentityUnitRange(-1:4),))
2220+
@test @inferred(A[:]) === @inferred(A[axes(A)...]) === A
2221+
@test @inferred(A[3:4]) isa OneElement{Int,1}
2222+
@test @inferred(A[3:4]) == Zeros(2)
2223+
@test @inferred(A[2:3]) === OneElement(2, 1, 2)
2224+
2225+
C = OneElement(2, (2,), (Base.OneTo(big(4)),))
2226+
@test @inferred(C[1:4]) === OneElement(2, 2, 4)
2227+
2228+
D = OneElement(2, (2,), (InfiniteArrays.OneToInf(),))
2229+
D2 = D[:]
2230+
@test axes(D2) == axes(D)
2231+
@test D2[2] == D[2]
2232+
D3 = D[axes(D)...]
2233+
@test axes(D3) == axes(D)
2234+
@test D3[2] == D[2]
2235+
end
2236+
@testset "2D" begin
2237+
A = OneElement(2, (2,3), (4,5))
2238+
@test @inferred(A[:,:]) === @inferred(A[axes(A)...]) === A
2239+
@test @inferred(A[:,1]) isa OneElement{Int,1}
2240+
@test @inferred(A[:,1]) == Zeros(4)
2241+
@test A[:, Int64(1)] === A[:, Int32(1)]
2242+
@test @inferred(A[1,:]) isa OneElement{Int,1}
2243+
@test @inferred(A[1,:]) == Zeros(5)
2244+
@test @inferred(A[:,3]) === OneElement(2, 2, 4)
2245+
@test @inferred(A[2,:]) === OneElement(2, 3, 5)
2246+
@test @inferred(A[1:1,:]) isa OneElement{Int,2}
2247+
@test @inferred(A[1:1,:]) == Zeros(1,5)
2248+
@test @inferred(A[4:4,:]) isa OneElement{Int,2}
2249+
@test @inferred(A[4:4,:]) == Zeros(1,5)
2250+
@test @inferred(A[2:2,:]) === OneElement(2, (1,3), (1,5))
2251+
@test @inferred(A[1:4,:]) === OneElement(2, (2,3), (4,5))
2252+
@test @inferred(A[:,3:3]) === OneElement(2, (2,1), (4,1))
2253+
@test @inferred(A[:,1:5]) === OneElement(2, (2,3), (4,5))
2254+
@test @inferred(A[1:4,1:4]) === OneElement(2, (2,3), (4,4))
2255+
@test @inferred(A[2:4,2:4]) === OneElement(2, (1,2), (3,3))
2256+
@test @inferred(A[2:4,3:4]) === OneElement(2, (1,1), (3,2))
2257+
@test @inferred(A[4:4,5:5]) isa OneElement{Int,2}
2258+
@test @inferred(A[4:4,5:5]) == Zeros(1,1)
2259+
@test @inferred(A[Base.IdentityUnitRange(2:4), :]) isa OneElement{Int,2}
2260+
@test axes(A[Base.IdentityUnitRange(2:4), :]) == (Base.IdentityUnitRange(2:4), axes(A,2))
2261+
@test @inferred(A[:,:,:]) == reshape(A, size(A)...,1)
2262+
2263+
B = OneElement(2, (2,3), (Base.IdentityUnitRange(2:4),Base.IdentityUnitRange(2:5)))
2264+
@test @inferred(B[:,:]) === @inferred(B[axes(B)...]) === B
2265+
@test @inferred(B[:,3]) === OneElement(2, (2,), (Base.IdentityUnitRange(2:4),))
2266+
@test @inferred(B[3:4, 4:5]) isa OneElement{Int,2}
2267+
@test @inferred(B[3:4, 4:5]) == Zeros(2,2)
2268+
b = @inferred(B[Base.IdentityUnitRange(3:4), Base.IdentityUnitRange(4:5)])
2269+
@test b == Zeros(axes(b))
2270+
2271+
C = OneElement(2, (2,3), (Base.OneTo(big(4)), Base.OneTo(big(5))))
2272+
@test @inferred(C[1:4, 1:5]) === OneElement(2, (2,3), Int.(size(C)))
2273+
2274+
D = OneElement(2, (2,3), (InfiniteArrays.OneToInf(), InfiniteArrays.OneToInf()))
2275+
D2 = @inferred D[:,:]
2276+
@test axes(D2) == axes(D)
2277+
@test D2[2,3] == D[2,3]
2278+
D3 = @inferred D[axes(D)...]
2279+
@test axes(D3) == axes(D)
2280+
@test D3[2,3] == D[2,3]
2281+
end
2282+
end
2283+
21992284
@testset "adjoint/transpose" begin
22002285
A = OneElement(3im, (2,4), (4,6))
22012286
@test A' === OneElement(-3im, (4,2), (6,4))

0 commit comments

Comments
 (0)