Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
ArrayInterfaceCore = "0.1.3"
Compat = "3, 4"
IfElse = "0.1"
Static = "0.7"
Static = "0.8"
julia = "1.6"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion lib/ArrayInterfaceOffsetArrays/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
[compat]
ArrayInterface = "5, 6"
OffsetArrays = "1.11"
Static = "0.7"
Static = "0.7, 0.8"
julia = "1.6"

[extras]
Expand Down
2 changes: 1 addition & 1 deletion lib/ArrayInterfaceStaticArrays/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Adapt = "3"
ArrayInterface = "6"
ArrayInterfaceCore = "0.1.21"
ArrayInterfaceStaticArraysCore = "0.1"
Static = "0.7"
Static = "0.8"
StaticArrays = "1.2.5, 1.3, 1.4"
julia = "1.6"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import ArrayInterfaceStaticArraysCore

const CanonicalInt = Union{Int,StaticInt}

function Static.OptionallyStaticUnitRange(::StaticArrays.SOneTo{N}) where {N}
Static.OptionallyStaticUnitRange(StaticInt(1), StaticInt(N))
end
ArrayInterface.known_first(::Type{<:StaticArrays.SOneTo}) = 1
ArrayInterface.known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N
ArrayInterface.known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N
Expand Down
12 changes: 5 additions & 7 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ import ArrayInterfaceCore: known_first, known_step, known_last

using Static
using Static: Zero, One, nstatic, eq, ne, gt, ge, lt, le, eachop, eachop_tuple,
permute, invariant_permutation, field_type, reduce_tup, find_first_eq
permute, invariant_permutation, field_type, reduce_tup, find_first_eq,
OptionallyStaticUnitRange, OptionallyStaticStepRange, OptionallyStaticRange, IntType,
SOneTo, SUnitRange

using IfElse

Expand All @@ -43,10 +45,6 @@ _sub1(@nospecialize x) = x - oneunit(x)
Tuple{X.parameters...,Y.parameters...}
end

const CanonicalInt = Union{Int,StaticInt}
canonicalize(x::Integer) = Int(x)
canonicalize(@nospecialize(x::StaticInt)) = x

abstract type AbstractArray2{T,N} <: AbstractArray{T,N} end

Base.size(A::AbstractArray2) = map(Int, ArrayInterface.size(A))
Expand Down Expand Up @@ -93,10 +91,10 @@ end
@inline static_last(x) = Static.maybe_static(known_last, last, x)
@inline static_step(x) = Static.maybe_static(known_step, step, x)

@inline function _to_cartesian(a, i::CanonicalInt)
@inline function _to_cartesian(a, i::IntType)
@inbounds(CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a))))[i])
end
@inline function _to_linear(a, i::Tuple{CanonicalInt,Vararg{CanonicalInt}})
@inline function _to_linear(a, i::Tuple{IntType,Vararg{IntType}})
_strides2int(offsets(a), size_to_strides(size(a), static(1)), i) + static(1)
end

Expand Down
2 changes: 1 addition & 1 deletion src/array_index.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ struct StrideIndex{N,R,C,S,O} <: ArrayIndex{N}
end

## getindex
@propagate_inbounds Base.getindex(x::ArrayIndex, i::CanonicalInt, ii::CanonicalInt...) = x[NDIndex(i, ii...)]
@propagate_inbounds Base.getindex(x::ArrayIndex, i::IntType, ii::IntType...) = x[NDIndex(i, ii...)]

@inline function Base.getindex(x::StrideIndex{N}, i::AbstractCartesianIndex) where {N}
return _strides2int(offsets(x), strides(x), Tuple(i)) + static(1)
Expand Down
5 changes: 4 additions & 1 deletion src/axes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ Base.keys(x::LazyAxis) = keys(parent(x))

Base.IndexStyle(T::Type{<:LazyAxis}) = IndexStyle(parent_type(T))

function Static.OptionallyStaticUnitRange(x::LazyAxis)
OptionallyStaticUnitRange(static_first(x), static_last(x))
end
ArrayInterfaceCore.can_change_size(@nospecialize T::Type{<:LazyAxis}) = can_change_size(fieldtype(T, :parent))

