Skip to content

Commit 1b64411

Browse files
authored
Remove range types (#366)
* Remove optionally static range types This is complimentary to SciML/Static.jl#88 and would be a big move in disentangling static types from ArrayInterface * Replace CanonicalInt with Static.IntType * Remove internal `_pick_range` method
1 parent 9e58cb2 commit 1b64411

File tree

14 files changed

+73
-524
lines changed

14 files changed

+73
-524
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
1313
ArrayInterfaceCore = "0.1.3"
1414
Compat = "3, 4"
1515
IfElse = "0.1"
16-
Static = "0.7"
16+
Static = "0.8"
1717
julia = "1.6"
1818

1919
[extras]

lib/ArrayInterfaceOffsetArrays/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
1010
[compat]
1111
ArrayInterface = "5, 6"
1212
OffsetArrays = "1.11"
13-
Static = "0.7"
13+
Static = "0.7, 0.8"
1414
julia = "1.6"
1515

1616
[extras]

lib/ArrayInterfaceStaticArrays/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ Adapt = "3"
1616
ArrayInterface = "6"
1717
ArrayInterfaceCore = "0.1.21"
1818
ArrayInterfaceStaticArraysCore = "0.1"
19-
Static = "0.7"
19+
Static = "0.8"
2020
StaticArrays = "1.2.5, 1.3, 1.4"
2121
julia = "1.6"
2222

lib/ArrayInterfaceStaticArrays/src/ArrayInterfaceStaticArrays.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import ArrayInterfaceStaticArraysCore
99

1010
const CanonicalInt = Union{Int,StaticInt}
1111

12+
function Static.OptionallyStaticUnitRange(::StaticArrays.SOneTo{N}) where {N}
13+
Static.OptionallyStaticUnitRange(StaticInt(1), StaticInt(N))
14+
end
1215
ArrayInterface.known_first(::Type{<:StaticArrays.SOneTo}) = 1
1316
ArrayInterface.known_last(::Type{StaticArrays.SOneTo{N}}) where {N} = N
1417
ArrayInterface.known_length(::Type{StaticArrays.SOneTo{N}}) where {N} = N

src/ArrayInterface.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ import ArrayInterfaceCore: known_first, known_step, known_last
2323

2424
using Static
2525
using Static: Zero, One, nstatic, eq, ne, gt, ge, lt, le, eachop, eachop_tuple,
26-
permute, invariant_permutation, field_type, reduce_tup, find_first_eq
26+
permute, invariant_permutation, field_type, reduce_tup, find_first_eq,
27+
OptionallyStaticUnitRange, OptionallyStaticStepRange, OptionallyStaticRange, IntType,
28+
SOneTo, SUnitRange
2729

2830
using IfElse
2931

@@ -43,10 +45,6 @@ _sub1(@nospecialize x) = x - oneunit(x)
4345
Tuple{X.parameters...,Y.parameters...}
4446
end
4547

46-
const CanonicalInt = Union{Int,StaticInt}
47-
canonicalize(x::Integer) = Int(x)
48-
canonicalize(@nospecialize(x::StaticInt)) = x
49-
5048
abstract type AbstractArray2{T,N} <: AbstractArray{T,N} end
5149

5250
Base.size(A::AbstractArray2) = map(Int, ArrayInterface.size(A))
@@ -93,10 +91,10 @@ end
9391
@inline static_last(x) = Static.maybe_static(known_last, last, x)
9492
@inline static_step(x) = Static.maybe_static(known_step, step, x)
9593

96-
@inline function _to_cartesian(a, i::CanonicalInt)
94+
@inline function _to_cartesian(a, i::IntType)
9795
@inbounds(CartesianIndices(ntuple(dim -> indices(a, dim), Val(ndims(a))))[i])
9896
end
99-
@inline function _to_linear(a, i::Tuple{CanonicalInt,Vararg{CanonicalInt}})
97+
@inline function _to_linear(a, i::Tuple{IntType,Vararg{IntType}})
10098
_strides2int(offsets(a), size_to_strides(size(a), static(1)), i) + static(1)
10199
end
102100

src/array_index.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ struct StrideIndex{N,R,C,S,O} <: ArrayIndex{N}
2020
end
2121

2222
## getindex
23-
@propagate_inbounds Base.getindex(x::ArrayIndex, i::CanonicalInt, ii::CanonicalInt...) = x[NDIndex(i, ii...)]
23+
@propagate_inbounds Base.getindex(x::ArrayIndex, i::IntType, ii::IntType...) = x[NDIndex(i, ii...)]
2424

2525
@inline function Base.getindex(x::StrideIndex{N}, i::AbstractCartesianIndex) where {N}
2626
return _strides2int(offsets(x), strides(x), Tuple(i)) + static(1)

src/axes.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ Base.keys(x::LazyAxis) = keys(parent(x))
185185

186186
Base.IndexStyle(T::Type{<:LazyAxis}) = IndexStyle(parent_type(T))
187187

188+
function Static.OptionallyStaticUnitRange(x::LazyAxis)
189+
OptionallyStaticUnitRange(static_first(x), static_last(x))
190+
end
188191
ArrayInterfaceCore.can_change_size(@nospecialize T::Type{<:LazyAxis}) = can_change_size(fieldtype(T, :parent))
189192

190193
ArrayInterfaceCore.known_first(::Type{<:LazyAxis{N,P}}) where {N,P} = known_offsets(P, static(N))
@@ -219,7 +222,7 @@ Base.axes1(x::Slice{LazyAxis{N,A}}) where {N,A} = indices(getfield(x.indices, :p
219222
Base.axes1(x::Slice{LazyAxis{:,A}}) where {A} = indices(getfield(x.indices, :parent))
220223
Base.to_shape(x::LazyAxis) = Base.length(x)
221224

222-
@propagate_inbounds function Base.getindex(x::LazyAxis, i::CanonicalInt)
225+
@propagate_inbounds function Base.getindex(x::LazyAxis, i::IntType)
223226
@boundscheck checkindex(Bool, x, i) || throw(BoundsError(x, i))
224227
return Int(i)
225228
end

src/dimensions.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ to `:_`, then `false` is returned.
111111
Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not
112112
have a name.
113113
"""
114-
@inline known_dimnames(x, dim) = _known_dimname(known_dimnames(x), canonicalize(dim))
114+
@inline known_dimnames(x, dim) = _known_dimname(known_dimnames(x), IntType(dim))
115115
known_dimnames(x) = known_dimnames(typeof(x))
116116
function known_dimnames(@nospecialize T::Type{<:VecAdjTrans})
117117
(:_, getfield(known_dimnames(parent_type(T)), 1))
@@ -159,7 +159,7 @@ end
159159
_unknown_dimnames(::Base.HasShape{N}) where {N} = ntuple(Compat.Returns(:_), StaticInt(N))
160160
_unknown_dimnames(::Any) = (:_,)
161161

162-
@inline function _known_dimname(x::Tuple{Vararg{Any,N}}, dim::CanonicalInt) where {N}
162+
@inline function _known_dimname(x::Tuple{Vararg{Any,N}}, dim::IntType) where {N}
163163
# we cannot have `@boundscheck`, else this will depend on bounds checking being enabled
164164
(dim > N || dim < 1) && return :_
165165
return @inbounds(x[dim])
@@ -173,7 +173,7 @@ end
173173
Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not
174174
have a name.
175175
"""
176-
@inline dimnames(x, dim) = _dimname(dimnames(x), canonicalize(dim))
176+
@inline dimnames(x, dim) = _dimname(dimnames(x), IntType(dim))
177177
@inline function dimnames(x::Union{PermutedDimsArray,MatAdjTrans})
178178
map(GetIndex{false}(dimnames(parent(x))), to_parent_dims(x))
179179
end
@@ -214,7 +214,7 @@ end
214214
return ntuple(Compat.Returns(static(:_)), StaticInt(ndims(x)))
215215
end
216216
end
217-
@inline function _dimname(x::Tuple{Vararg{Any,N}}, dim::CanonicalInt) where {N}
217+
@inline function _dimname(x::Tuple{Vararg{Any,N}}, dim::IntType) where {N}
218218
# we cannot have `@boundscheck`, else this will depend on bounds checking being enabled
219219
# for calls such as `dimnames(view(x, :, 1, :))`
220220
(dim > N || dim < 1) && return static(:_)
@@ -228,7 +228,7 @@ end
228228
This returns the dimension(s) of `x` corresponding to `dim`.
229229
"""
230230
to_dims(x, dim::Colon) = dim
231-
to_dims(x, @nospecialize(dim::CanonicalInt)) = dim
231+
to_dims(x, @nospecialize(dim::IntType)) = dim
232232
to_dims(x, dim::Integer) = Int(dim)
233233
to_dims(x, dim::Union{StaticSymbol,Symbol}) = _to_dim(dimnames(x), dim)
234234
function to_dims(x, dims::Tuple{Vararg{Any,N}}) where {N}

src/indexing.jl

Lines changed: 23 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,4 @@
11

2-
function known_lastindex(::Type{T}) where {T}
3-
if known_offset1(T) === nothing || known_length(T) === nothing
4-
return nothing
5-
else
6-
return known_length(T) - known_offset1(T) + 1
7-
end
8-
end
9-
known_lastindex(@nospecialize x) = known_lastindex(typeof(x))
10-
11-
@inline static_lastindex(x) = Static.maybe_static(known_lastindex, lastindex, x)
12-
13-
function Base.first(x::AbstractVector, n::StaticInt)
14-
@boundscheck n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
15-
start = offset1(x)
16-
@inbounds x[start:min((start - one(start)) + n, static_lastindex(x))]
17-
end
18-
19-
function Base.last(x::AbstractVector, n::StaticInt)
20-
@boundscheck n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
21-
stop = static_lastindex(x)
22-
@inbounds x[max(offset1(x), (stop + one(stop)) - n):stop]
23-
end
24-
252
"""
263
ArrayInterface.to_indices(A, I::Tuple) -> Tuple
274
@@ -162,16 +139,16 @@ to_index(::LinearIndices, i::AbstractArray{Bool}) = LogicalIndex{Int}(i)
162139
@inline to_index(x, i::NDIndex) = getfield(i, 1)
163140
@inline to_index(x, i::AbstractArray{<:AbstractCartesianIndex}) = i
164141
@inline function to_index(x, i::Base.Fix2{<:Union{typeof(<),typeof(isless)},<:Union{Base.BitInteger,StaticInt}})
165-
static_first(x):min(_sub1(canonicalize(i.x)), static_last(x))
142+
static_first(x):min(_sub1(IntType(i.x)), static_last(x))
166143
end
167144
@inline function to_index(x, i::Base.Fix2{typeof(<=),<:Union{Base.BitInteger,StaticInt}})
168-
static_first(x):min(canonicalize(i.x), static_last(x))
145+
static_first(x):min(IntType(i.x), static_last(x))
169146
end
170147
@inline function to_index(x, i::Base.Fix2{typeof(>=),<:Union{Base.BitInteger,StaticInt}})
171-
max(canonicalize(i.x), static_first(x)):static_last(x)
148+
max(IntType(i.x), static_first(x)):static_last(x)
172149
end
173150
@inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}})
174-
max(_add1(canonicalize(i.x)), static_first(x)):static_last(x)
151+
max(_add1(IntType(i.x)), static_first(x)):static_last(x)
175152
end
176153
# integer indexing
177154
to_index(x, i::AbstractArray{<:Integer}) = i
@@ -232,7 +209,7 @@ indices calling [`to_axis`](@ref).
232209
end
233210
end
234211
# drop this dimension
235-
to_axes(A, a::Tuple, i::Tuple{<:CanonicalInt,Vararg{Any}}) = to_axes(A, _maybe_tail(a), tail(i))
212+
to_axes(A, a::Tuple, i::Tuple{<:IntType,Vararg{Any}}) = to_axes(A, _maybe_tail(a), tail(i))
236213
to_axes(A, a::Tuple, i::Tuple{I,Vararg{Any}}) where {I} = _to_axes(StaticInt(ndims_index(I)), A, a, i)
237214
function _to_axes(::StaticInt{1}, A, axs::Tuple, inds::Tuple)
238215
return (to_axis(_maybe_first(axs), first(inds)), to_axes(A, _maybe_tail(axs), tail(inds))...)
@@ -309,15 +286,15 @@ function unsafe_getindex(a::A) where {A}
309286
end
310287

