Skip to content

Commit 9ccc040

Browse files
authored
Merge pull request #75 from Tokazama/master
Rework UnsafeIndex
2 parents ad711a6 + 8b356b6 commit 9ccc040

File tree

3 files changed

+128
-102
lines changed

3 files changed

+128
-102
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "2.13.3"
3+
version = "2.13.4"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/indexing.jl

Lines changed: 102 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,83 @@
11

22
"""
3-
argdims(::IndexStyle, ::Type{T})
3+
ArrayStyle(::Type{A})
4+
5+
Used to customize the meaning of indexing arguments in the context of a given array `A`.
6+
7+
See also: [`argdims`](@ref), [`UnsafeIndex`](@ref)
8+
"""
9+
abstract type ArrayStyle end
10+
11+
struct DefaultArrayStyle <: ArrayStyle end
12+
13+
ArrayStyle(A) = ArrayStyle(typeof(A))
14+
ArrayStyle(::Type{A}) where {A} = DefaultArrayStyle()
15+
16+
"""
17+
argdims(::ArrayStyle, ::Type{T})
418
519
Whats the dimensionality of the indexing argument of type `T`?
620
"""
7-
argdims(A, x) = argdims(IndexStyle(A), typeof(x))
8-
argdims(s::IndexStyle, x) = argdims(s, typeof(x))
21+
argdims(x, arg) = argdims(x, typeof(arg))
22+
argdims(x, ::Type{T}) where {T} = argdims(ArrayStyle(x), T)
23+
argdims(s::ArrayStyle, arg) = argdims(s, typeof(arg))
924
# single elements initially map to 1 dimension but that dimension is subsequently dropped.
10-
argdims(::IndexStyle, ::Type{T}) where {T} = 0
11-
argdims(::IndexStyle, ::Type{T}) where {T<:Colon} = 1
12-
argdims(::IndexStyle, ::Type{T}) where {T<:AbstractArray} = ndims(T)
13-
argdims(::IndexStyle, ::Type{T}) where {N,T<:CartesianIndex{N}} = N
14-
argdims(::IndexStyle, ::Type{T}) where {N,T<:AbstractArray{CartesianIndex{N}}} = N
15-
argdims(::IndexStyle, ::Type{T}) where {N,T<:AbstractArray{<:Any,N}} = N
16-
argdims(::IndexStyle, ::Type{T}) where {N,T<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = N
17-
@generated function argdims(s::IndexStyle, ::Type{T}) where {N,T<:Tuple{Vararg{<:Any,N}}}
25+
argdims(::ArrayStyle, ::Type{T}) where {T} = 0
26+
argdims(::ArrayStyle, ::Type{T}) where {T<:Colon} = 1
27+
argdims(::ArrayStyle, ::Type{T}) where {T<:AbstractArray} = ndims(T)
28+
argdims(::ArrayStyle, ::Type{T}) where {N,T<:CartesianIndex{N}} = N
29+
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{CartesianIndex{N}}} = N
30+
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{<:Any,N}} = N
31+
argdims(::ArrayStyle, ::Type{T}) where {N,T<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = N
32+
@generated function argdims(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{<:Any,N}}}
1833
e = Expr(:tuple)
1934
for p in T.parameters
2035
push!(e.args, :(ArrayInterface.argdims(s, $p)))
2136
end
2237
Expr(:block, Expr(:meta, :inline), e)
2338
end
2439

