Skip to content

Commit 7c7daec

Browse files
authored
Fix getindex with additional inds (#312)
* Fix getindex with additional inds * Make`getindex` works the same as `Base` * Fixes and avoid allocating new array if possible
1 parent b41e6a3 commit 7c7daec

File tree

2 files changed

+79
-6
lines changed

2 files changed

+79
-6
lines changed

src/indexing.jl

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,8 @@ to_index(::MyIndexStyle, axis, arg) = ...
183183
"""
184184
to_index(x, i::Slice) = i
185185
to_index(x, ::Colon) = indices(x)
186+
to_index(::LinearIndices{0,Tuple{}}, ::Colon) = Slice(static(1):static(1))
187+
to_index(::CartesianIndices{0,Tuple{}}, ::Colon) = Slice(static(1):static(1))
186188
# logical indexing
187189
to_index(x, i::AbstractArray{Bool}) = LogicalIndex(i)
188190
to_index(x::LinearIndices, i::AbstractArray{Bool}) = LogicalIndex{Int}(i)
@@ -251,10 +253,10 @@ indices calling [`to_axis`](@ref).
251253
end
252254
end
253255
# drop this dimension
254-
to_axes(A, a::Tuple, i::Tuple{<:CanonicalInt,Vararg{Any}}) = to_axes(A, tail(a), tail(i))
256+
to_axes(A, a::Tuple, i::Tuple{<:CanonicalInt,Vararg{Any}}) = to_axes(A, _maybe_tail(a), tail(i))
255257
to_axes(A, a::Tuple, i::Tuple{I,Vararg{Any}}) where {I} = _to_axes(StaticInt(ndims_index(I)), A, a, i)
256258
function _to_axes(::StaticInt{1}, A, axs::Tuple, inds::Tuple)
257-
return (to_axis(first(axs), first(inds)), to_axes(A, tail(axs), tail(inds))...)
259+
return (to_axis(_maybe_first(axs), first(inds)), to_axes(A, _maybe_tail(axs), tail(inds))...)
258260
end
259261
@propagate_inbounds function _to_axes(::StaticInt{N}, A, axs::Tuple, inds::Tuple) where {N}
260262
axes_front, axes_tail = Base.IteratorsMD.split(axs, Val(N))
@@ -268,6 +270,11 @@ end
268270
to_axes(A, ::Tuple{Ax,Vararg{Any}}, ::Tuple{}) where {Ax} = ()
269271
to_axes(A, ::Tuple{}, ::Tuple{}) = ()
270272

273+
_maybe_first(::Tuple{}) = static(1):static(1)
274+
_maybe_first(t::Tuple) = first(t)
275+
_maybe_tail(::Tuple{}) = ()
276+
_maybe_tail(t::Tuple) = tail(t)
277+
271278
"""
272279
to_axis(old_axis, index) -> new_axis
273280
@@ -349,7 +356,9 @@ unsafe_getindex(A::Array, i::CanonicalInt) = Base.arrayref(false, A, Int(i))
349356
end
350357

351358
unsafe_getindex(A::LinearIndices, i::CanonicalInt) = Int(i)
352-
unsafe_getindex(A::CartesianIndices, i::CanonicalInt, ii::Vararg{CanonicalInt}) = CartesianIndex(i, ii...)
359+
unsafe_getindex(A::CartesianIndices{N}, ii::Vararg{CanonicalInt,N}) where {N} = CartesianIndex(ii...)
360+
unsafe_getindex(A::CartesianIndices, ii::Vararg{CanonicalInt}) =
361+
unsafe_getindex(A, Base.front(ii)...)
353362
unsafe_getindex(A::CartesianIndices, i::CanonicalInt) = @inbounds(A[i])
354363

355364
unsafe_getindex(A::ReshapedArray, i::CanonicalInt) = @inbounds(parent(A)[i])
@@ -378,18 +387,37 @@ function unsafe_get_collection(A, inds)
378387
end
379388
_ints2range(x::CanonicalInt) = x:x
380389
_ints2range(x::AbstractRange) = x
390+
# apply _ints2range to front N elements
391+
_ints2range_front(::Val{N}, ind, inds...) where {N} =
392+
(_ints2range(ind), _ints2range_front(Val(N - 1), inds...)...)
393+
_ints2range_front(::Val{0}, ind, inds...) = ()
394+
_ints2range_front(::Val{0}) = ()
395+
# get output shape with given indices
396+
_output_shape(::CanonicalInt, inds...) = _output_shape(inds...)
397+
_output_shape(ind::AbstractRange, inds...) = (length(ind), _output_shape(inds...)...)
398+
_output_shape(::CanonicalInt) = ()
399+
_output_shape(x::AbstractRange) = (length(x),)
381400
@inline function unsafe_get_collection(A::CartesianIndices{N}, inds) where {N}
382401
if (Base.length(inds) === 1 && N > 1) || stride_preserving_index(typeof(inds)) === False()
383402
return Base._getindex(IndexStyle(A), A, inds...)
384403
else
385-
return CartesianIndices(to_axes(A, _ints2range.(inds)))
404+
return reshape(
405+
CartesianIndices(_ints2range_front(Val(N), inds...)),
406+
_output_shape(inds...)
407+
)
386408
end
387409
end
410+
_known_first_isone(ind) = known_first(ind) !== nothing && isone(known_first(ind))
388411
@inline function unsafe_get_collection(A::LinearIndices{N}, inds) where {N}
389412
if Base.length(inds) === 1 && isone(_ndims_index(typeof(inds), static(1)))
390413
return @inbounds(eachindex(A)[first(inds)])
391-
elseif stride_preserving_index(typeof(inds)) === True()
392-
return LinearIndices(to_axes(A, _ints2range.(inds)))
414+
elseif stride_preserving_index(typeof(inds)) === True() &&
415+
reduce_tup(&, map(_known_first_isone, inds))
416+
# create a LinearIndices when first(ind) != 1 is imposable
417+
return reshape(
418+
LinearIndices(_ints2range_front(Val(N), inds...)),
419+
_output_shape(inds...)
420+
)
393421
else
394422
return Base._getindex(IndexStyle(A), A, inds...)
395423
end

test/indexing.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,51 @@ end
9696

9797
@test ArrayInterface.to_axis(axis, axis) === axis
9898
@test ArrayInterface.to_axis(axis, ArrayInterface.indices(axis)) === axis
99+
100+
@test @inferred(ArrayInterface.to_axes(A, (), (inds,))) === (inds,)
101+
end
102+
103+
@testset "getindex with additional inds" begin
104+
A = reshape(1:12, (3, 4))
105+
subA = view(A, :, :)
106+
LA = LinearIndices(A)
107+
CA = CartesianIndices(A)
108+
@test @inferred(ArrayInterface.getindex(A, 1, 1, 1)) == 1
109+
@test @inferred(ArrayInterface.getindex(A, 1, 1, :)) == [1]
110+
@test @inferred(ArrayInterface.getindex(A, 1, 1, 1:1)) == [1]
111+
@test @inferred(ArrayInterface.getindex(A, 1, 1, :, :)) == ones(1, 1)
112+
@test @inferred(ArrayInterface.getindex(A, :, 1, 1)) == 1:3
113+
@test @inferred(ArrayInterface.getindex(A, 2:3, 1, 1)) == 2:3
114+
@test @inferred(ArrayInterface.getindex(A, static(1):2, 1, 1)) == 1:2
115+
@test @inferred(ArrayInterface.getindex(A, :, 1, :)) == reshape(1:3, 3, 1)
116+
@test @inferred(ArrayInterface.getindex(subA, 1, 1, 1)) == 1
117+
@test @inferred(ArrayInterface.getindex(subA, 1, 1, :)) == [1]
118+
@test @inferred(ArrayInterface.getindex(subA, 1, 1, 1:1)) == [1]
119+
@test @inferred(ArrayInterface.getindex(subA, 1, 1, :, :)) == ones(1, 1)
120+
@test @inferred(ArrayInterface.getindex(subA, :, 1, 1)) == 1:3
121+
@test @inferred(ArrayInterface.getindex(subA, 2:3, 1, 1)) == 2:3
122+
@test @inferred(ArrayInterface.getindex(subA, static(1):2, 1, 1)) == 1:2
123+
@test @inferred(ArrayInterface.getindex(subA, :, 1, :)) == reshape(1:3, 3, 1)
124+
@test @inferred(ArrayInterface.getindex(LA, 1, 1, 1)) == 1
125+
@test @inferred(ArrayInterface.getindex(LA, 1, 1, :)) == [1]
126+
@test @inferred(ArrayInterface.getindex(LA, 1, 1, 1:1)) == [1]
127+
@test @inferred(ArrayInterface.getindex(LA, 1, 1, :, :)) == ones(1, 1)
128+
@test @inferred(ArrayInterface.getindex(LA, :, 1, 1)) == 1:3
129+
@test @inferred(ArrayInterface.getindex(LA, 2:3, 1, 1)) == 2:3
130+
@test @inferred(ArrayInterface.getindex(LA, static(1):2, 1, 1)) == 1:2
131+
@test @inferred(ArrayInterface.getindex(LA, :, 1, :)) == reshape(1:3, 3, 1)
132+
@test @inferred(ArrayInterface.getindex(CA, 1, 1, 1)) == CartesianIndex(1, 1)
133+
@test @inferred(ArrayInterface.getindex(CA, 1, 1, :)) == [CartesianIndex(1, 1)]
134+
@test @inferred(ArrayInterface.getindex(CA, 1, 1, 1:1)) == [CartesianIndex(1, 1)]
135+
@test @inferred(ArrayInterface.getindex(CA, 1, 1, :, :)) == fill(CartesianIndex(1, 1), 1, 1)
136+
@test @inferred(ArrayInterface.getindex(CA, :, 1, 1)) ==
137+
reshape(CartesianIndex(1, 1):CartesianIndex(3, 1), 3)
138+
@test @inferred(ArrayInterface.getindex(CA, 2:3, 1, 1)) ==
139+
reshape(CartesianIndex(2, 1):CartesianIndex(3, 1), 2)
140+
@test @inferred(ArrayInterface.getindex(CA, static(1):2, 1, 1)) ==
141+
reshape(CartesianIndex(1, 1):CartesianIndex(2, 1), 2)
142+
@test @inferred(ArrayInterface.getindex(CA, :, 1, :)) ==
143+
reshape(CartesianIndex(1, 1):CartesianIndex(3, 1), 3, 1)
99144
end
100145

101146
@testset "0-dimensional" begin

0 commit comments

Comments
 (0)