Skip to content

Commit 70c135e

Browse files
authored
Optimizations on to_indices (#227)
* Optimizations on to_indices * Remove bounds checking from `to_indices` Although including bounds checking in the `to_indices` should avoid revisiting every argument site, the compiler has trouble optimizing it when done all at once. Now we do boundschecking prior to `to_indices` like in Base. * Replace canonicalize code with formal interface through `to_indices` * Extended help section for to_indices * Greatly simplify internals for `to_indices` After a lot of testing I found out that we can just rely on `ndims_index` for aligning axes in a generated function and then flatten out tuples returned from `to_index`. This greatly simplifies the explanation of internals and it also makes accomodating new indexing types simpler. * Minor version bump for new features
1 parent 5c9681d commit 70c135e

File tree

8 files changed

+272
-271
lines changed

8 files changed

+272
-271
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "3.1.40"
3+
version = "3.2"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

docs/src/api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ ArrayInterface.fast_scalar_indexing
1717
ArrayInterface.has_dimnames
1818
ArrayInterface.has_parent
1919
ArrayInterface.has_sparsestruct
20-
ArrayInterface.is_canonical
2120
ArrayInterface.is_column_major
2221
ArrayInterface.is_lazy_conjugate
2322
ArrayInterface.ismutable
2423
ArrayInterface.issingular
2524
ArrayInterface.isstructured
25+
ArrayInterface.is_splat_index
2626
ArrayInterface.known_first
2727
ArrayInterface.known_last
2828
ArrayInterface.known_length
@@ -31,6 +31,7 @@ ArrayInterface.known_offsets
3131
ArrayInterface.known_size
3232
ArrayInterface.known_step
3333
ArrayInterface.known_strides
34+
ArrayInterface.ndims_index
3435
```
3536

3637
## Functions
@@ -43,7 +44,6 @@ ArrayInterface.axes
4344
ArrayInterface.axes_types
4445
ArrayInterface.broadcast_axis
4546
ArrayInterface.buffer
46-
ArrayInterface.canonicalize
4747
ArrayInterface.deleteat
4848
ArrayInterface.dense_dims
4949
ArrayInterface.findstructralnz

docs/src/index.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ end
2929
```
3030

3131
Most traits in `ArrayInterface` are a variant on this pattern.
32+
If the trait in question may be altered by a wrapper array, this pattern should be altered or may be inappropriate.
3233

3334
## Static Traits
3435

@@ -174,3 +175,19 @@ Defining these two methods ensures that other array types that wrap `OffsetArray
174175
It is entirely optional to define `ArrayInterface.size` for `OffsetArray` because the size can be derived from the axes.
175176
However, in this particularly case we should also define
176177
`ArrayInterface.size(A::OffsetArray) = ArrayInterface.size(parent(A))` because the relative offsets attached to `OffsetArray` do not change the size but may hide static sizes if using a relative offset that is defined with an `Int`.
178+
179+
## Processing Indices (`to_indices`)
180+
181+
For most users, the only reason you should use `ArrayInterface.to_indices` over `Base.to_indices` is that it's faster and perhaps some of the more detailed benefits described in the [`to_indices`](@ref) doc string.
182+
For those interested in how this is accomplished, the following steps (beginning with the `to_indices(A::AbstractArray, I::Tuple)`) are used to accomplish this:
183+
184+
1. The number of dimensions that each indexing argument in `I` corresponds to is determined using using the [`ndims_index`](@ref) and [`is_splat_index`](@ref) traits.
185+
2. A non-allocating reference to each axis of `A` is created (`lazy_axes(A) -> axs`). These are aligned to each the index arguments using information from the first step. For example, if an index argument maps to a single dimension then it is paired with `axs[dim]`. In the case of multiple dimensions it is paired with `CartesianIndices(axs[dim_1], ... axs[dim_n])`. These pairs are further processed using `to_index(axis, I[n])`.
186+
3. Tuples returned from `to_index` are flattened out so that there are no nested tuples returned from `to_indices`.
187+
188+
Entry points:
189+
190+
* `to_indices(::ArrayType, indices)` : dispatch on unique array type `ArrayType`
191+
* `to_index(axis, ::IndexType)` : dispatch on a unique indexing type, `IndexType`. `ArrayInterface.ndims_index(::Type{IndexType})` should also be defined in this case.
192+
* `to_index(S::IndexStyle, axis, index)` : The index style `S` that corresponds to `axis`. This is
193+

src/ArrayInterface.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretAr
1414
ReshapedArray, AbstractCartesianIndex
1515

1616
const CanonicalInt = Union{Int,StaticInt}
17+
canonicalize(x::Integer) = Int(x)
18+
canonicalize(@nospecialize(x::StaticInt)) = x
1719

1820
@static if isdefined(Base, :ReshapedReinterpretArray)
1921
_is_reshaped(::Type{<:Base.ReshapedReinterpretArray}) = true
@@ -29,8 +31,6 @@ parameterless_type(x::Type) = __parameterless_type(x)
2931

3032
const VecAdjTrans{T,V<:AbstractVector{T}} = Union{Transpose{T,V},Adjoint{T,V}}
3133
const MatAdjTrans{T,M<:AbstractMatrix{T}} = Union{Transpose{T,M},Adjoint{T,M}}
32-
const UpTri{T,M} = Union{UpperTriangular{T,M},UnitUpperTriangular{T,M}}
33-
const LoTri{T,M} = Union{LowerTriangular{T,M},UnitLowerTriangular{T,M}}
3434

3535
@inline static_length(a::UnitRange{T}) where {T} = last(a) - first(a) + oneunit(T)
3636
@inline static_length(x) = Static.maybe_static(known_length, length, x)
@@ -55,15 +55,13 @@ parent_type(::Type{<:LinearAlgebra.AbstractTriangular{T,S}}) where {T,S} = S
5555
parent_type(::Type{<:PermutedDimsArray{T,N,I1,I2,A}}) where {T,N,I1,I2,A} = A
5656
parent_type(::Type{Slice{T}}) where {T} = T
5757
parent_type(::Type{T}) where {T} = T
58-
parent_type(::Type{R}) where {S,T,A,N,R<:Base.ReinterpretArray{T,N,S,A}} = A
59-
parent_type(::Type{LoTri{T,M}}) where {T,M} = M
60-
parent_type(::Type{UpTri{T,M}}) where {T,M} = M
58+
parent_type(::Type{R}) where {S,T,A,N,R<:ReinterpretArray{T,N,S,A}} = A
6159
parent_type(::Type{Diagonal{T,V}}) where {T,V} = V
6260

6361
"""
6462
has_parent(::Type{T}) -> StaticBool
6563
66-
Returns `True` if `parent_type(T)` a type unique to `T`.
64+
Returns `static(true)` if `parent_type(T)` a type unique to `T`.
6765
"""
6866
has_parent(x) = has_parent(typeof(x))
6967
has_parent(::Type{T}) where {T} = _has_parent(parent_type(T), T)

src/dimensions.jl

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,6 @@ function is_increasing(perm::Tuple{StaticInt{X},StaticInt{Y}}) where {X, Y}
2222
end
2323
is_increasing(::Tuple{StaticInt{X}}) where {X} = True()
2424

25-
#=
26-
ndims_index(::Type{I})::StaticInt
27-
28-
The number of dimensions an instance of `I` maps to when indexing an instance of `A`.
29-
=#
30-
ndims_index(i) = ndims_index(typeof(i))
31-
ndims_index(::Type{I}) where {I} = static(1)
32-
ndims_index(::Type{I}) where {N,I<:AbstractCartesianIndex{N}} = static(N)
33-
ndims_index(::Type{I}) where {I<:AbstractArray} = ndims_index(eltype(I))
34-
ndims_index(::Type{I}) where {I<:AbstractArray{Bool}} = static(ndims(I))
35-
ndims_index(::Type{I}) where {N,I<:LogicalIndex{<:Any,<:AbstractArray{Bool,N}}} = static(N)
36-
_ndims_index(::Type{I}, i::StaticInt) where {I} = ndims_index(_get_tuple(I, i))
37-
ndims_index(::Type{I}) where {N,I<:Tuple{Vararg{Any,N}}} = eachop(_ndims_index, nstatic(Val(N)), I)
38-
3925
"""
4026
from_parent_dims(::Type{T}) -> Tuple{Vararg{Union{Int,StaticInt}}}
4127
from_parent_dims(::Type{T}, dim) -> Union{Int,StaticInt}
@@ -191,7 +177,8 @@ end
191177
This returns the dimension(s) of `x` corresponding to `d`.
192178
"""
193179
to_dims(x, dim) = to_dims(typeof(x), dim)
194-
to_dims(::Type{T}, dim::Integer) where {T} = canonicalize(dim)
180+
to_dims(::Type{T}, dim::StaticInt) where {T} = dim
181+
to_dims(::Type{T}, dim::Integer) where {T} = Int(dim)
195182
to_dims(::Type{T}, dim::Colon) where {T} = dim
196183
function to_dims(::Type{T}, dim::StaticSymbol) where {T}
197184
i = find_first_eq(dim, dimnames(T))

0 commit comments

Comments
 (0)