40+
"""
41+
UnsafeIndex(::ArrayStyle, ::Type{I})
42+
43+
`UnsafeIndex` controls how indices that have been bounds checked and converted to
44+
native axes' indices are used to return the stored values of an array. For example,
45+
if the indices at each dimension are single integers then `UnsafeIndex(array, inds)` returns
46+
`UnsafeGetElement()`. Conversely, if any of the indices are vectors then `UnsafeGetCollection()`
47+
is returned, indicating that a new array needs to be reconstructed. This method permits
48+
customizing the terminal behavior of the indexing pipeline based on arguments passed
49+
to `ArrayInterface.getindex`. New subtypes of `UnsafeIndex` should define `promote_rule`.
50+
"""
51+
abstract type UnsafeIndex end
52+
53+
struct UnsafeGetElement <: UnsafeIndex end
54+
55+
struct UnsafeGetCollection <: UnsafeIndex end
56+
57+
UnsafeIndex(x, i) = UnsafeIndex(x, typeof(i))
58+
UnsafeIndex(x, ::Type{I}) where {I} = UnsafeIndex(ArrayStyle(x), I)
59+
UnsafeIndex(s::ArrayStyle, i) = UnsafeIndex(s, typeof(i))
60+
UnsafeIndex(::ArrayStyle, ::Type{I}) where {I} = UnsafeGetElement()
61+
UnsafeIndex(::ArrayStyle, ::Type{I}) where {I<:AbstractArray} = UnsafeGetCollection()
62+
63+
Base.promote_rule(::Type{X}, ::Type{Y}) where {X<:UnsafeIndex,Y<:UnsafeGetElement} = X
64+
65+
@generated function UnsafeIndex(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{<:Any,N}}}
66+
if N === 0
67+
return UnsafeGetElement()
68+
else
69+
e = Expr(:call, promote_type)
70+
for p in T.parameters
71+
push!(e.args, :(typeof(ArrayInterface.UnsafeIndex(s, $p))))
72+
end
73+
return Expr(:block, Expr(:meta, :inline), Expr(:call, e))
74+
end
75+
end
76+
77+
# are the indexing arguments provided a linear collection into a multidim collection
78+
is_linear_indexing(A, args::Tuple{Arg}) where {Arg} = argdims(A, Arg) < 2
79+
is_linear_indexing(A, args::Tuple{Arg,Vararg{Any}}) where {Arg} = false
80+
2581
"""
2682
flatten_args(A, args::Tuple{Arg,Vararg{Any}}) -> Tuple
2783
@@ -133,27 +189,15 @@ be accomplished using `to_index(axis, arg)`.
133189
@propagate_inbounds function to_indices(A, args::Tuple)
134190
if can_flatten(A, args)
135191
return to_indices(A, flatten_args(A, args))
192+
elseif is_linear_indexing(A, args)
193+
return (to_index(eachindex(IndexLinear(), A), first(args)),)
136194
else
137195
return to_indices(A, axes(A), args)
138196
end
139197
end
140-
@propagate_inbounds function to_indices(A, args::Tuple{Arg}) where {Arg}
141-
if can_flatten(A, args)
142-
return to_indices(A, flatten_args(A, args))
143-
else
144-
if argdims(IndexStyle(A), Arg) > 1
145-
return to_indices(A, axes(A), args)
146-
else
147-
if ndims(A) === 1
148-
return (to_index(axes(A, 1), first(args)),)
149-
else
150-
return to_indices(A, (eachindex(A),), args)
151-
end
152-
end
153-
end
154-
end
198+
@propagate_inbounds to_indices(A, args::Tuple{}) = to_indices(A, axes(A), ())
155199
@propagate_inbounds function to_indices(A, axs::Tuple, args::Tuple{Arg,Vararg{Any}}) where {Arg}
156-
N = argdims(IndexStyle(A), Arg)
200+
N = argdims(A, Arg)
157201
if N > 1
158202
axes_front, axes_tail = Base.IteratorsMD.split(axs, Val(N))
159203
return (to_multi_index(axes_front, first(args)), to_indices(A, axes_tail, tail(args))...)
@@ -172,8 +216,21 @@ end
172216
end
173217
to_indices(A, axs::Tuple{}, args::Tuple{}) = ()
174218

219+
220+
_multi_check_index(axs::Tuple, arg) = _multi_check_index(axs, axes(arg))
221+
function _multi_check_index(axs::Tuple, arg::AbstractArray{T}) where {T<:CartesianIndex}
222+
return checkindex(Bool, axs, arg)
223+
end
224+
_multi_check_index(::Tuple{}, ::Tuple{}) = true
225+
function _multi_check_index(axs::Tuple, args::Tuple)
226+
if checkindex(Bool, first(axs), first(args))
227+
return _multi_check_index(tail(axs), tail(args))
228+
else
229+
return false
230+
end
231+
end
175232
@propagate_inbounds function to_multi_index(axs::Tuple, arg)
176-
@boundscheck if !Base.checkbounds_indices(Bool, axs, (arg,))
233+
@boundscheck if !_multi_check_index(axs, arg)
177234
throw(BoundsError(axs, arg))
178235
end
179236
return arg
@@ -236,7 +293,6 @@ function unsafe_reconstruct(A::OneTo, data; kwargs...)
236293
end
237294
end
238295
end
239-
240296
function unsafe_reconstruct(A::UnitRange, data; kwargs...)
241297
if can_change_size(A)
242298
return typeof(A)(data)
@@ -248,7 +304,6 @@ function unsafe_reconstruct(A::UnitRange, data; kwargs...)
248304
end
249305
end
250306
end
251-
252307
function unsafe_reconstruct(A::OptionallyStaticUnitRange, data; kwargs...)
253308
if can_change_size(A)
254309
return typeof(A)(data)
@@ -260,7 +315,6 @@ function unsafe_reconstruct(A::OptionallyStaticUnitRange, data; kwargs...)
260315
end
261316
end
262317
end
263-
264318
function unsafe_reconstruct(A::AbstractUnitRange, data; kwargs...)
265319
return static_first(data):static_last(data)
266320
end
@@ -284,7 +338,7 @@ end
284338
to_axes(A, ::Tuple{Ax,Vararg{Any}}, ::Tuple{}) where {Ax} = ()
285339
to_axes(A, ::Tuple{}, ::Tuple{}) = ()
286340
@propagate_inbounds function to_axes(A, axs::Tuple{Ax,Vararg{Any}}, inds::Tuple{I,Vararg{Any}}) where {Ax,I}
287-
N = argdims(IndexStyle(A), I)
341+
N = argdims(A, I)
288342
if N === 0
289343
# drop this dimension
290344
return to_axes(A, tail(axs), tail(inds))
@@ -330,53 +384,15 @@ Changing indexing based on a given argument from `args` should be done through
330384
"""
331385
@propagate_inbounds getindex(A, args...) = unsafe_getindex(A, to_indices(A, args))
332386

