Skip to content

Commit 91795a6

Browse files
authored
clean up range code and documentation (#176)
1 parent deccaf8 commit 91795a6

File tree

3 files changed

+65
-104
lines changed

3 files changed

+65
-104
lines changed

src/indexing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ Changing indexing based on a given argument from `args` should be done through,
393393
"""
394394
@propagate_inbounds getindex(A, args...) = unsafe_get_index(A, to_indices(A, args))
395395
@propagate_inbounds function getindex(A; kwargs...)
396-
return unsafe_get_index(A, to_indices(A, order_named_inds(dimnames(A), kwargs.data)))
396+
return unsafe_get_index(A, to_indices(A, order_named_inds(dimnames(A), values(kwargs))))
397397
end
398398
@propagate_inbounds getindex(x::Tuple, i::Int) = getfield(x, i)
399399
@propagate_inbounds getindex(x::Tuple, ::StaticInt{i}) where {i} = getfield(x, i)
@@ -512,7 +512,7 @@ Store the given values at the given key or index within a collection.
512512
end
513513
end
514514
@propagate_inbounds function setindex!(A, val; kwargs...)
515-
return unsafe_set_index!(A, val, to_indices(A, order_named_inds(dimnames(A), kwargs.data)))
515+
return unsafe_set_index!(A, val, to_indices(A, order_named_inds(dimnames(A), values(kwargs))))
516516
end
517517

518518
unsafe_set_index!(A, v, inds::Tuple) = _unsafe_set_index!(_is_element_index(inds), A, v, inds)

src/ranges.jl

Lines changed: 63 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,13 @@
55
If `first` of an instance of type `T` is known at compile time, return it.
66
Otherwise, return `nothing`.
77
8-
@test isnothing(known_first(typeof(1:4)))
9-
@test isone(known_first(typeof(Base.OneTo(4))))
8+
```julia
9+
julia> ArrayInterface.known_first(typeof(1:4))
10+
nothing
11+
12+
julia> ArrayInterface.known_first(typeof(Base.OneTo(4)))
13+
1
14+
```
1015
"""
1116
known_first(x) = known_first(typeof(x))
1217
function known_first(::Type{T}) where {T}
@@ -24,9 +29,14 @@ known_first(::Type{Base.OneTo{T}}) where {T} = one(T)
2429
If `last` of an instance of type `T` is known at compile time, return it.
2530
Otherwise, return `nothing`.
2631
27-
@test isnothing(known_last(typeof(1:4)))
28-
using StaticArrays
29-
@test known_last(typeof(SOneTo(4))) == 4
32+
```julia
33+
julia> ArrayInterface.known_last(typeof(1:4))
34+
nothing
35+
36+
julia> ArrayInterface.known_first(typeof(static(1):static(4)))
37+
4
38+
39+
```
3040
"""
3141
known_last(x) = known_last(typeof(x))
3242
function known_last(::Type{T}) where {T}
@@ -43,8 +53,14 @@ end
4353
If `step` of an instance of type `T` is known at compile time, return it.
4454
Otherwise, return `nothing`.
4555
46-
@test isnothing(known_step(typeof(1:0.2:4)))
47-
@test isone(known_step(typeof(1:4)))
56+
```julia
57+
julia> ArrayInterface.known_step(typeof(1:2:8))
58+
nothing
59+
60+
julia> ArrayInterface.known_step(typeof(1:4))
61+
1
62+
63+
```
4864
"""
4965
known_step(x) = known_step(typeof(x))
5066
function known_step(::Type{T}) where {T}
@@ -65,20 +81,15 @@ at compile time. An `OptionallyStaticUnitRange` is intended to be constructed in
6581
from other valid indices. Therefore, users should not expect the same checks are used
6682
to ensure construction of a valid `OptionallyStaticUnitRange` as a `UnitRange`.
6783
"""
68-
struct OptionallyStaticUnitRange{F<:Integer,L<:Integer} <: AbstractUnitRange{Int}
84+
struct OptionallyStaticUnitRange{F<:CanonicalInt,L<:CanonicalInt} <: AbstractUnitRange{Int}
6985
start::F
7086
stop::L
7187

72-
function OptionallyStaticUnitRange(start, stop)
73-
if eltype(start) <: Int
74-
if eltype(stop) <: Int
75-
return new{typeof(start),typeof(stop)}(start, stop)
76-
else
77-
return OptionallyStaticUnitRange(start, Int(stop))
78-
end
79-
else
80-
return OptionallyStaticUnitRange(Int(start), stop)
81-
end
88+
function OptionallyStaticUnitRange(start::CanonicalInt, stop::CanonicalInt)
89+
new{typeof(start),typeof(stop)}(start, stop)
90+
end
91+
function OptionallyStaticUnitRange(start::Integer, stop::Integer)
92+
OptionallyStaticUnitRange(canonicalize(start), canonicalize(stop))
8293
end
8394

8495
function OptionallyStaticUnitRange{F,L}(x::AbstractRange) where {F,L}
@@ -138,24 +149,18 @@ julia> ArrayInterface.OptionallyStaticStepRange(x, x, 10)
138149
ArrayInterface.StaticInt{2}():ArrayInterface.StaticInt{2}():10
139150
```
140151
"""
141-
struct OptionallyStaticStepRange{F<:Integer,S<:Integer,L<:Integer} <: OrdinalRange{Int,Int}
152+
struct OptionallyStaticStepRange{F<:CanonicalInt,S<:CanonicalInt,L<:CanonicalInt} <: OrdinalRange{Int,Int}
142153
start::F
143154
step::S
144155
stop::L
145156

146-
function OptionallyStaticStepRange(start, step, stop)
147-
if eltype(start) <: Int
148-
if eltype(stop) <: Int
149-
lst = _steprange_last(start, step, stop)
150-
return new{typeof(start),typeof(step),typeof(lst)}(start, step, lst)
151-
else
152-
return OptionallyStaticStepRange(start, step, Int(stop))
153-
end
154-
else
155-
return OptionallyStaticStepRange(Int(start), step, stop)
156-
end
157+
function OptionallyStaticStepRange(start::CanonicalInt, step::CanonicalInt, stop::CanonicalInt)
158+
lst = _steprange_last(start, step, stop)
159+
new{typeof(start),typeof(step),typeof(lst)}(start, step, lst)
160+
end
161+
function OptionallyStaticStepRange(start::Integer, step::Integer, stop::Integer)
162+
OptionallyStaticStepRange(canonicalize(start), canonicalize(step), canonicalize(stop))
157163
end
158-
159164
function OptionallyStaticStepRange(x::AbstractRange)
160165
return OptionallyStaticStepRange(static_first(x), static_step(x), static_last(x))
161166
end
@@ -277,22 +282,6 @@ end
277282
unsafe_isempty_one_to(lst) = lst <= zero(lst)
278283
unsafe_isempty_unit_range(fst, lst) = fst > lst
279284

280-
unsafe_length_one_to(lst::Int) = lst
281-
unsafe_length_one_to(::StaticInt{L}) where {L} = L
282-
283-
# TODO this should probably be renamed because the point is that it is safe
284-
@inline function unsafe_length_step_range(start::Int, step::Int, stop::Int)
285-
if step > 1
286-
return Base.checked_add(Int(div(unsigned(stop - start), step)), 1)
287-
elseif step < -1
288-
return Base.checked_add(Int(div(unsigned(start - stop), -step)), 1)
289-
elseif step > 0
290-
return Base.checked_add(Int(div(Base.checked_sub(stop, start), step)), 1)
291-
else
292-
return Base.checked_add(Int(div(Base.checked_sub(start, stop), -step)), 1)
293-
end
294-
end
295-
296285
@propagate_inbounds function Base.getindex(
297286
r::OptionallyStaticUnitRange,
298287
s::AbstractUnitRange{<:Integer},
@@ -347,81 +336,54 @@ end
347336
return x
348337
end
349338

350-
###
351-
### length
352-
###
339+
## length
353340
@inline function known_length(::Type{T}) where {T<:OptionallyStaticUnitRange}
354-
fst = known_first(T)
355-
lst = known_last(T)
356-
if fst === nothing || lst === nothing
357-
return nothing
358-
else
359-
if fst === oneunit(eltype(T))
360-
return unsafe_length_one_to(lst)
361-
else
362-
return unsafe_length_unit_range(fst, lst)
363-
end
364-
end
341+
return _range_length(known_first(T), known_last(T))
365342
end
366343

367344
@inline function known_length(::Type{T}) where {T<:OptionallyStaticStepRange}
368-
fst = known_first(T)
369-
stp = known_step(T)
370-
lst = known_last(T)
371-
if fst === nothing || stp === nothing || lst === nothing
372-
return nothing
373-
else
374-
if stp === 1
375-
if fst === oneunit(eltype(T))
376-
return unsafe_length_one_to(lst)
377-
else
378-
return unsafe_length_unit_range(fst, lst)
379-
end
380-
else
381-
return unsafe_length_step_range(fst, stp, lst)
382-
end
383-
end
345+
_range_length(known_first(T), known_step(T), known_last(T))
384346
end
385347

386-
function Base.length(r::OptionallyStaticUnitRange)
348+
@inline function Base.length(r::OptionallyStaticUnitRange)
387349
if isempty(r)
388350
return 0
389351
else
390-
if known_first(r) === 1
391-
return unsafe_length_one_to(last(r))
392-
else
393-
return unsafe_length_unit_range(first(r), last(r))
394-
end
352+
return _range_length(static_first(r), static_last(r))
395353
end
396354
end
397355

398-
function Base.length(r::OptionallyStaticStepRange)
356+
@inline function Base.length(r::OptionallyStaticStepRange)
399357
if isempty(r)
400358
return 0
401359
else
402-
if known_step(r) === 1
403-
if known_first(r) === 1
404-
return unsafe_length_one_to(last(r))
405-
else
406-
return unsafe_length_unit_range(first(r), last(r))
407-
end
408-
else
409-
return unsafe_length_step_range(Int(first(r)), Int(step(r)), Int(last(r)))
410-
end
360+
return _range_length(static_first(r), static_step(r), static_last(r))
411361
end
412362
end
413363

414-
unsafe_length_unit_range(start::Integer, stop::Integer) = Int((stop - start) + 1)
364+
_range_length(::StaticInt{1}, stop::Integer) = Int(stop)
365+
_range_length(start::Integer, stop::Integer) = Int((stop - start) + 1)
366+
_range_length(start, stop) = nothing
367+
_range_length(start::Integer, ::StaticInt{1}, stop::Integer) = _range_length(start, stop)
368+
@inline function _range_length(start::Integer, step::Integer, stop::Integer)
369+
if step > 1
370+
return Base.checked_add(Int(div(unsigned(stop - start), step)), 1)
371+
elseif step < -1
372+
return Base.checked_add(Int(div(unsigned(start - stop), -step)), 1)
373+
elseif step > 0
374+
return Base.checked_add(Int(div(Base.checked_sub(stop, start), step)), 1)
375+
else
376+
return Base.checked_add(Int(div(Base.checked_sub(start, stop), -step)), 1)
377+
end
378+
end
379+
_range_length(start, step, stop) = nothing
415380

381+
Base.AbstractUnitRange{Int}(r::OptionallyStaticUnitRange) = r
416382
function Base.AbstractUnitRange{T}(r::OptionallyStaticUnitRange) where {T}
417-
if T <: Int
418-
return r
383+
if known_first(r) === 1 && T <: Integer
384+
return OneTo{T}(last(r))
419385
else
420-
if known_first(r) === 1 && T <: Integer
421-
return OneTo{T}(last(r))
422-
else
423-
return UnitRange{T}(first(r), last(r))
424-
end
386+
return UnitRange{T}(first(r), last(r))
425387
end
426388
end
427389

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,6 @@ end
734734
@test @inferred(ArrayInterface.strides(Ac2t)) === (StaticInt(1), 5)
735735
Ac2t_static = reinterpret(reshape, Tuple{Float64,Float64}, view(@MMatrix(rand(ComplexF64, 5, 7)), 2:4, 3:6));
736736
@test @inferred(ArrayInterface.strides(Ac2t_static)) === (StaticInt(1), StaticInt(5))
737-
738737
end
739738
end
740739

0 commit comments

Comments
 (0)