ArrayInterfaceCore.known_first(::Type{<:LazyAxis{N,P}}) where {N,P} = known_offsets(P, static(N))
Expand Down Expand Up @@ -219,7 +222,7 @@ Base.axes1(x::Slice{LazyAxis{N,A}}) where {N,A} = indices(getfield(x.indices, :p
Base.axes1(x::Slice{LazyAxis{:,A}}) where {A} = indices(getfield(x.indices, :parent))
Base.to_shape(x::LazyAxis) = Base.length(x)

@propagate_inbounds function Base.getindex(x::LazyAxis, i::CanonicalInt)
@propagate_inbounds function Base.getindex(x::LazyAxis, i::IntType)
@boundscheck checkindex(Bool, x, i) || throw(BoundsError(x, i))
return Int(i)
end
Expand Down
10 changes: 5 additions & 5 deletions src/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ to `:_`, then `false` is returned.
Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not
have a name.
"""
@inline known_dimnames(x, dim) = _known_dimname(known_dimnames(x), canonicalize(dim))
@inline known_dimnames(x, dim) = _known_dimname(known_dimnames(x), IntType(dim))
known_dimnames(x) = known_dimnames(typeof(x))
function known_dimnames(@nospecialize T::Type{<:VecAdjTrans})
(:_, getfield(known_dimnames(parent_type(T)), 1))
Expand Down Expand Up @@ -159,7 +159,7 @@ end
_unknown_dimnames(::Base.HasShape{N}) where {N} = ntuple(Compat.Returns(:_), StaticInt(N))
_unknown_dimnames(::Any) = (:_,)

@inline function _known_dimname(x::Tuple{Vararg{Any,N}}, dim::CanonicalInt) where {N}
@inline function _known_dimname(x::Tuple{Vararg{Any,N}}, dim::IntType) where {N}
# we cannot have `@boundscheck`, else this will depend on bounds checking being enabled
(dim > N || dim < 1) && return :_
return @inbounds(x[dim])
Expand All @@ -173,7 +173,7 @@ end
Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not
have a name.
"""
@inline dimnames(x, dim) = _dimname(dimnames(x), canonicalize(dim))
@inline dimnames(x, dim) = _dimname(dimnames(x), IntType(dim))
@inline function dimnames(x::Union{PermutedDimsArray,MatAdjTrans})
map(GetIndex{false}(dimnames(parent(x))), to_parent_dims(x))
end
Expand Down Expand Up @@ -214,7 +214,7 @@ end
return ntuple(Compat.Returns(static(:_)), StaticInt(ndims(x)))
end
end
@inline function _dimname(x::Tuple{Vararg{Any,N}}, dim::CanonicalInt) where {N}
@inline function _dimname(x::Tuple{Vararg{Any,N}}, dim::IntType) where {N}
# we cannot have `@boundscheck`, else this will depend on bounds checking being enabled
# for calls such as `dimnames(view(x, :, 1, :))`
(dim > N || dim < 1) && return static(:_)
Expand All @@ -228,7 +228,7 @@ end
This returns the dimension(s) of `x` corresponding to `dim`.
"""
to_dims(x, dim::Colon) = dim
to_dims(x, @nospecialize(dim::CanonicalInt)) = dim
to_dims(x, @nospecialize(dim::IntType)) = dim
to_dims(x, dim::Integer) = Int(dim)
to_dims(x, dim::Union{StaticSymbol,Symbol}) = _to_dim(dimnames(x), dim)
function to_dims(x, dims::Tuple{Vararg{Any,N}}) where {N}
Expand Down
69 changes: 23 additions & 46 deletions src/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,4 @@

function known_lastindex(::Type{T}) where {T}
if known_offset1(T) === nothing || known_length(T) === nothing
return nothing
else
return known_length(T) - known_offset1(T) + 1
end
end
known_lastindex(@nospecialize x) = known_lastindex(typeof(x))

@inline static_lastindex(x) = Static.maybe_static(known_lastindex, lastindex, x)

function Base.first(x::AbstractVector, n::StaticInt)
@boundscheck n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
start = offset1(x)
@inbounds x[start:min((start - one(start)) + n, static_lastindex(x))]
end

function Base.last(x::AbstractVector, n::StaticInt)
@boundscheck n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
stop = static_lastindex(x)
@inbounds x[max(offset1(x), (stop + one(stop)) - n):stop]
end

"""
ArrayInterface.to_indices(A, I::Tuple) -> Tuple

Expand Down Expand Up @@ -162,16 +139,16 @@ to_index(::LinearIndices, i::AbstractArray{Bool}) = LogicalIndex{Int}(i)
@inline to_index(x, i::NDIndex) = getfield(i, 1)
@inline to_index(x, i::AbstractArray{<:AbstractCartesianIndex}) = i
@inline function to_index(x, i::Base.Fix2{<:Union{typeof(<),typeof(isless)},<:Union{Base.BitInteger,StaticInt}})
static_first(x):min(_sub1(canonicalize(i.x)), static_last(x))
static_first(x):min(_sub1(IntType(i.x)), static_last(x))
end
@inline function to_index(x, i::Base.Fix2{typeof(<=),<:Union{Base.BitInteger,StaticInt}})
static_first(x):min(canonicalize(i.x), static_last(x))
static_first(x):min(IntType(i.x), static_last(x))
end
@inline function to_index(x, i::Base.Fix2{typeof(>=),<:Union{Base.BitInteger,StaticInt}})
max(canonicalize(i.x), static_first(x)):static_last(x)
max(IntType(i.x), static_first(x)):static_last(x)
end
@inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}})
max(_add1(canonicalize(i.x)), static_first(x)):static_last(x)
max(_add1(IntType(i.x)), static_first(x)):static_last(x)
end
# integer indexing
to_index(x, i::AbstractArray{<:Integer}) = i
Expand Down Expand Up @@ -232,7 +209,7 @@ indices calling [`to_axis`](@ref).
end
end
# drop this dimension
to_axes(A, a::Tuple, i::Tuple{<:CanonicalInt,Vararg{Any}}) = to_axes(A, _maybe_tail(a), tail(i))
to_axes(A, a::Tuple, i::Tuple{<:IntType,Vararg{Any}}) = to_axes(A, _maybe_tail(a), tail(i))
to_axes(A, a::Tuple, i::Tuple{I,Vararg{Any}}) where {I} = _to_axes(StaticInt(ndims_index(I)), A, a, i)
function _to_axes(::StaticInt{1}, A, axs::Tuple, inds::Tuple)
return (to_axis(_maybe_first(axs), first(inds)), to_axes(A, _maybe_tail(axs), tail(inds))...)
Expand Down Expand Up @@ -309,15 +286,15 @@ function unsafe_getindex(a::A) where {A}
end