333-
"""
334-
UnsafeIndex <: Function
335-
336-
`UnsafeIndex` controls how indices that have been bounds checked and converted to
337-
native axes' indices are used to return the stored values of an array. For example,
338-
if the indices at each dimension are single integers than `UnsafeIndex(inds)` returns
339-
`UnsafeElement()`. Conversely, if any of the indices are vectors then `UnsafeCollection()`
340-
is returned, indicating that a new array needs to be reconstructed. This method permits
341-
customizing the terimnal behavior of the indexing pipeline based on arguments passed
342-
to `ArrayInterface.getindex`
343-
"""
344-
abstract type UnsafeIndex <: Function end
345-
346-
struct UnsafeElement <: UnsafeIndex end
347-
const unsafe_element = UnsafeElement()
348-
349-
struct UnsafeCollection <: UnsafeIndex end
350-
const unsafe_collection = UnsafeCollection()
351-
352-
# 1-arg
353-
UnsafeIndex(x) = UnsafeIndex(typeof(x))
354-
UnsafeIndex(x::UnsafeIndex) = x
355-
UnsafeIndex(::Type{T}) where {T<:Integer} = unsafe_element
356-
UnsafeIndex(::Type{T}) where {T<:AbstractArray} = unsafe_collection
357-
358-
# 2-arg
359-
UnsafeIndex(x::UnsafeIndex, y::UnsafeElement) = x
360-
UnsafeIndex(x::UnsafeElement, y::UnsafeIndex) = y
361-
UnsafeIndex(x::UnsafeElement, y::UnsafeElement) = x
362-
UnsafeIndex(x::UnsafeCollection, y::UnsafeCollection) = x
363-
364-
365-
# tuple
366-
UnsafeIndex(x::Tuple{I}) where {I} = UnsafeIndex(I)
367-
@inline function UnsafeIndex(x::Tuple{I,Vararg{Any}}) where {I}
368-
return UnsafeIndex(UnsafeIndex(I), UnsafeIndex(tail(x)))
369-
end
370-
371387
"""
372388
unsafe_getindex(A, inds)
373389
374390
Indexes into `A` given `inds`. This method assumes that `inds` have already been
375391
bounds checked.
376392
"""
377-
unsafe_getindex(A, inds) = unsafe_getindex(UnsafeIndex(inds), A, inds)
378-
unsafe_getindex(::UnsafeElement, A, inds) = unsafe_get_element(A, inds)
379-
unsafe_getindex(::UnsafeCollection, A, inds) = unsafe_get_collection(A, inds)
393+
unsafe_getindex(A, inds) = unsafe_getindex(UnsafeIndex(A, inds), A, inds)
394+
unsafe_getindex(::UnsafeGetElement, A, inds) = unsafe_get_element(A, inds)
395+
unsafe_getindex(::UnsafeGetCollection, A, inds) = unsafe_get_collection(A, inds)
380396

