Skip to content

Commit cd6e5d2

Browse files
authored
NDIndex - Optionally static CartesianIndex (#140)
Add NDIndex
1 parent d65bcc0 commit cd6e5d2

File tree

6 files changed

+357
-143
lines changed

6 files changed

+357
-143
lines changed

src/ArrayInterface.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ else
2727
end
2828
end
2929

30-
static_ndims(x) = static(ndims(x))
31-
3230
if VERSION v"1.6.0-DEV.1581"
3331
_is_reshaped(::Type{ReinterpretArray{T,N,S,A,true}}) where {T,N,S,A} = true
3432
_is_reshaped(::Type{ReinterpretArray{T,N,S,A,false}}) where {T,N,S,A} = false
@@ -51,6 +49,8 @@ const LoTri{T,M} = Union{LowerTriangular{T,M},UnitLowerTriangular{T,M}}
5149
@inline static_last(x) = Static.maybe_static(known_last, last, x)
5250
@inline static_step(x) = Static.maybe_static(known_step, step, x)
5351

52+
include("ndindex.jl")
53+
5454
"""
5555
parent_type(::Type{T})
5656
@@ -70,6 +70,7 @@ parent_type(::Type{R}) where {S,T,A,N,R<:Base.ReinterpretArray{T,N,S,A}} = A
7070
parent_type(::Type{LoTri{T,M}}) where {T,M} = M
7171
parent_type(::Type{UpTri{T,M}}) where {T,M} = M
7272
parent_type(::Type{Diagonal{T,V}}) where {T,V} = V
73+
7374
"""
7475
has_parent(::Type{T}) -> StaticBool
7576
@@ -591,7 +592,7 @@ safevec(v::Number) = v
591592
safevec(v::AbstractVector) = v
592593

593594
"""
594-
zeromatrix(u::AbstractVector)
595+
zeromatrix(u::AbstractVector)
595596
596597
Creates the zero'd matrix version of `u`. Note that this is unique because
597598
`similar(u,length(u),length(u))` returns a mutable type, so it is not type-matching,
@@ -607,7 +608,7 @@ function zeromatrix(u)
607608
end
608609

609610
"""
610-
restructure(x,y)
611+
restructure(x,y)
611612
612613
Restructures the object `y` into a shape of `x`, keeping its values intact. For
613614
simple objects like an `Array`, this simply amounts to a reshape. However, for

src/indexing.jl

Lines changed: 91 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -28,60 +28,25 @@ argdims(s::ArrayStyle, arg) = argdims(s, typeof(arg))
2828
argdims(::ArrayStyle, ::Type{T}) where {T} = static(0)
2929
argdims(::ArrayStyle, ::Type{T}) where {T<:Colon} = static(1)
3030
argdims(::ArrayStyle, ::Type{T}) where {T<:AbstractArray} = static(ndims(T))
31-
argdims(::ArrayStyle, ::Type{T}) where {N,T<:CartesianIndex{N}} = static(N)
32-
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{CartesianIndex{N}}} = static(N)
31+
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractCartesianIndex{N}} = static(N)
32+
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{<:AbstractCartesianIndex{N}}} = static(N)
3333
argdims(::ArrayStyle, ::Type{T}) where {N,T<:AbstractArray{<:Any,N}} = static(N)
3434
argdims(::ArrayStyle, ::Type{T}) where {N,T<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = static(N)
3535
_argdims(s::ArrayStyle, ::Type{I}, i::StaticInt) where {I} = argdims(s, _get_tuple(I, i))
3636
function argdims(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
3737
return eachop(_argdims, nstatic(Val(N)), s, T)
3838
end
3939

40-
is_element_index(i) = is_element_index(typeof(i))
41-
is_element_index(::Type{T}) where {T} = static(false)
42-
is_element_index(::Type{T}) where {T<:AbstractCartesianIndex} = static(true)
43-
is_element_index(::Type{T}) where {T<:Integer} = static(true)
44-
_is_element_index(::Type{T}, i::StaticInt) where {T} = is_element_index(_get_tuple(T, i))
45-
function is_element_index(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
46-
return static(all(eachop(_is_element_index, nstatic(Val(N)), T)))
47-
end
48-
49-
"""
50-
UnsafeIndex(::ArrayStyle, ::Type{I})
51-
52-
`UnsafeIndex` controls how indices that have been bounds checked and converted to
53-
native axes' indices are used to return the stored values of an array. For example,
54-
if the indices at each dimension are single integers then `UnsafeIndex(array, inds)` returns
55-
`UnsafeGetElement()`. Conversely, if any of the indices are vectors then `UnsafeGetCollection()`
56-
is returned, indicating that a new array needs to be reconstructed. This method permits
57-
customizing the terminal behavior of the indexing pipeline based on arguments passed
58-
to `ArrayInterface.getindex`. New subtypes of `UnsafeIndex` should define `promote_rule`.
59-
"""
60-
abstract type UnsafeIndex end
61-
62-
struct UnsafeGetElement <: UnsafeIndex end
63-
64-
struct UnsafeGetCollection <: UnsafeIndex end
65-
66-
UnsafeIndex(x, i) = UnsafeIndex(x, typeof(i))
67-
UnsafeIndex(x, ::Type{I}) where {I} = UnsafeIndex(ArrayStyle(x), I)
68-
UnsafeIndex(s::ArrayStyle, i) = UnsafeIndex(s, typeof(i))
69-
UnsafeIndex(::ArrayStyle, ::Type{I}) where {I} = UnsafeGetElement()
70-
UnsafeIndex(::ArrayStyle, ::Type{I}) where {I<:AbstractArray} = UnsafeGetCollection()
71-
72-
Base.promote_rule(::Type{X}, ::Type{Y}) where {X<:UnsafeIndex,Y<:UnsafeGetElement} = X
73-
74-
@generated function UnsafeIndex(s::ArrayStyle, ::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
75-
if N === 0
76-
return UnsafeGetElement()
77-
else
78-
e = Expr(:call, promote_type)
79-
for p in T.parameters
80-
push!(e.args, :(typeof(ArrayInterface.UnsafeIndex(s, $p))))
81-
end
82-
return Expr(:block, Expr(:meta, :inline), Expr(:call, e))
83-
end
40+
_is_element_index(i) = _is_element_index(typeof(i))
41+
_is_element_index(::Type{T}) where {T} = static(false)
42+
_is_element_index(::Type{T}) where {T<:AbstractCartesianIndex} = static(true)
43+
_is_element_index(::Type{T}) where {T<:Integer} = static(true)
44+
__is_element_index(::Type{T}, i::StaticInt) where {T} = _is_element_index(_get_tuple(T, i))
45+
function _is_element_index(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}}
46+
return static(all(eachop(__is_element_index, nstatic(Val(N)), T)))
8447
end
48+
# empty tuples refer to the single element of 0-dimensional arrays
49+
_is_element_index(::Type{Tuple{}}) = static(true)
8550

8651
# are the indexing arguments provided a linear collection into a multidim collection
8752
is_linear_indexing(A, args::Tuple{Arg}) where {Arg} = argdims(A, Arg) < 2
@@ -181,6 +146,22 @@ to_index(::IndexLinear, axis, arg::CartesianIndices{1}) = axes(arg, 1)
181146
@propagate_inbounds function to_index(::IndexLinear, axis, arg::AbstractCartesianIndex{1})
182147
return to_index(axis, first(Tuple(arg)))
183148
end
149+
function to_index(::IndexLinear, x, arg::AbstractCartesianIndex{N}) where {N}
150+
inds = Tuple(arg)
151+
o = offsets(x)
152+
s = size(x)
153+
return first(inds) + (offset1(x) - first(o)) + _subs2int(first(s), tail(s), tail(o), tail(inds))
154+
end
155+
@inline function _subs2int(stride, s::Tuple{Any,Vararg}, o::Tuple{Any,Vararg}, inds::Tuple{Any,Vararg})
156+
i = ((first(inds) - first(o)) * stride)
157+
return i + _subs2int(stride * first(s), tail(s), tail(o), tail(inds))
158+
end
159+
function _subs2int(stride, s::Tuple{Any}, o::Tuple{Any}, inds::Tuple{Any})
160+
return (first(inds) - first(o)) * stride
161+
end
162+
# trailing inbounds can only be 1 or 1:1
163+
_subs2int(stride, ::Tuple{}, ::Tuple{}, ::Tuple{Any}) = static(0)
164+
184165
@propagate_inbounds function to_index(::IndexLinear, x, arg::Union{Array{Bool}, BitArray})
185166
@boundscheck checkbounds(x, arg)
186167
return LogicalIndex{Int}(arg)
@@ -194,7 +175,7 @@ end
194175
return arg
195176
end
196177
@propagate_inbounds function to_index(::IndexLinear, x, arg::Integer)
197-
@boundscheck checkindex(Bool, x, arg) || throw(BoundsError(x, arg))
178+
@boundscheck checkindex(Bool, indices(x), arg) || throw(BoundsError(x, arg))
198179
return _int(arg)
199180
end
200181
@propagate_inbounds function to_index(::IndexLinear, axis, arg::AbstractArray{Bool})
@@ -209,25 +190,11 @@ end
209190
@boundscheck checkindex(Bool, indices(axis), arg) || throw(BoundsError(axis, arg))
210191
return static_first(arg):static_step(arg):static_last(arg)
211192
end
212-
to_index(::IndexLinear, x, inds::Tuple{Any}) = first(inds)
213-
function to_index(::IndexLinear, x, inds::Tuple{Any,Vararg{Any}})
214-
o = offsets(x)
215-
s = size(x)
216-
return first(inds) + (offset1(x) - first(o)) + _subs2int(first(s), tail(s), tail(o), tail(inds))
217-
end
218-
@inline function _subs2int(stride, s::Tuple{Any,Vararg}, o::Tuple{Any,Vararg}, inds::Tuple{Any,Vararg})
219-
i = ((first(inds) - first(o)) * stride)
220-
return i + _subs2int(stride * first(s), tail(s), tail(o), tail(inds))
221-
end
222-
function _subs2int(stride, s::Tuple{Any}, o::Tuple{Any}, inds::Tuple{Any})
223-
return (first(inds) - first(o)) * stride
224-
end
225-
# trailing inbounds can only be 1 or 1:1
226-
_subs2int(stride, ::Tuple{}, ::Tuple{}, ::Tuple{Any}) = static(0)
227193

228194
## IndexCartesian ##
229195
to_index(::IndexCartesian, x, arg::Colon) = CartesianIndices(x)
230196
to_index(::IndexCartesian, x, arg::CartesianIndices{0}) = arg
197+
to_index(::IndexCartesian, x, arg::AbstractCartesianIndex) = arg
231198
function to_index(::IndexCartesian, x, arg)
232199
@boundscheck _multi_check_index(axes(x), arg) || throw(BoundsError(x, arg))
233200
return arg
@@ -253,15 +220,13 @@ end
253220
@boundscheck checkbounds(x, arg)
254221
return LogicalIndex{Int}(arg)
255222
end
256-
to_index(::IndexCartesian, x, i::Integer) = _int2subs(axes(x), i - offset1(x))
257-
@inline function _int2subs(axs::Tuple{Any,Vararg{Any}}, i)
258-
axis = first(axs)
259-
len = static_length(axis)
223+
to_index(::IndexCartesian, x, i::Integer) = NDIndex(_int2subs(offsets(x), size(x), i - offset1(x)))
224+
@inline function _int2subs(o::Tuple{Any,Vararg{Any}}, s::Tuple{Any,Vararg{Any}}, i)
225+
len = first(s)
260226
inext = div(i, len)
261-
return (_int(i - len * inext + static_first(axis)), _int2subs(tail(axs), inext)...)
227+
return (_int(i - len * inext + first(o)), _int2subs(tail(o), tail(s), inext)...)
262228
end
263-
_int2subs(axs::Tuple{Any}, i) = _int(i + static_first(first(axs)))
264-
229+
_int2subs(o::Tuple{Any}, s::Tuple{Any}, i) = _int(i + first(o))
265230

266231
"""
267232
unsafe_reconstruct(A, data; kwargs...)
@@ -353,6 +318,9 @@ end
353318
end
354319
to_axis(S::IndexLinear, axis, inds) = StaticInt(1):static_length(inds)
355320

321+
################
322+
### getindex ###
323+
################
356324
"""
357325
ArrayInterface.getindex(A, args...)
358326
@@ -362,14 +330,19 @@ Changing indexing based on a given argument from `args` should be done through,
362330
[`to_index`](@ref), or [`to_axis`](@ref).
363331
"""
364332
@propagate_inbounds getindex(A, args...) = unsafe_get_index(A, to_indices(A, args))
365-
@propagate_inbounds getindex(A; kwargs...) = A[order_named_inds(dimnames(A), kwargs.data)...]
333+
@propagate_inbounds function getindex(A; kwargs...)
334+
return unsafe_get_index(A, to_indices(A, order_named_inds(dimnames(A), kwargs.data)))
335+
end
366336
@propagate_inbounds getindex(x::Tuple, i::Int) = getfield(x, i)
367337
@propagate_inbounds getindex(x::Tuple, ::StaticInt{i}) where {i} = getfield(x, i)
368338

369339
## unsafe_get_index ##
370-
unsafe_get_index(A, inds::Tuple) = _unsafe_get_index(is_element_index(inds), A, inds)
371-
_unsafe_get_index(::True, A, inds::Tuple) = unsafe_get_element(A, inds)
340+
unsafe_get_index(A, inds::Tuple) = _unsafe_get_index(_is_element_index(inds), A, inds)
372341
_unsafe_get_index(::False, A, inds::Tuple) = unsafe_get_collection(A, inds)
342+
_unsafe_get_index(::True, A, inds::Tuple) = __unsafe_get_index(A, inds)
343+
__unsafe_get_index(A, inds::Tuple{}) = unsafe_get_element(A, ())
344+
__unsafe_get_index(A, inds::Tuple{Any}) = unsafe_get_element(A, first(inds))
345+
__unsafe_get_index(A, inds::Tuple{Any,Vararg{Any}}) = unsafe_get_element(A, NDIndex(inds))
373346

374347
"""
375348
unsafe_get_element(A::AbstractArray{T}, inds::Tuple) -> T
@@ -380,22 +353,30 @@ must define `unsafe_get_element(::NewArrayType, inds)`.
380353
"""
381354
unsafe_get_element(a::A, inds) where {A} = _unsafe_get_element(has_parent(A), a, inds)
382355
_unsafe_get_element(::True, a, inds) = unsafe_get_element(parent(a), inds)
383-
_unsafe_get_element(::False, a, inds) = @inbounds(parent(a)[inds...])
384-
_unsafe_get_element(::False, a::AbstractArray2, inds) = unsafe_get_element_error(a, inds)
356+
_unsafe_get_element(::False, a, inds) = @inbounds(parent(a)[inds])
357+
_unsafe_get_element(::False, a::AbstractArray2, i) = unsafe_get_element_error(a, i)
358+
359+
## Array ##
385360
unsafe_get_element(A::Array, ::Tuple{}) = Base.arrayref(false, A, 1)
386-
unsafe_get_element(A::Array, inds) = Base.arrayref(false, A, Int(to_index(A, inds)))
387-
unsafe_get_element(A::LinearIndices, inds) = Int(to_index(A, inds))
388-
@inline function unsafe_get_element(A::CartesianIndices, inds)
389-
if length(inds) === 1
390-
return CartesianIndex(to_index(A, first(inds)))
391-
else
392-
return CartesianIndex(Base._to_subscript_indices(A, inds...))
393-
end
361+
unsafe_get_element(A::Array, i::Integer) = Base.arrayref(false, A, Int(i))
362+
unsafe_get_element(A::Array, i::NDIndex) = unsafe_get_element(A, to_index(A, i))
363+
364+
## LinearIndices ##
365+
unsafe_get_element(A::LinearIndices, i::Integer) = Int(i)
366+
unsafe_get_element(A::LinearIndices, i::NDIndex) = unsafe_get_element(A, to_index(A, i))
367+
368+
unsafe_get_element(A::CartesianIndices, i::NDIndex) = CartesianIndex(i)
369+
unsafe_get_element(A::CartesianIndices, i::Integer) = unsafe_get_element(A, to_index(A, i))
370+
371+
unsafe_get_element(A::ReshapedArray, i::Integer) = unsafe_get_element(parent(A), i)
372+
function unsafe_get_element(A::ReshapedArray, i::NDIndex)
373+
return unsafe_get_element(parent(A), to_index(IndexLinear(), A, i))
394374
end
395-
unsafe_get_element(A::ReshapedArray, inds) = @inbounds(A[inds...])
396-
unsafe_get_element(A::SubArray, inds) = @inbounds(A[inds...])
397375

398-
unsafe_get_element_error(A, inds) = throw(MethodError(unsafe_get_element, (A, inds)))
376+
unsafe_get_element(A::SubArray, i) = @inbounds(A[i])
377+
function unsafe_get_element_error(@nospecialize(A), @nospecialize(i))
378+
throw(MethodError(unsafe_get_element, (A, i)))
379+
end
399380

400381
# This is based on Base._unsafe_getindex from https://github.com/JuliaLang/julia/blob/c5ede45829bf8eb09f2145bfd6f089459d77b2b1/base/multidimensional.jl#L755.
401382
"""
@@ -424,7 +405,7 @@ function _generate_unsafe_get_index!_body(N::Int)
424405
# the optimizer is not clever enough to split the union without it
425406
Dy === nothing && return dest
426407
(idx, state) = Dy
427-
dest[idx] = unsafe_get_element(src, Base.Cartesian.@ntuple($N, j))
408+
dest[idx] = unsafe_get_element(src, NDIndex(Base.Cartesian.@ntuple($N, j)))
428409
Dy = iterate(D, state)
429410
end
430411
return dest
@@ -453,37 +434,36 @@ end
453434
end
454435
end
455436

437+
#################
438+
### setindex! ###
439+
#################
456440
"""
457441
ArrayInterface.setindex!(A, args...)
458442
459443
Store the given values at the given key or index within a collection.
460444
"""
461445
@propagate_inbounds function setindex!(A, val, args...)
462446
if can_setindex(A)
463-
return unsafe_setindex!(A, val, to_indices(A, args))
447+
return unsafe_set_index!(A, val, to_indices(A, args))
464448
else
465449
error("Instance of type $(typeof(A)) are not mutable and cannot change elements after construction.")
466450
end
467451
end
468452
@propagate_inbounds function setindex!(A, val; kwargs...)
469-
if has_dimnames(A)
470-
return setindex!(A, val, order_named_inds(dimnames(A), kwargs.data)...)
471-
else
472-
return unsafe_setindex!(A, val, to_indices(A, ()))
473-
end
453+
return unsafe_set_index!(A, val, to_indices(A, order_named_inds(dimnames(A), kwargs.data)))
474454
end
475455

476-
"""
477-
unsafe_setindex!(A, val, inds::Tuple)
478-
479-
Sets indices (`inds`) of `A` to `val`. This method assumes that `inds` have already been
480-
bounds-checked. This step of the processing pipeline can be customized by:
481-
"""
482-
unsafe_setindex!(A, val, i::Tuple) = unsafe_setindex!(UnsafeIndex(A, i), A, val, i)
483-
unsafe_setindex!(::UnsafeGetElement, A, val, i::Tuple) = unsafe_set_element!(A, val, i)
484-
unsafe_setindex!(::UnsafeGetCollection, A, v, i::Tuple) = unsafe_set_collection!(A, v, i)
456+
unsafe_set_index!(A, v, inds::Tuple) = _unsafe_set_index!(_is_element_index(inds), A, v, inds)
457+
_unsafe_set_index!(::False, A, v, inds::Tuple) = unsafe_set_collection!(A, v, inds)
458+
_unsafe_set_index!(::True, A, v, inds::Tuple) = __unsafe_set_index!(A, v, inds)
459+
__unsafe_set_index!(A, v, inds::Tuple{}) = unsafe_set_element!(A, v, ())
460+
function __unsafe_set_index!(A, v, inds::Tuple{Any})
461+
return unsafe_set_element!(A, v, to_index(A, first(inds)))
462+
end
463+
function __unsafe_set_index!(A, v, inds::Tuple{Any,Vararg{Any}})
464+
return unsafe_set_element!(A, v, to_index(A, NDIndex(inds)))
465+
end
485466

486-
unsafe_set_element_error(A, v, i) = throw(MethodError(unsafe_set_element!, (A, v, i)))
487467

488468
"""
489469
unsafe_set_element!(A, val, inds::Tuple)
@@ -494,19 +474,18 @@ must define `unsafe_set_element!(::NewArrayType, val, inds)`.
494474
"""
495475
unsafe_set_element!(a, val, inds) = _unsafe_set_element!(has_parent(a), a, val, inds)
496476
_unsafe_set_element!(::True, a, val, inds) = unsafe_set_element!(parent(a), val, inds)
497-
_unsafe_set_element!(::False, a, val,inds) = @inbounds(parent(a)[inds...] = val)
477+
_unsafe_set_element!(::False, a, val, inds) = @inbounds(parent(a)[inds] = val)
478+
498479
function _unsafe_set_element!(::False, a::AbstractArray2, val, inds)
499480
unsafe_set_element_error(a, val, inds)
500481
end
482+
unsafe_set_element_error(A, v, i) = throw(MethodError(unsafe_set_element!, (A, v, i)))
501483

502-
function unsafe_set_element!(A::Array{T}, val, inds::Tuple) where {T}
503-
if length(inds) === 0
504-
return Base.arrayset(false, A, convert(T, val)::T, 1)
505-
elseif inds isa Tuple{Vararg{Int}}
506-
return Base.arrayset(false, A, convert(T, val)::T, inds...)
507-
else
508-
throw(MethodError(unsafe_set_element!, (A, inds)))
509-
end
484+
function unsafe_set_element!(A::Array{T}, val, ::Tuple{}) where {T}
485+
Base.arrayset(false, A, convert(T, val)::T, 1)
486+
end
487+
function unsafe_set_element!(A::Array{T}, val, i::Integer) where {T}
488+
return Base.arrayset(false, A, convert(T, val)::T, Int(i))
510489
end
511490

512491
# This is based on Base._unsafe_setindex!.
@@ -529,7 +508,7 @@ function _generate_unsafe_setindex!_body(N::Int)
529508
# the optimizer that it does not need to emit error paths
530509
Xy === nothing && break
531510
(val, state) = Xy
532-
unsafe_set_element!(A, val, Base.Cartesian.@ntuple($N, i))
511+
unsafe_set_element!(A, val, NDIndex(Base.Cartesian.@ntuple($N, i)))
533512
Xy = iterate(x′, state)
534513
end
535514
A

0 commit comments

Comments
 (0)