Skip to content

Commit 09eeb2f

Browse files
authored
Use static indices for deletat and safer param access (#287)
Since we created our own ArrayInterface.length we were calling OneTo(static_length) indirectly. We might as well just use the known indices so that if the compiler ever gets smart enough it can unroll that. We also shouldn't be doing T.parameters[i] as much so this fixes some of that. I also threw in some very minor changes that reduced invalidations from over 100 to 27 when loading ArrayInterface. All changes were locally tested to ensure we wouldn't get performance regressions.
1 parent f67325d commit 09eeb2f

File tree

5 files changed

+9
-11
lines changed

5 files changed

+9
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "6.0.6"
3+
version = "6.0.7"
44

55
[deps]
66
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"

src/ArrayInterface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function Base.axes(A::AbstractArray2)
5151
!(parent_type(A) <: typeof(A)) && return ArrayInterface.axes(parent(A))
5252
throw(ArgumentError("Subtypes of `AbstractArray2` must define an axes method"))
5353
end
54-
Base.axes(A::AbstractArray2, dim) = ArrayInterface.axes(A, dim)
54+
Base.axes(A::AbstractArray2, dim::Union{Symbol,StaticSymbol}) = Base.axes(A, to_dims(A, dim))
5555

5656
function Base.strides(A::AbstractArray2)
5757
defines_strides(A) && return map(Int, ArrayInterface.strides(A))
@@ -256,7 +256,7 @@ end
256256
@inline function unsafe_deleteat(src::Tuple, inds::AbstractVector)
257257
dst = Vector{eltype(src)}(undef, length(src) - length(inds))
258258
dst_index = firstindex(dst)
259-
@inbounds for src_index in OneTo(length(src))
259+
@inbounds for src_index in static(1):length(src)
260260
if !in(src_index, inds)
261261
dst[dst_index] = src[src_index]
262262
dst_index += one(dst_index)

src/axes.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717
if dim > ndims(x)
1818
return SOneTo{1}
1919
else
20-
return axes_types(x).parameters[dim]
20+
return fieldtype(axes_types(x), dim)
2121
end
2222
end
2323
axes_types(x) = axes_types(typeof(x))
@@ -230,7 +230,6 @@ end
230230
end
231231
ArrayInterfaceCore.known_last(::Type{LazyAxis{N,P}}) where {N,P} = known_last(axes_types(P, static(N)))
232232
ArrayInterfaceCore.known_last(::Type{LazyAxis{:,P}}) where {P} = known_length(P)
233-
Base.lastindex(x::LazyAxis) = last(x)
234233
Base.last(x::LazyAxis) = _last(known_last(x), x)
235234
_last(::Nothing, x) = last(parent(x))
236235
_last(N::Int, x) = N

src/dimensions.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(I
5050
@generated function _from_sub_dims(::Type{I}) where {I<:Tuple}
5151
out = Expr(:tuple)
5252
dim_i = 1
53-
for i in 1:length(I.parameters)
54-
p = I.parameters[i]
53+
for i in 1:fieldcount(I)
54+
p = fieldtype(I, i)
5555
if p <: CanonicalInt
5656
push!(out.args, :(StaticInt(0)))
5757
else
@@ -106,7 +106,8 @@ to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(I)
106106
@generated function _to_sub_dims(::Type{I}) where {I<:Tuple}
107107
out = Expr(:tuple)
108108
n = 1
109-
for p in I.parameters
109+
for i in 1:fieldcount(I)
110+
p = fieldtype(I, i)
110111
if !(p <: CanonicalInt)
111112
push!(out.args, :(StaticInt($n)))
112113
end

src/ranges.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,6 @@ end
239239
end
240240

241241
## length
242-
Base.lastindex(x::OptionallyStaticRange) = length(x)
243242
@inline function Base.length(r::OptionallyStaticUnitRange)
244243
if isempty(r)
245244
return 0
@@ -274,7 +273,6 @@ function Base.AbstractUnitRange{T}(r::OptionallyStaticUnitRange) where {T}
274273
end
275274
end
276275

277-
Base.eachindex(r::OptionallyStaticRange) = One():length(r)
278276
@inline function Base.iterate(r::OptionallyStaticRange)
279277
isempty(r) && return nothing
280278
fi = Int(first(r));
@@ -302,7 +300,7 @@ Base.axes(S::Slice{<:OptionallyStaticUnitRange{One}}) = (S.indices,)
302300
Base.axes(S::Slice{<:OptionallyStaticRange}) = (Base.IdentityUnitRange(S.indices),)
303301

304302
Base.axes(x::OptionallyStaticRange) = (Base.axes1(x),)
305-
Base.axes1(x::OptionallyStaticRange) = eachindex(x)
303+
Base.axes1(x::OptionallyStaticRange) = static(1):length(x)
306304
Base.axes1(x::Slice{<:OptionallyStaticUnitRange{One}}) = x.indices
307305
Base.axes1(x::Slice{<:OptionallyStaticRange}) = Base.IdentityUnitRange(x.indices)
308306

0 commit comments

Comments
 (0)