381397
"""
382398
unsafe_get_element(A::AbstractArray{T}, inds::Tuple) -> T
@@ -389,7 +405,9 @@ function unsafe_get_element(A, inds)
389405
throw(MethodError(unsafe_getindex, (A, inds)))
390406
end
391407
function unsafe_get_element(A::Array, inds)
392-
if inds isa Tuple{Vararg{Int}}
408+
if length(inds) === 0
409+
return Base.arrayref(false, A, 1)
410+
elseif inds isa Tuple{Vararg{Int}}
393411
return Base.arrayref(false, A, inds...)
394412
else
395413
throw(MethodError(unsafe_get_element, (A, inds)))
@@ -443,14 +461,12 @@ end
443461
end
444462
end
445463
@inline function unsafe_get_collection(A::LinearIndices{N}, inds) where {N}
446-
if can_preserve_indices(typeof(inds))
464+
if is_linear_indexing(A, inds)
465+
return @inbounds(eachindex(A)[first(inds)])
466+
elseif can_preserve_indices(typeof(inds))
447467
return LinearIndices(to_axes(A, _ints2range.(inds)))
448468
else
449-
if length(inds) === 1
450-
return @inbounds(eachindex(A)[first(inds)])
451-
else
452-
return Base._getindex(IndexStyle(A), A, inds...)
453-
end
469+
return Base._getindex(IndexStyle(A), A, inds...)
454470
end
455471
end
456472

@@ -474,9 +490,9 @@ end
474490
Sets indices (`inds`) of `A` to `val`. This method assumes that `inds` have already been
475491
bounds checked. This step of the processing pipeline can be customized by
476492
"""
477-
unsafe_setindex!(A, val, inds::Tuple) = unsafe_setindex!(UnsafeIndex(inds), A, val, inds)
478-
unsafe_setindex!(::UnsafeElement, A, val, inds::Tuple) = unsafe_set_element!(A, val, inds)
479-
unsafe_setindex!(::UnsafeCollection, A, val, inds::Tuple) = unsafe_set_collection!(A, val, inds)
493+
unsafe_setindex!(A, val, inds::Tuple) = unsafe_setindex!(UnsafeIndex(A, inds), A, val, inds)
494+
unsafe_setindex!(::UnsafeGetElement, A, val, inds::Tuple) = unsafe_set_element!(A, val, inds)
495+
unsafe_setindex!(::UnsafeGetCollection, A, val, inds::Tuple) = unsafe_set_collection!(A, val, inds)
480496