311288
# TODO Need to manage index transformations between nested layers of arrays
312-
function unsafe_getindex(a::A, i::CanonicalInt) where {A}
289+
function unsafe_getindex(a::A, i::IntType) where {A}
313290
if IndexStyle(A) === IndexLinear()
314291
is_forwarding_wrapper(A) || throw(MethodError(unsafe_getindex, (A, i)))
315292
return unsafe_getindex(parent(a), i)
316293
else
317294
return unsafe_getindex(a, _to_cartesian(a, i)...)
318295
end
319296
end
320-
function unsafe_getindex(a::A, i::CanonicalInt, ii::Vararg{CanonicalInt}) where {A}
297+
function unsafe_getindex(a::A, i::IntType, ii::Vararg{IntType}) where {A}
321298
if IndexStyle(A) === IndexLinear()
322299
return unsafe_getindex(a, _to_linear(a, (i, ii...)))
323300
else
@@ -329,24 +306,24 @@ end
329306
unsafe_getindex(a, i::Vararg{Any}) = unsafe_get_collection(a, i)
330307

331308
unsafe_getindex(A::Array) = Base.arrayref(false, A, 1)
332-
unsafe_getindex(A::Array, i::CanonicalInt) = Base.arrayref(false, A, Int(i))
333-
@inline function unsafe_getindex(A::Array, i::CanonicalInt, ii::Vararg{CanonicalInt})
309+
unsafe_getindex(A::Array, i::IntType) = Base.arrayref(false, A, Int(i))
310+
@inline function unsafe_getindex(A::Array, i::IntType, ii::Vararg{IntType})
334311
unsafe_getindex(A, _to_linear(A, (i, ii...)))
335312
end
336313

