Skip to content

Commit f67325d

Browse files
authored
Pre static changes (#286)
* Reduce dispatch on `Integer` Motivated by SciML/Static.jl#64, this shouldn't change how any code currently works. Just less resrtrictive dispatch patterns on `Integer`.
1 parent fbaab13 commit f67325d

File tree

8 files changed

+29
-34
lines changed

8 files changed

+29
-34
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "6.0.5"
3+
version = "6.0.6"
44

55
[deps]
66
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"

src/ArrayInterface.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ Base.@propagate_inbounds function insert(collection, index, item)
196196
return ret
197197
end
198198

199-
function insert(x::Tuple{Vararg{Any,N}}, index::Integer, item) where {N}
199+
function insert(x::Tuple{Vararg{Any,N}}, index, item) where {N}
200200
@boundscheck if !checkindex(Bool, StaticInt{1}():StaticInt{N}(), index)
201201
throw(BoundsError(x, index))
202202
end
@@ -229,7 +229,7 @@ Base.@propagate_inbounds function deleteat(collection::Tuple{Vararg{Any,N}}, ind
229229
return unsafe_deleteat(collection, index)
230230
end
231231

232-
function unsafe_deleteat(src::AbstractVector, index::Integer)
232+
function unsafe_deleteat(src::AbstractVector, index)
233233
dst = similar(src, length(src) - 1)
234234
@inbounds for i in indices(dst)
235235
if i < index
@@ -265,10 +265,10 @@ end
265265
return Tuple(dst)
266266
end
267267

268-
@inline unsafe_deleteat(x::Tuple{T}, i::Integer) where {T} = ()
269-
@inline unsafe_deleteat(x::Tuple{T1,T2}, i::Integer) where {T1,T2} =
268+
@inline unsafe_deleteat(x::Tuple{T}, i) where {T} = ()
269+
@inline unsafe_deleteat(x::Tuple{T1,T2}, i) where {T1,T2} =
270270
isone(i) ? (x[2],) : (x[1],)
271-
@inline function unsafe_deleteat(x::Tuple, i::Integer)
271+
@inline function unsafe_deleteat(x::Tuple, i)
272272
if i === one(i)
273273
return tail(x)
274274
elseif i == length(x)

src/axes.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ end
119119
end
120120
end
121121

122-
@inline function axes(A::SubArray, dim::Integer)
122+
@inline function axes(A::SubArray, dim::CanonicalInt)
123123
if dim > ndims(A)
124124
return OneTo(1)
125125
else
@@ -260,15 +260,15 @@ Base.axes(x::Slice{<:LazyAxis}) = (Base.axes1(x),)
260260
Base.axes1(x::Slice{<:LazyAxis}) = indices(parent(x.indices))
261261
Base.to_shape(x::LazyAxis) = length(x)
262262

263-
@inline function Base.checkindex(::Type{Bool}, x::LazyAxis, i::Integer)
263+
@inline function Base.checkindex(::Type{Bool}, x::LazyAxis, i::CanonicalInt)
264264
if known_first(x) === nothing || known_last(x) === nothing
265265
return checkindex(Bool, parent(x), i)
266266
else # everything is static so we don't have to retrieve the axis
267267
return (!(known_first(x) > i) || !(known_last(x) < i))
268268
end
269269
end
270270

271-
@propagate_inbounds function Base.getindex(x::LazyAxis, i::Integer)
271+
@propagate_inbounds function Base.getindex(x::LazyAxis, i::CanonicalInt)
272272
@boundscheck checkindex(Bool, x, i) || throw(BoundsError(x, i))
273273
return Int(i)
274274
end
@@ -294,5 +294,3 @@ lazy_axes(x::CartesianIndices) = axes(x)
294294
@inline lazy_axes(x::MatAdjTrans) = reverse(lazy_axes(parent(x)))
295295
@inline lazy_axes(x::VecAdjTrans) = (SOneTo{1}(), first(lazy_axes(parent(x))))
296296
@inline lazy_axes(x::PermutedDimsArray) = permute(lazy_axes(parent(x)), to_parent_dims(x))
297-
298-

src/dimensions.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(I
5252
dim_i = 1
5353
for i in 1:length(I.parameters)
5454
p = I.parameters[i]
55-
if p <: Integer
55+
if p <: CanonicalInt
5656
push!(out.args, :(StaticInt(0)))
5757
else
5858
push!(out.args, :(StaticInt($dim_i)))
@@ -107,7 +107,7 @@ to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(I)
107107
out = Expr(:tuple)
108108
n = 1
109109
for p in I.parameters
110-
if !(p <: Integer)
110+
if !(p <: CanonicalInt)
111111
push!(out.args, :(StaticInt($n)))
112112
end
113113
n += 1
@@ -161,7 +161,7 @@ to `:_`, then `false` is returned.
161161
Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not
162162
have a name.
163163
"""
164-
@inline known_dimnames(x, dim::Integer) = _known_dimname(known_dimnames(x), canonicalize(dim))
164+
@inline known_dimnames(x, dim) = _known_dimname(known_dimnames(x), canonicalize(dim))
165165
known_dimnames(x) = known_dimnames(typeof(x))
166166
known_dimnames(::Type{T}) where {T} = _known_dimnames(T, parent_type(T))
167167
_known_dimnames(::Type{T}, ::Type{T}) where {T} = _unknown_dimnames(Base.IteratorSize(T))
@@ -184,7 +184,7 @@ end
184184
Return the names of the dimensions for `x`. `:_` is used to indicate a dimension does not
185185
have a name.
186186
"""
187-
@inline dimnames(x, dim::Integer) = _dimname(dimnames(x), canonicalize(dim))
187+
@inline dimnames(x, dim) = _dimname(dimnames(x), canonicalize(dim))
188188
@inline dimnames(x) = _dimnames(has_parent(x), x)
189189
@inline function _dimnames(::True, x)
190190
eachop(_inbounds_dimname, to_parent_dims(x), dimnames(parent(x)))
@@ -204,7 +204,8 @@ end
204204
This returns the dimension(s) of `x` corresponding to `dim`.
205205
"""
206206
to_dims(x, dim::Colon) = dim
207-
to_dims(x, dim::Integer) = canonicalize(dim)
207+
to_dims(x, @nospecialize(dim::CanonicalInt)) = dim
208+
to_dims(x, dim::Integer) = Int(dim)
208209
to_dims(x, dim::Union{StaticSymbol,Symbol}) = _to_dim(dimnames(x), dim)
209210
function to_dims(x, dims::Tuple{Vararg{Any,N}}) where {N}
210211
eachop(_to_dims, nstatic(Val(N)), dimnames(x), dims)
@@ -257,4 +258,3 @@ An error is thrown if any keywords are used which do not occur in `nda`'s names.
257258
end
258259
end
259260
end
260-

src/indexing.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ This implementation differs from that of `Base.to_indices` in the following ways
6060
* Specializing by dispatch through method definitions like this:
6161
`to_indices(::ArrayType, ::Tuple{AxisType,Vararg{Any}}, ::Tuple{::IndexType,Vararg{Any}})`
6262
require an excessive number of hand written methods to avoid ambiguities. Furthermore, if
63-
`AxisType` is wrapping another axis that should have unique behavior, then unique parametric
63+
`AxisType` is wrapping another axis that should have unique behavior, then unique parametric
6464
types need to also be explicitly defined.
6565
* `to_index(axes(A, dim), index)` is called, as opposed to `Base.to_index(A, index)`. The
6666
`IndexStyle` of the resulting axis is used to allow indirect dispatch on nested axis types
@@ -228,7 +228,7 @@ indices calling [`to_axis`](@ref).
228228
end
229229
end
230230
# drop this dimension
231-
to_axes(A, a::Tuple, i::Tuple{<:Integer,Vararg{Any}}) = to_axes(A, tail(a), tail(i))
231+
to_axes(A, a::Tuple, i::Tuple{<:CanonicalInt,Vararg{Any}}) = to_axes(A, tail(a), tail(i))
232232
to_axes(A, a::Tuple, i::Tuple{I,Vararg{Any}}) where {I} = _to_axes(StaticInt(ndims_index(I)), A, a, i)
233233
function _to_axes(::StaticInt{1}, A, axs::Tuple, inds::Tuple)
234234
return (to_axis(first(axs), first(inds)), to_axes(A, tail(axs), tail(inds))...)
@@ -353,7 +353,7 @@ function unsafe_get_collection(A, inds)
353353
end
354354
return dest
355355
end
356-
_ints2range(x::Integer) = x:x
356+
_ints2range(x::CanonicalInt) = x:x
357357
_ints2range(x::AbstractRange) = x
358358
@inline function unsafe_get_collection(A::CartesianIndices{N}, inds) where {N}
359359
if (Base.length(inds) === 1 && N > 1) || stride_preserving_index(typeof(inds)) === False()

src/ranges.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ struct OptionallyStaticUnitRange{F<:CanonicalInt,L<:CanonicalInt} <: AbstractUni
1414
function OptionallyStaticUnitRange(start::CanonicalInt, stop::CanonicalInt)
1515
new{typeof(start),typeof(stop)}(start, stop)
1616
end
17-
function OptionallyStaticUnitRange(start::Integer, stop::Integer)
17+
function OptionallyStaticUnitRange(start, stop)
1818
OptionallyStaticUnitRange(canonicalize(start), canonicalize(stop))
1919
end
2020
function OptionallyStaticUnitRange(x::AbstractRange)
@@ -60,7 +60,7 @@ struct OptionallyStaticStepRange{F<:CanonicalInt,S<:CanonicalInt,L<:CanonicalInt
6060
lst = _steprange_last(start, step, stop)
6161
new{typeof(start),typeof(step),typeof(lst)}(start, step, lst)
6262
end
63-
function OptionallyStaticStepRange(start::Integer, step::Integer, stop::Integer)
63+
function OptionallyStaticStepRange(start, step, stop)
6464
OptionallyStaticStepRange(canonicalize(start), canonicalize(step), canonicalize(stop))
6565
end
6666
function OptionallyStaticStepRange(x::AbstractRange)
@@ -72,15 +72,15 @@ end
7272
@inline function _steprange_last(start::StaticInt, step::StaticInt, stop::StaticInt)
7373
return StaticInt(_steprange_last(Int(start), Int(step), Int(stop)))
7474
end
75-
@inline function _steprange_last(start::Integer, step::StaticInt, stop::StaticInt)
75+
@inline function _steprange_last(start, step::StaticInt, stop::StaticInt)
7676
if step === one(step)
7777
# we don't need to check the `stop` if we know it acts like a unit range
7878
return stop
7979
else
8080
return _steprange_last(start, Int(step), Int(stop))
8181
end
8282
end
83-
@inline function _steprange_last(start::Integer, step::Integer, stop::Integer)
83+
@inline function _steprange_last(start, step, stop)
8484
z = zero(step)
8585
if step === z
8686
throw(ArgumentError("step cannot be zero"))
@@ -415,4 +415,3 @@ end
415415
function Base.similar(::Type{<:Array{T}}, axes::Tuple{Base.OneTo,OptionallyStaticUnitRange{StaticInt{1}},Vararg{Union{Base.OneTo,OptionallyStaticUnitRange{StaticInt{1}}}}}) where {T}
416416
Array{T}(undef, map(last, axes))
417417
end
418-

src/size.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ size(x::Iterators.Pairs) = size(getfield(x, :itr))
5454
end
5555

5656
size(a, dim) = size(a, to_dims(a, dim))
57-
size(a::Array, dim::Integer) = Base.arraysize(a, convert(Int, dim))
58-
function size(a::A, dim::Integer) where {A}
57+
size(a::Array, dim::CanonicalInt) = Base.arraysize(a, convert(Int, dim))
58+
function size(a::A, dim::CanonicalInt) where {A}
5959
if parent_type(A) <: A
6060
len = known_size(A, dim)
6161
if len === nothing
@@ -67,7 +67,7 @@ function size(a::A, dim::Integer) where {A}
6767
return size(a)[dim]
6868
end
6969
end
70-
function size(A::SubArray, dim::Integer)
70+
function size(A::SubArray, dim::CanonicalInt)
7171
pdim = to_parent_dims(A, dim)
7272
if pdim > ndims(parent_type(A))
7373
return size(parent(A), pdim)
@@ -170,4 +170,3 @@ _prod_or_nothing(_) = nothing
170170

171171
_maybe_known_length(::Base.HasShape, ::Type{T}) where {T} = _prod_or_nothing(known_size(T))
172172
_maybe_known_length(::Base.IteratorSize, ::Type) = nothing
173-

src/stridelayout.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ not known at compile time `nothing` is returned its position.
4444
"""
4545
known_offsets(x, dim) = known_offsets(typeof(x), dim)
4646
known_offsets(::Type{T}, dim) where {T} = known_offsets(T, to_dims(T, dim))
47-
function known_offsets(::Type{T}, dim::Integer) where {T}
47+
function known_offsets(::Type{T}, dim::CanonicalInt) where {T}
4848
if ndims(T) < dim
4949
return 1
5050
else
@@ -187,7 +187,7 @@ function _contiguous_axis(::Type{A}, c::StaticInt{C}) where {T,N,P,I,A<:SubArray
187187
return from_parent_dims(A)[C]
188188
elseif field_type(I, c) <: AbstractArray
189189
return -One()
190-
elseif field_type(I, c) <: Integer
190+
elseif field_type(I, c) <: CanonicalInt
191191
return -One()
192192
else
193193
return nothing
@@ -489,7 +489,7 @@ compile time are represented by `nothing`.
489489
"""
490490
known_strides(x, dim) = known_strides(typeof(x), dim)
491491
known_strides(::Type{T}, dim) where {T} = known_strides(T, to_dims(T, dim))
492-
function known_strides(::Type{T}, dim::Integer) where {T}
492+
function known_strides(::Type{T}, dim::CanonicalInt) where {T}
493493
# see https://github.com/JuliaLang/julia/blob/6468dcb04ea2947f43a11f556da9a5588de512a0/base/reinterpretarray.jl#L148
494494
if ndims(T) < dim
495495
return known_length(T)
@@ -663,7 +663,7 @@ maybe_static_step(_) = nothing
663663
end
664664

665665
strides(a, dim) = strides(a, to_dims(a, dim))
666-
function strides(a::A, dim::Integer) where {A}
666+
function strides(a::A, dim::CanonicalInt) where {A}
667667
if parent_type(A) <: A
668668
return Base.stride(a, Int(dim))
669669
else
@@ -674,4 +674,3 @@ end
674674
@inline stride(A::AbstractArray, ::StaticInt{N}) where {N} = strides(A)[N]
675675
@inline stride(A::AbstractArray, ::Val{N}) where {N} = strides(A)[N]
676676
stride(A, i) = Base.stride(A, i) # for type stability
677-

0 commit comments

Comments
 (0)