# TODO Need to manage index transformations between nested layers of arrays
function unsafe_getindex(a::A, i::CanonicalInt) where {A}
function unsafe_getindex(a::A, i::IntType) where {A}
if IndexStyle(A) === IndexLinear()
is_forwarding_wrapper(A) || throw(MethodError(unsafe_getindex, (A, i)))
return unsafe_getindex(parent(a), i)
else
return unsafe_getindex(a, _to_cartesian(a, i)...)
end
end
function unsafe_getindex(a::A, i::CanonicalInt, ii::Vararg{CanonicalInt}) where {A}
function unsafe_getindex(a::A, i::IntType, ii::Vararg{IntType}) where {A}
if IndexStyle(A) === IndexLinear()
return unsafe_getindex(a, _to_linear(a, (i, ii...)))
else
Expand All @@ -329,24 +306,24 @@ end
unsafe_getindex(a, i::Vararg{Any}) = unsafe_get_collection(a, i)

unsafe_getindex(A::Array) = Base.arrayref(false, A, 1)
unsafe_getindex(A::Array, i::CanonicalInt) = Base.arrayref(false, A, Int(i))
@inline function unsafe_getindex(A::Array, i::CanonicalInt, ii::Vararg{CanonicalInt})
unsafe_getindex(A::Array, i::IntType) = Base.arrayref(false, A, Int(i))
@inline function unsafe_getindex(A::Array, i::IntType, ii::Vararg{IntType})
unsafe_getindex(A, _to_linear(A, (i, ii...)))
end