337-
unsafe_getindex(A::LinearIndices, i::CanonicalInt) = Int(i)
338-
unsafe_getindex(A::CartesianIndices{N}, ii::Vararg{CanonicalInt,N}) where {N} = CartesianIndex(ii...)
339-
unsafe_getindex(A::CartesianIndices, ii::Vararg{CanonicalInt}) =
314+
unsafe_getindex(A::LinearIndices, i::IntType) = Int(i)
315+
unsafe_getindex(A::CartesianIndices{N}, ii::Vararg{IntType,N}) where {N} = CartesianIndex(ii...)
316+
unsafe_getindex(A::CartesianIndices, ii::Vararg{IntType}) =
340317
unsafe_getindex(A, Base.front(ii)...)
341-
unsafe_getindex(A::CartesianIndices, i::CanonicalInt) = @inbounds(A[i])
318+
unsafe_getindex(A::CartesianIndices, i::IntType) = @inbounds(A[i])
342319

343-
unsafe_getindex(A::ReshapedArray, i::CanonicalInt) = @inbounds(parent(A)[i])
344-
function unsafe_getindex(A::ReshapedArray, i::CanonicalInt, ii::Vararg{CanonicalInt})
320+
unsafe_getindex(A::ReshapedArray, i::IntType) = @inbounds(parent(A)[i])
321+
function unsafe_getindex(A::ReshapedArray, i::IntType, ii::Vararg{IntType})
345322
@inbounds(parent(A)[_to_linear(A, (i, ii...))])
346323
end
347324