481497
"""
482498
unsafe_set_element!(A, val, inds::Tuple)
@@ -489,7 +505,9 @@ function unsafe_set_element!(A, val, inds)
489505
throw(MethodError(unsafe_set_element!, (A, val, inds)))
490506
end
491507
function unsafe_set_element!(A::Array{T}, val, inds::Tuple) where {T}
492-
if inds isa Tuple{Vararg{Int}}
508+
if length(inds) === 0
509+
return Base.arrayset(false, A, convert(T, val)::T, 1)
510+
elseif inds isa Tuple{Vararg{Int}}
493511
return Base.arrayset(false, A, convert(T, val)::T, inds...)
494512
else
495513
throw(MethodError(unsafe_set_element!, (A, inds)))

test/indexing.jl

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11

22
@testset "argdims" begin
3-
static_argdims(x) = Val(ArrayInterface.argdims(IndexLinear(), x))
3+
static_argdims(x) = Val(ArrayInterface.argdims(ArrayInterface.DefaultArrayStyle(), x))
44
@test @inferred(static_argdims((1, CartesianIndex(1,2)))) === Val((0, 2))
55
@test @inferred(static_argdims((1, [CartesianIndex(1,2), CartesianIndex(1,3)]))) === Val((0, 2))
66
@test @inferred(static_argdims((1, CartesianIndex((2,2))))) === Val((0, 2))
77
@test @inferred(static_argdims((CartesianIndex((2,2)), :, :))) === Val((2, 1, 1))
88
end
99

10+
@testset "UnsafeIndex" begin
11+
@test @inferred(ArrayInterface.UnsafeIndex(ones(2,2,2), typeof((1,[1,2],1)))) == ArrayInterface.UnsafeGetCollection()
12+
@test @inferred(ArrayInterface.UnsafeIndex(ones(2,2,2), typeof((1,1,1)))) == ArrayInterface.UnsafeGetElement()
13+
end
14+
1015
@testset "to_index" begin
1116
axis = 1:3
1217
@test @inferred(ArrayInterface.to_index(axis, 1)) === 1
@@ -20,7 +25,6 @@ end
2025
@test_throws BoundsError ArrayInterface.to_index(axis, [true, false, false, true])
2126
end
2227

23-
2428
@testset "to_indices" begin
2529
a = ones(2,2,1)
2630
v = ones(2)
@@ -47,7 +51,12 @@ end
4751
@test @inferred ArrayInterface.to_indices(a, ([CartesianIndex(1,1,1), CartesianIndex(1,2,1)],)) == (CartesianIndex{3}[CartesianIndex(1, 1, 1), CartesianIndex(1, 2, 1)],)
4852
@test @inferred ArrayInterface.to_indices(a, ([CartesianIndex(1,1), CartesianIndex(1,2)],1:1)) == (CartesianIndex{2}[CartesianIndex(1, 1), CartesianIndex(1, 2)], 1:1)
4953
@test_throws ErrorException ArrayInterface.to_indices(ones(2,2,2), (1, 1))
54+
end
5055

56+
@testset "0-dimensional" begin
57+
x = Array{Int,0}(undef)
58+
ArrayInterface.setindex!(x, 1)
59+
@test @inferred(ArrayInterface.getindex(x)) == 1
5160
end
5261

5362
@testset "1-dimensional" begin
@@ -66,8 +75,9 @@ end
6675
@test_throws BoundsError ArrayInterface.getindex(CartesianIndices((3,)), 2, 2)
6776
# ambiguity btw cartesian indexing and linear indexing in 1d when
6877
# indices may be nontraditional
69-
@test_throws ArgumentError Base._sub2ind((1:3,), 2)
70-
@test_throws ArgumentError Base._ind2sub((1:3,), 2)
78+
# TODO should this be implemented in ArrayInterface with vectorization?
79+
#@test_throws ArgumentError Base._sub2ind((1:3,), 2)
80+
#@test_throws ArgumentError Base._ind2sub((1:3,), 2)
7181
end
7282

7383
@testset "2-dimensional" begin
@@ -81,17 +91,17 @@ end
8191
@test @inferred(ArrayInterface.getindex(LinearIndices(map(Base.Slice, (0:3,3:5))), i-1, j+2)) == k
8292
@test @inferred(ArrayInterface.getindex(CartesianIndices(map(Base.Slice, (0:3,3:5))), k)) == CartesianIndex(i-1,j+2)
8393
end
84-
@test @inferred(getindex(linear, linear)) == linear
85-
@test @inferred(getindex(linear, vec(linear))) == vec(linear)
86-
@test @inferred(getindex(linear, cartesian)) == linear
87-
@test @inferred(getindex(linear, vec(cartesian))) == vec(linear)
88-
@test @inferred(getindex(cartesian, linear)) == cartesian
89-
@test @inferred(getindex(cartesian, vec(linear))) == vec(cartesian)
90-
@test @inferred(getindex(cartesian, cartesian)) == cartesian
91-
@test @inferred(getindex(cartesian, vec(cartesian))) == vec(cartesian)
92-
@test @inferred(getindex(linear, 2:3)) === 2:3
93-
@test @inferred(getindex(linear, 3:-1:1)) === 3:-1:1
94-
@test_throws BoundsError linear[4:13]
94+
@test @inferred(ArrayInterface.getindex(linear, linear)) == linear
95+
@test @inferred(ArrayInterface.getindex(linear, vec(linear))) == vec(linear)
96+
@test @inferred(ArrayInterface.getindex(linear, cartesian)) == linear
97+
@test @inferred(ArrayInterface.getindex(linear, vec(cartesian))) == vec(linear)
98+
@test @inferred(ArrayInterface.getindex(cartesian, linear)) == cartesian
99+
@test @inferred(ArrayInterface.getindex(cartesian, vec(linear))) == vec(cartesian)
100+
@test @inferred(ArrayInterface.getindex(cartesian, cartesian)) == cartesian
101+
@test @inferred(ArrayInterface.getindex(cartesian, vec(cartesian))) == vec(cartesian)
102+
@test @inferred(ArrayInterface.getindex(linear, 2:3)) === 2:3
103+
@test @inferred(ArrayInterface.getindex(linear, 3:-1:1)) === 3:-1:1
104+
@test_throws BoundsError ArrayInterface.getindex(linear, 4:13)
95105
end
96106

97107
@testset "3-dimensional" begin
@@ -126,5 +136,3 @@ end
126136
@test @inferred(ArrayInterface.getindex(LinearIndices(A),ArrayInterface.getindex(CartesianIndices(A),i))) == i
127137
end
128138
end
129-
130-

0 commit comments

Comments
 (0)