Skip to content

Commit f1b5f65

Browse files
authored
Improve single element indexing (and other quality of life improvements) (#149)
* Use LazyAxis to avoid cost of materializing axes * to_indices doesn't iterate through indexers that don't need conversion * Patch bump and simplify some lazy_axes returns * add change size trait * Document canonicalize * Fix type stability of ranges * Add more axes tests
1 parent 16defd4 commit f1b5f65

File tree

10 files changed

+403
-84
lines changed

10 files changed

+403
-84
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "3.1.9"
3+
version = "3.1.10"
44

55
[deps]
66
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"

src/ArrayInterface.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretAr
1616
## utilites for internal use only ##
1717
_int_or_static_int(::Nothing) = Int
1818
_int_or_static_int(x::Int) = StaticInt{x}
19-
_int(i::Integer) = Int(i)
20-
_int(i::StaticInt) = i
2119

2220
@static if VERSION >= v"1.7.0-DEV.421"
2321
using Base: @aggressive_constprop
@@ -842,10 +840,10 @@ end
842840
end
843841

844842
include("ranges.jl")
845-
include("indexing.jl")
846843
include("dimensions.jl")
847844
include("axes.jl")
848845
include("size.jl")
846+
include("indexing.jl")
849847
include("stridelayout.jl")
850848
include("broadcast.jl")
851849

src/axes.jl

Lines changed: 152 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,9 @@ function axes_types(::Type{T}) where {T}
3737
end
3838
axes_types(::Type{LinearIndices{N,R}}) where {N,R} = R
3939
axes_types(::Type{CartesianIndices{N,R}}) where {N,R} = R
40-
function axes_types(::Type{T}) where {T<:VecAdjTrans}
41-
return Tuple{OptionallyStaticUnitRange{One,One},axes_types(parent_type(T), One())}
42-
end
43-
function axes_types(::Type{T}) where {T<:MatAdjTrans}
44-
return eachop_tuple(_get_tuple, to_parent_dims(T), axes_types(parent_type(T)))
40+
function axes_types(::Type{T}) where {T<:Union{Adjoint,Transpose}}
41+
P = parent_type(T)
42+
return Tuple{axes_types(P, static(2)), axes_types(P, static(1))}
4543
end
4644
function axes_types(::Type{T}) where {T<:PermutedDimsArray}
4745
return eachop_tuple(_get_tuple, to_parent_dims(T), axes_types(parent_type(T)))
@@ -133,6 +131,21 @@ function axes(a::A, dim::Integer) where {A}
133131
return axes(parent(a), to_parent_dims(A, dim))
134132
end
135133
end
134+
function axes(A::CartesianIndices{N}, dim::Integer) where {N}
135+
if dim > N
136+
return static(1):static(1)
137+
else
138+
return getfield(axes(A), Int(dim))
139+
end
140+
end
141+
function axes(A::LinearIndices{N}, dim::Integer) where {N}
142+
if dim > N
143+
return static(1):static(1)
144+
else
145+
return getfield(axes(A), Int(dim))
146+
end
147+
end
148+
136149
axes(A::SubArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
137150
axes(A::ReinterpretArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
138151
axes(A::Base.ReshapedArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
@@ -160,3 +173,137 @@ axes(A::Base.ReshapedArray) = Base.axes(A) # TODO implement ArrayInterface vers
160173
axes(A::CartesianIndices) = A.indices
161174
axes(A::LinearIndices) = A.indices
162175

176+
"""
177+
LazyAxis{N}(parent::AbstractArray)
178+
179+
A lazy representation of `axes(parent, N)`.
180+
"""
181+
struct LazyAxis{N,P} <: AbstractUnitRange{Int}
182+
parent::P
183+
184+
LazyAxis{N}(parent::P) where {N,P} = new{N::Int,P}(parent)
185+
@inline function LazyAxis{:}(parent::P) where {P}
186+
if ndims(P) === 1
187+
return new{1,P}(parent)
188+
else
189+
return new{:,P}(parent)
190+
end
191+
end
192+
end
193+
194+
@inline Base.parent(x::LazyAxis{N,P}) where {N,P} = axes(getfield(x, :parent), static(N))
195+
@inline function Base.parent(x::LazyAxis{:,P}) where {P}
196+
return eachindex(IndexLinear(), getfield(x, :parent))
197+
end
198+
199+
@inline parent_type(::Type{LazyAxis{N,P}}) where {N,P} = axes_types(P, static(N))
200+
# TODO this approach to parent_type(::Type{LazyAxis{:}}) is a bit hacky. Something like
201+
# LabelledArrays has a linear set of symbolic keys, which could be propagated through
202+
# `to_indices` for key based indexing. However, there currently isn't a good way of handling
203+
# that when the linear indices aren't linearly accessible through a child array (e.g, adjoint)
204+
# For now we just make sure the linear elements are accurate.
205+
parent_type(::Type{LazyAxis{:,P}}) where {P<:Array} = OneTo{Int}
206+
@inline function parent_type(::Type{LazyAxis{:,P}}) where {P}
207+
if known_length(P) === nothing
208+
return OptionallyStaticUnitRange{StaticInt{1},Int}
209+
else
210+
return OptionallyStaticUnitRange{StaticInt{1},StaticInt{known_length(P)}}
211+
end
212+
end
213+
214+
Base.keys(x::LazyAxis) = keys(parent(x))
215+
216+
Base.IndexStyle(::Type{T}) where {T<:LazyAxis} = IndexStyle(parent_type(T))
217+
218+
can_change_size(::Type{LazyAxis{N,P}}) where {N,P} = can_change_size(P)
219+
220+
known_first(::Type{T}) where {T<:LazyAxis} = known_first(parent_type(T))
221+
222+
known_length(::Type{LazyAxis{N,P}}) where {N,P} = known_size(P, N)
223+
known_length(::Type{LazyAxis{:,P}}) where {P} = known_length(P)
224+
225+
@inline function known_last(::Type{T}) where {T<:LazyAxis}
226+
return _lazy_axis_known_last(known_first(T), known_length(T))
227+
end
228+
_lazy_axis_known_last(start::Int, length::Int) = (length + start) - 1
229+
_lazy_axis_known_last(::Any, ::Any) = nothing
230+
231+
@inline function Base.first(x::LazyAxis{N})::Int where {N}
232+
if known_first(x) === nothing
233+
return offsets(getfield(x, :parent), static(N))
234+
else
235+
return known_first(x)
236+
end
237+
end
238+
@inline function Base.first(x::LazyAxis{:})::Int
239+
if known_first(x) === nothing
240+
return firstindex(getfield(x, :parent))
241+
else
242+
return known_first(x)
243+
end
244+
end
245+
246+
@inline function Base.length(x::LazyAxis{N})::Int where {N}
247+
if known_length(x) === nothing
248+
return size(getfield(x, :parent), static(N))
249+
else
250+
return known_length(x)
251+
end
252+
end
253+
@inline function Base.length(x::LazyAxis{:})::Int
254+
if known_length(x) === nothing
255+
return lastindex(getfield(x, :parent))
256+
else
257+
return known_length(x)
258+
end
259+
end
260+
261+
@inline function Base.last(x::LazyAxis)::Int
262+
if known_last(x) === nothing
263+
if known_first(x) === 1
264+
return length(x)
265+
else
266+
return (static_length(x) + static_first(x)) - 1
267+
end
268+
else
269+
return known_last(x)
270+
end
271+
end
272+
273+
Base.to_shape(x::LazyAxis) = length(x)
274+
275+
@inline function Base.checkindex(::Type{Bool}, x::LazyAxis, i::Integer)
276+
if known_first(x) === nothing || known_last(x) === nothing
277+
return checkindex(Bool, parent(x), i)
278+
else # everything is static so we don't have to retrieve the axis
279+
return (!(known_first(x) > i) || !(known_last(x) < i))
280+
end
281+
end
282+
283+
@propagate_inbounds function Base.getindex(x::LazyAxis, i::Integer)
284+
@boundscheck checkindex(Bool, x, i) || throw(BoundsError(x, i))
285+
return Int(i)
286+
end
287+
@propagate_inbounds Base.getindex(x::LazyAxis, i::StepRange{T}) where {T<:Integer} = parent(x)[i]
288+
@propagate_inbounds Base.getindex(x::LazyAxis, i::AbstractUnitRange{<:Integer}) = parent(x)[i]
289+
290+
Base.show(io::IO, x::LazyAxis{N}) where {N} = print(io, "LazyAxis{$N}($(parent(x))))")
291+
292+
"""
293+
lazy_axes(x)
294+
295+
Produces a tuple of axes where each axis is constructed lazily. If an axis of `x` is already
296+
constructed or it is simply retrieved.
297+
"""
298+
@generated function lazy_axes(x::X) where {X}
299+
Expr(:block,
300+
Expr(:meta, :inline),
301+
Expr(:tuple, [:(LazyAxis{$dim}(x)) for dim in 1:ndims(X)]...)
302+
)
303+
end
304+
lazy_axes(x::LinearIndices) = axes(x)
305+
lazy_axes(x::CartesianIndices) = axes(x)
306+
@inline lazy_axes(x::MatAdjTrans) = reverse(lazy_axes(parent(x)))
307+
@inline lazy_axes(x::VecAdjTrans) = (LazyAxis{1}(x), first(lazy_axes(parent(x))))
308+
@inline lazy_axes(x::PermutedDimsArray) = permute(lazy_axes(parent(x)), to_parent_dims(A))
309+

src/dimensions.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,16 +190,13 @@ function dimnames(::Type{T}) where {T<:SubArray}
190190
return eachop(dimnames, to_parent_dims(T), parent_type(T))
191191
end
192192

193-
_to_int(x::Integer) = Int(x)
194-
_to_int(x::StaticInt) = x
195-
196193
"""
197194
to_dims(::Type{T}, dim) -> Integer
198195
199196
This returns the dimension(s) of `x` corresponding to `d`.
200197
"""
201198
to_dims(x, dim) = to_dims(typeof(x), dim)
202-
to_dims(::Type{T}, dim::Integer) where {T} = _to_int(dim)
199+
to_dims(::Type{T}, dim::Integer) where {T} = canonicalize(dim)
203200
to_dims(::Type{T}, dim::Colon) where {T} = dim
204201
function to_dims(::Type{T}, dim::StaticSymbol) where {T}
205202
i = find_first_eq(dim, dimnames(T))

0 commit comments

Comments
 (0)