unsafe_getindex(A::LinearIndices, i::CanonicalInt) = Int(i)
unsafe_getindex(A::CartesianIndices{N}, ii::Vararg{CanonicalInt,N}) where {N} = CartesianIndex(ii...)
unsafe_getindex(A::CartesianIndices, ii::Vararg{CanonicalInt}) =
unsafe_getindex(A::LinearIndices, i::IntType) = Int(i)
unsafe_getindex(A::CartesianIndices{N}, ii::Vararg{IntType,N}) where {N} = CartesianIndex(ii...)
unsafe_getindex(A::CartesianIndices, ii::Vararg{IntType}) =
unsafe_getindex(A, Base.front(ii)...)
unsafe_getindex(A::CartesianIndices, i::CanonicalInt) = @inbounds(A[i])
unsafe_getindex(A::CartesianIndices, i::IntType) = @inbounds(A[i])

unsafe_getindex(A::ReshapedArray, i::CanonicalInt) = @inbounds(parent(A)[i])
function unsafe_getindex(A::ReshapedArray, i::CanonicalInt, ii::Vararg{CanonicalInt})
unsafe_getindex(A::ReshapedArray, i::IntType) = @inbounds(parent(A)[i])
function unsafe_getindex(A::ReshapedArray, i::IntType, ii::Vararg{IntType})
@inbounds(parent(A)[_to_linear(A, (i, ii...))])
end

unsafe_getindex(A::SubArray, i::CanonicalInt) = @inbounds(A[i])
unsafe_getindex(A::SubArray, i::CanonicalInt, ii::Vararg{CanonicalInt}) = @inbounds(A[i, ii...])
unsafe_getindex(A::SubArray, i::IntType) = @inbounds(A[i])
unsafe_getindex(A::SubArray, i::IntType, ii::Vararg{IntType}) = @inbounds(A[i, ii...])

# This is based on Base._unsafe_getindex from https://github.com/JuliaLang/julia/blob/c5ede45829bf8eb09f2145bfd6f089459d77b2b1/base/multidimensional.jl#L755.
#=
Expand All @@ -364,17 +341,17 @@ function unsafe_get_collection(A, inds)
end
return dest
end
_ints2range(x::CanonicalInt) = x:x
_ints2range(x::IntType) = x:x
_ints2range(x::AbstractRange) = x
# apply _ints2range to front N elements
_ints2range_front(::Val{N}, ind, inds...) where {N} =
(_ints2range(ind), _ints2range_front(Val(N - 1), inds...)...)
_ints2range_front(::Val{0}, ind, inds...) = ()
_ints2range_front(::Val{0}) = ()
# get output shape with given indices
_output_shape(::CanonicalInt, inds...) = _output_shape(inds...)
_output_shape(::IntType, inds...) = _output_shape(inds...)
_output_shape(ind::AbstractRange, inds...) = (Base.length(ind), _output_shape(inds...)...)
_output_shape(::CanonicalInt) = ()
_output_shape(::IntType) = ()
_output_shape(x::AbstractRange) = (Base.length(x),)
@inline function unsafe_get_collection(A::CartesianIndices{N}, inds) where {N}
if (Base.length(inds) === 1 && N > 1) || stride_preserving_index(typeof(inds)) === False()
Expand Down Expand Up @@ -426,15 +403,15 @@ function unsafe_setindex!(a::A, v) where {A}
return unsafe_setindex!(parent(a), v)
end
# TODO Need to manage index transformations between nested layers of arrays
function unsafe_setindex!(a::A, v, i::CanonicalInt) where {A}
function unsafe_setindex!(a::A, v, i::IntType) where {A}
if IndexStyle(A) === IndexLinear()
is_forwarding_wrapper(A) || throw(MethodError(unsafe_setindex!, (A, v, i)))
return unsafe_setindex!(parent(a), v, i)
else
return unsafe_setindex!(a, v, _to_cartesian(a, i)...)
end
end
function unsafe_setindex!(a::A, v, i::CanonicalInt, ii::Vararg{CanonicalInt}) where {A}
function unsafe_setindex!(a::A, v, i::IntType, ii::Vararg{IntType}) where {A}
if IndexStyle(A) === IndexLinear()
return unsafe_setindex!(a, v, _to_linear(a, (i, ii...)))
else
Expand All @@ -446,7 +423,7 @@ end
function unsafe_setindex!(A::Array{T}, v) where {T}
Base.arrayset(false, A, convert(T, v)::T, 1)
end
function unsafe_setindex!(A::Array{T}, v, i::CanonicalInt) where {T}
function unsafe_setindex!(A::Array{T}, v, i::IntType) where {T}
return Base.arrayset(false, A, convert(T, v)::T, Int(i))
end

Expand Down
Loading