348-
unsafe_getindex(A::SubArray, i::CanonicalInt) = @inbounds(A[i])
349-
unsafe_getindex(A::SubArray, i::CanonicalInt, ii::Vararg{CanonicalInt}) = @inbounds(A[i, ii...])
325+
unsafe_getindex(A::SubArray, i::IntType) = @inbounds(A[i])
326+
unsafe_getindex(A::SubArray, i::IntType, ii::Vararg{IntType}) = @inbounds(A[i, ii...])
350327

351328
# This is based on Base._unsafe_getindex from https://github.com/JuliaLang/julia/blob/c5ede45829bf8eb09f2145bfd6f089459d77b2b1/base/multidimensional.jl#L755.
352329
#=
@@ -364,17 +341,17 @@ function unsafe_get_collection(A, inds)
364341
end
365342
return dest
366343
end
367-
_ints2range(x::CanonicalInt) = x:x
344+
_ints2range(x::IntType) = x:x
368345
_ints2range(x::AbstractRange) = x
369346
# apply _ints2range to front N elements
370347
_ints2range_front(::Val{N}, ind, inds...) where {N} =
371348
(_ints2range(ind), _ints2range_front(Val(N - 1), inds...)...)
372349
_ints2range_front(::Val{0}, ind, inds...) = ()
373350
_ints2range_front(::Val{0}) = ()
374351
# get output shape with given indices
375-
_output_shape(::CanonicalInt, inds...) = _output_shape(inds...)
352+
_output_shape(::IntType, inds...) = _output_shape(inds...)
376353
_output_shape(ind::AbstractRange, inds...) = (Base.length(ind), _output_shape(inds...)...)
377-
_output_shape(::CanonicalInt) = ()
354+
_output_shape(::IntType) = ()
378355
_output_shape(x::AbstractRange) = (Base.length(x),)
379356
@inline function unsafe_get_collection(A::CartesianIndices{N}, inds) where {N}
380357
if (Base.length(inds) === 1 && N > 1) || stride_preserving_index(typeof(inds)) === False()
@@ -426,15 +403,15 @@ function unsafe_setindex!(a::A, v) where {A}
426403
return unsafe_setindex!(parent(a), v)
427404
end
428405
# TODO Need to manage index transformations between nested layers of arrays
429-
function unsafe_setindex!(a::A, v, i::CanonicalInt) where {A}
406+
function unsafe_setindex!(a::A, v, i::IntType) where {A}
430407
if IndexStyle(A) === IndexLinear()
431408
is_forwarding_wrapper(A) || throw(MethodError(unsafe_setindex!, (A, v, i)))
432409
return unsafe_setindex!(parent(a), v, i)
433410
else
434411
return unsafe_setindex!(a, v, _to_cartesian(a, i)...)
435412
end
436413
end
437-
function unsafe_setindex!(a::A, v, i::CanonicalInt, ii::Vararg{CanonicalInt}) where {A}
414+
function unsafe_setindex!(a::A, v, i::IntType, ii::Vararg{IntType}) where {A}
438415
if IndexStyle(A) === IndexLinear()
439416
return unsafe_setindex!(a, v, _to_linear(a, (i, ii...)))
440417
else
@@ -446,7 +423,7 @@ end
446423
function unsafe_setindex!(A::Array{T}, v) where {T}
447424
Base.arrayset(false, A, convert(T, v)::T, 1)
448425
end
449-
function unsafe_setindex!(A::Array{T}, v, i::CanonicalInt) where {T}
426+
function unsafe_setindex!(A::Array{T}, v, i::IntType) where {T}
450427
return Base.arrayset(false, A, convert(T, v)::T, Int(i))
451428
end
452429

0 commit comments

Comments
 (0)