Skip to content

Commit 5228630

Browse files
authored
Safe functional indexing (#316)
This allows common comparison operators be parsed into indices (`<`, `<=`, `>`, `>=`). So instead of doing `x[3:end]` you can do `x[>=(3)]` (if `x` uses `ArrayInterface.getindex`. The main utility right now is that we can ensure that the indices are always inbounds using this syntax. I'm still working out the best way to communicate this info to `checkbounds` so that bounds checking is passed over (similar to how `:` becomes `Slice` which is known to be inbounds); but that will probably be more involved so I'm going to make that a separate effort.
1 parent b922313 commit 5228630

File tree

3 files changed

+56
-5
lines changed

3 files changed

+56
-5
lines changed

src/ArrayInterface.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ import Compat
3636

3737
n_of_x(::StaticInt{N}, x::X) where {N,X} = ntuple(Compat.Returns(x), Val{N}())
3838

39+
_add1(@nospecialize x) = x + oneunit(x)
40+
_sub1(@nospecialize x) = x - oneunit(x)
41+
3942
@generated function merge_tuple_type(::Type{X}, ::Type{Y}) where {X<:Tuple,Y<:Tuple}
4043
Tuple{X.parameters...,Y.parameters...}
4144
end

src/indexing.jl

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ function Base.last(x::AbstractVector, n::StaticInt)
2323
end
2424

2525
"""
26-
to_indices(A, I::Tuple) -> Tuple
26+
ArrayInterface.to_indices(A, I::Tuple) -> Tuple
2727
2828
Converts the tuple of indexing arguments, `I`, into an appropriate form for indexing into `A`.
2929
Typically, each index should be an `Int`, `StaticInt`, a collection with values of `Int`, or a collection with values of `CartesianIndex`
@@ -183,16 +183,50 @@ end
183183
end
184184

185185
"""
186-
to_index([::IndexStyle, ]axis, arg) -> index
186+
ArrayInterface.to_index([::IndexStyle, ]axis, arg) -> index
187187
188-
Convert the argument `arg` that was originally passed to `getindex` for the dimension
189-
corresponding to `axis` into a form for native indexing (`Int`, Vector{Int}, etc.). New
190-
axis types with unique behavior should use an `IndexStyle` trait:
188+
Convert the argument `arg` that was originally passed to `ArrayInterface.getindex` for the
189+
dimension corresponding to `axis` into a form for native indexing (`Int`, Vector{Int}, etc.).
191190
191+
`ArrayInterface.to_index` supports passing a function as an index. This function-index is
192+
transformed into a proper index.
193+
194+
```julia
195+
julia> using ArrayInterface, Static
196+
197+
julia> ArrayInterface.to_index(static(1):static(10), 5)
198+
5
199+
200+
julia> ArrayInterface.to_index(static(1):static(10), <(5))
201+
static(1):4
202+
203+
julia> ArrayInterface.to_index(static(1):static(10), <=(5))
204+
static(1):5
205+
206+
julia> ArrayInterface.to_index(static(1):static(10), >(5))
207+
6:static(10)
208+
209+
julia> ArrayInterface.to_index(static(1):static(10), >=(5))
210+
5:static(10)
211+
212+
```
213+
214+
Use of a function-index helps ensure that indices are inbounds
215+
216+
```julia
217+
julia> ArrayInterface.to_index(static(1):static(10), <(12))
218+
static(1):10
219+
220+
julia> ArrayInterface.to_index(static(1):static(10), >(-1))
221+
1:static(10)
222+
```
223+
224+
New axis types with unique behavior should use an `IndexStyle` trait:
192225
```julia
193226
to_index(axis::MyAxisType, arg) = to_index(IndexStyle(axis), axis, arg)
194227
to_index(::MyIndexStyle, axis, arg) = ...
195228
```
229+
196230
"""
197231
to_index(x, i::Slice) = i
198232
to_index(x, ::Colon) = indices(x)
@@ -207,6 +241,18 @@ to_index(x::LinearIndices, i::AbstractArray{Bool}) = LogicalIndex{Int}(i)
207241
@inline to_index(x, i::CartesianIndex) = Tuple(i)
208242
@inline to_index(x, i::NDIndex) = Tuple(i)
209243
@inline to_index(x, i::AbstractArray{<:AbstractCartesianIndex}) = i
244+
@inline function to_index(x, i::Base.Fix2{<:Union{typeof(<),typeof(isless)},<:Union{Base.BitInteger,StaticInt}})
245+
offset1(x):min(_sub1(canonicalize(i.x)), static_lastindex(x))
246+
end
247+
@inline function to_index(x, i::Base.Fix2{typeof(<=),<:Union{Base.BitInteger,StaticInt}})
248+
offset1(x):min(canonicalize(i.x), static_lastindex(x))
249+
end
250+
@inline function to_index(x, i::Base.Fix2{typeof(>=),<:Union{Base.BitInteger,StaticInt}})
251+
max(canonicalize(i.x), offset1(x)):static_lastindex(x)
252+
end
253+
@inline function to_index(x, i::Base.Fix2{typeof(>),<:Union{Base.BitInteger,StaticInt}})
254+
max(_add1(canonicalize(i.x)), offset1(x)):static_lastindex(x)
255+
end
210256
# integer indexing
211257
to_index(x, i::AbstractArray{<:Integer}) = i
212258
to_index(x, @nospecialize(i::StaticInt)) = i

test/indexing.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,8 @@ end
202202
@test @inferred(ArrayInterface.getindex(cartesian, vec(cartesian))) == vec(cartesian)
203203
@test @inferred(ArrayInterface.getindex(linear, 2:3)) === 2:3
204204
@test @inferred(ArrayInterface.getindex(linear, 3:-1:1)) === 3:-1:1
205+
@test @inferred(ArrayInterface.getindex(linear, >(1), <(3))) == linear[(begin+1):end, 1:(end-1)]
206+
@test @inferred(ArrayInterface.getindex(linear, >=(1), <=(3))) == linear[begin:end, 1:end]
205207
@test_throws BoundsError ArrayInterface.getindex(linear, 4:13)
206208
end
207209

0 commit comments

Comments
 (0)