Skip to content

Commit 5987bae

Browse files
authored
SUnitRange and SOneTo aliases (#221)
Help out StaticArrays by providing a more comprehensive version of SUnitRange and SOneTo by only adding two new aliases here. All other changes here are improvements based on comparison to code found in StaticArrays.jl. I also fixed a bug where known_length gave errors when representing an empty range.
1 parent 98b8622 commit 5987bae

File tree

3 files changed

+81
-85
lines changed

3 files changed

+81
-85
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "3.1.35"
3+
version = "3.1.36"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/ranges.jl

Lines changed: 78 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -101,33 +101,31 @@ struct OptionallyStaticUnitRange{F<:CanonicalInt,L<:CanonicalInt} <: AbstractUni
101101
OptionallyStaticUnitRange(canonicalize(start), canonicalize(stop))
102102
end
103103

104-
function OptionallyStaticUnitRange{F,L}(x::AbstractRange) where {F,L}
105-
if step(x) == 1
106-
return OptionallyStaticUnitRange(static_first(x), static_last(x))
107-
else
108-
throw(ArgumentError("step must be 1, got $(step(x))"))
109-
end
110-
end
111-
112104
function OptionallyStaticUnitRange(x::AbstractRange)
113-
if step(x) == 1
114-
return OptionallyStaticUnitRange(static_first(x), static_last(x))
115-
else
116-
throw(ArgumentError("step must be 1, got $(step(x))"))
117-
end
105+
step(x) == 1 && return OptionallyStaticUnitRange(static_first(x), static_last(x))
106+
107+
errmsg(x) = throw(ArgumentError("step must be 1, got $(step(x))")) # avoid GC frame
108+
errmsg(x)
109+
end
110+
OptionallyStaticUnitRange{F,L}(x::AbstractRange) where {F,L} = OptionallyStaticUnitRange(x)
111+
function OptionallyStaticUnitRange{StaticInt{F},StaticInt{L}}() where {F,L}
112+
new{StaticInt{F},StaticInt{L}}()
118113
end
119114
end
120115

116+
const SUnitRange{F,L} = OptionallyStaticUnitRange{StaticInt{F},StaticInt{L}}
117+
const SOneTo{L} = SUnitRange{1,L}
118+
121119
function Base.first(r::OptionallyStaticUnitRange)::Int
122120
if known_first(r) === nothing
123-
return r.start
121+
return getfield(r, :start)
124122
else
125123
return known_first(r)
126124
end
127125
end
128126
function Base.last(r::OptionallyStaticUnitRange)::Int
129127
if known_last(r) === nothing
130-
return r.stop
128+
return getfield(r, :stop)
131129
else
132130
return known_last(r)
133131
end
@@ -214,21 +212,21 @@ end
214212
end
215213
function Base.first(r::OptionallyStaticStepRange)::Int
216214
if known_first(r) === nothing
217-
return r.start
215+
return getfield(r, :start)
218216
else
219217
return known_first(r)
220218
end
221219
end
222220
function Base.step(r::OptionallyStaticStepRange)::Int
223221
if known_step(r) === nothing
224-
return r.step
222+
return getfield(r, :step)
225223
else
226224
return known_step(r)
227225
end
228226
end
229227
function Base.last(r::OptionallyStaticStepRange)::Int
230228
if known_last(r) === nothing
231-
return r.stop
229+
return getfield(r, :stop)
232230
else
233231
return known_last(r)
234232
end
@@ -264,33 +262,27 @@ end
264262
function Base.:(:)(::StaticInt{F}, step::Integer, stop::Integer) where {F}
265263
return OptionallyStaticStepRange(StaticInt(F), step, stop)
266264
end
267-
function Base.:(:)(::StaticInt{F}, ::StaticInt{1}, ::StaticInt{L}) where {F,L}
268-
return OptionallyStaticUnitRange(StaticInt(F), StaticInt(L))
269-
end
270-
function Base.:(:)(start::Integer, ::StaticInt{1}, ::StaticInt{L}) where {L}
271-
return OptionallyStaticUnitRange(start, StaticInt(L))
272-
end
273-
function Base.:(:)(::StaticInt{F}, ::StaticInt{1}, stop::Integer) where {F}
274-
return OptionallyStaticUnitRange(StaticInt(F), stop)
275-
end
265+
Base.:(:)(start::StaticInt{F}, ::StaticInt{1}, stop::StaticInt{L}) where {F,L} = start:stop
266+
Base.:(:)(start::Integer, ::StaticInt{1}, stop::StaticInt{L}) where {L} = start:stop
267+
Base.:(:)(start::StaticInt{F}, ::StaticInt{1}, stop::Integer) where {F} = start:stop
276268
function Base.:(:)(start::Integer, ::StaticInt{1}, stop::Integer)
277-
return OptionallyStaticUnitRange(start, stop)
278-
end
279-
280-
function Base.isempty(r::OptionallyStaticUnitRange)
281-
if known_first(r) === oneunit(eltype(r))
282-
return unsafe_isempty_one_to(last(r))
283-
else
284-
return unsafe_isempty_unit_range(first(r), last(r))
285-
end
269+
OptionallyStaticUnitRange(start, stop)
286270
end
287271

272+
Base.isempty(r::OptionallyStaticUnitRange{One}) = last(r) <= 0
273+
Base.isempty(r::OptionallyStaticUnitRange) = first(r) > last(r)
288274
function Base.isempty(r::OptionallyStaticStepRange)
289-
return (r.start != r.stop) & ((r.step > zero(r.step)) != (r.stop > r.start))
275+
(r.start != r.stop) & ((r.step > 0) != (r.stop > r.start))
290276
end
291277

292-
unsafe_isempty_one_to(lst) = lst <= zero(lst)
293-
unsafe_isempty_unit_range(fst, lst) = fst > lst
278+
function Base.checkindex(
279+
::Type{Bool},
280+
::SUnitRange{F1,L1},
281+
::SUnitRange{F2,L2}
282+
) where {F1,L1,F2,L2}
283+
284+
(F1::Int <= F2::Int) && (L1::Int >= L2::Int)
285+
end
294286

295287
@propagate_inbounds function Base.getindex(
296288
r::OptionallyStaticUnitRange,
@@ -302,27 +294,14 @@ unsafe_isempty_unit_range(fst, lst) = fst > lst
302294
return (fnew+static_first(s)):(fnew+static_last(s))
303295
end
304296

305-
@propagate_inbounds function Base.getindex(r::OptionallyStaticUnitRange, i::Integer)
306-
if known_first(r) === oneunit(eltype(r))
307-
return get_index_one_to(r, i)
308-
else
309-
return get_index_unit_range(r, i)
310-
end
311-
end
312-
313-
@inline function get_index_one_to(r, i)
314-
@boundscheck if ((i < 1) || (i > last(r)))
315-
throw(BoundsError(r, i))
316-
end
317-
return convert(eltype(r), i)
297+
@propagate_inbounds function Base.getindex(x::OptionallyStaticUnitRange{StaticInt{1}}, i::Int)
298+
@boundscheck checkbounds(x, i)
299+
i
318300
end
319-
320-
@inline function get_index_unit_range(r, i)
321-
val = first(r) + (i - 1)
322-
@boundscheck if (i < 1) || val > last(r)
323-
throw(BoundsError(r, i))
324-
end
325-
return convert(eltype(r), val)
301+
@propagate_inbounds function Base.getindex(x::OptionallyStaticUnitRange, i::Int)
302+
val = first(x) + (i - 1)
303+
@boundscheck ((i < 1) || val > last(x)) && throw(BoundsError(x, i))
304+
val::Int
326305
end
327306

328307
@inline _try_static(::StaticInt{N}, ::StaticInt{N}) where {N} = StaticInt{N}()
@@ -355,27 +334,25 @@ end
355334
_range_length(known_first(T), known_step(T), known_last(T))
356335
end
357336

358-
@inline function Base.length(r::OptionallyStaticUnitRange)
337+
Base.length(r::OptionallyStaticUnitRange) = _range_length(static_first(r), static_last(r))
338+
@inline function Base.length(r::OptionallyStaticStepRange)
359339
if isempty(r)
360340
return 0
361341
else
362-
return _range_length(static_first(r), static_last(r))
342+
return _range_length(static_first(r), static_step(r), static_last(r))
363343
end
364344
end
365345

366-
@inline function Base.length(r::OptionallyStaticStepRange)
367-
if isempty(r)
346+
_range_length(start, stop) = nothing
347+
function _range_length(start::CanonicalInt, stop::CanonicalInt)
348+
if start > stop
368349
return 0
369350
else
370-
return _range_length(static_first(r), static_step(r), static_last(r))
351+
return Int((stop - start) + 1)
371352
end
372353
end
373-
374-
_range_length(::StaticInt{1}, stop::Integer) = Int(stop)
375-
_range_length(start::Integer, stop::Integer) = Int((stop - start) + 1)
376-
_range_length(start, stop) = nothing
377-
_range_length(start::Integer, ::StaticInt{1}, stop::Integer) = _range_length(start, stop)
378-
@inline function _range_length(start::Integer, step::Integer, stop::Integer)
354+
_range_length(start::CanonicalInt, ::One, stop::CanonicalInt) = _range_length(start, stop)
355+
@inline function _range_length(start::CanonicalInt, step::CanonicalInt, stop::CanonicalInt)
379356
if step > 1
380357
return Base.checked_add(Int(div(unsigned(stop - start), step)), 1)
381358
elseif step < -1
@@ -405,31 +382,36 @@ Base.eachindex(r::OptionallyStaticRange) = One():static_length(r)
405382
fi = Int(first(r));
406383
fi, fi
407384
end
408-
409-
Base.to_shape(x::OptionallyStaticRange) = length(x)
410-
Base.to_shape(x::Slice{T}) where {T<:OptionallyStaticRange} = length(x)
411-
412-
@inline function Base.axes(S::Slice{T}) where {T<:OptionallyStaticRange}
413-
if known_first(T) === 1 && known_step(T) === 1
414-
return (S.indices,)
385+
function Base.iterate(::SUnitRange{F,L}) where {F,L}
386+
if L::Int < F::Int
387+
return nothing
415388
else
416-
return (Base.IdentityUnitRange(S.indices),)
389+
return (F::Int, F::Int)
417390
end
418391
end
419-
420-
@inline function Base.axes1(S::Slice{T}) where {T<:OptionallyStaticRange}
421-
if known_first(T) === 1 && known_step(T) === 1
422-
return S.indices
392+
function Base.iterate(::SOneTo{n}, s::Int) where {n}
393+
if s < n::Int
394+
s2 = s + 1
395+
return (s2, s2)
423396
else
424-
return Base.IdentityUnitRange(S.indices)
397+
return nothing
425398
end
426399
end
427400

401+
Base.to_shape(x::OptionallyStaticRange) = length(x)
402+
Base.to_shape(x::Slice{T}) where {T<:OptionallyStaticRange} = length(x)
403+
Base.axes(S::Slice{<:OptionallyStaticUnitRange{One}}) = (S.indices,)
404+
Base.axes(S::Slice{<:OptionallyStaticRange}) = (Base.IdentityUnitRange(S.indices),)
405+
406+
Base.axes1(S::Slice{<:OptionallyStaticUnitRange{One}}) = S.indices
407+
Base.axes1(S::Slice{<:OptionallyStaticRange}) = Base.IdentityUnitRange(S.indices)
408+
Base.unsafe_indices(S::Base.Slice{<:OptionallyStaticUnitRange{One}}) = (S.indices,)
409+
428410
Base.:(-)(r::OptionallyStaticRange) = -static_first(r):-static_step(r):-static_last(r)
429411

430412
Base.reverse(r::OptionallyStaticUnitRange) = static_last(r):static(-1):static_first(r)
431413
function Base.reverse(r::OptionallyStaticStepRange)
432-
return OptionallyStaticStepRange(static_last(r), -static_step(r), static_first(r))
414+
OptionallyStaticStepRange(static_last(r), -static_step(r), static_first(r))
433415
end
434416

435417
function Base.show(io::IO, ::MIME"text/plain", r::OptionallyStaticRange)
@@ -444,6 +426,18 @@ function Base.show(io::IO, ::MIME"text/plain", r::OptionallyStaticRange)
444426
print(io, static_last(r))
445427
end
446428

429+
@inline function Base.getproperty(x::OptionallyStaticRange, s::Symbol)
430+
if s === :start
431+
return first(x)
432+
elseif s === :step
433+
return step(x)
434+
elseif s === :stop
435+
return last(x)
436+
else
437+
error("$x has no property $s")
438+
end
439+
end
440+
447441
"""
448442
reduce_tup(f::F, inds::Tuple{Vararg{Any,N}}) where {F,N}
449443

test/ranges.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
@testset "Range Interface" begin
33
@testset "Range Constructors" begin
44
@test @inferred(StaticInt(1):StaticInt(10)) == 1:10
5+
@test @inferred(ArrayInterface.SUnitRange{1,10}()) == 1:10
56
@test @inferred(StaticInt(1):StaticInt(2):StaticInt(10)) == 1:2:10
67
@test @inferred(1:StaticInt(2):StaticInt(10)) == 1:2:10
78
@test @inferred(StaticInt(1):StaticInt(2):10) == 1:2:10
@@ -84,6 +85,7 @@
8485
@test @inferred(length(StaticInt(0):StaticInt(-2):StaticInt(1))) == 0
8586

8687
@test @inferred(ArrayInterface.known_length(typeof(ArrayInterface.OptionallyStaticStepRange(StaticInt(1), 2, 10)))) === nothing
88+
@test @inferred(ArrayInterface.known_length(typeof(ArrayInterface.SOneTo{-10}()))) === 0
8789
@test @inferred(ArrayInterface.known_length(typeof(ArrayInterface.OptionallyStaticStepRange(StaticInt(1), StaticInt(1), StaticInt(10))))) === 10
8890
@test @inferred(ArrayInterface.known_length(typeof(ArrayInterface.OptionallyStaticStepRange(StaticInt(2), StaticInt(1), StaticInt(10))))) === 9
8991
@test @inferred(ArrayInterface.known_length(typeof(ArrayInterface.OptionallyStaticStepRange(StaticInt(2), StaticInt(2), StaticInt(10))))) === 5

0 commit comments

Comments
 (0)