Skip to content

Commit f86ade8

Browse files
authored
Merge pull request #74 from Tokazama/steprange
OptionallyStaticStepRange
2 parents 22c64ef + 693be0a commit f86ade8

File tree

4 files changed

+218
-33
lines changed

4 files changed

+218
-33
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "2.13.3"
3+
version = "2.13.4"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/ranges.jl

Lines changed: 176 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)
4343
# add methods to support ArrayInterface
4444

4545
"""
46-
OptionallyStaticUnitRange{T<:Integer}(start, stop) <: OrdinalRange{T,T}
46+
OptionallyStaticUnitRange(start, stop) <: AbstractUnitRange{Int}
4747
4848
This range permits diverse representations of arrays to comunicate common information
4949
about their indices. Each field may be an integer or `Val(<:Integer)` if it is known
@@ -67,21 +67,15 @@ struct OptionallyStaticUnitRange{F <: Integer, L <: Integer} <: AbstractUnitRang
6767
end
6868
end
6969

70-
function OptionallyStaticUnitRange(x::AbstractRange)
70+
function OptionallyStaticUnitRange(x::AbstractRange)
7171
if step(x) == 1
72-
fst = static_first(x)
73-
lst = static_last(x)
74-
return OptionallyStaticUnitRange(fst, lst)
72+
return OptionallyStaticUnitRange(static_first(x), static_last(x))
7573
else
7674
throw(ArgumentError("step must be 1, got $(step(r))"))
7775
end
7876
end
7977
end
8078

81-
Base.:(:)(L::Integer, ::StaticInt{U}) where {U} = OptionallyStaticUnitRange(L, StaticInt(U))
82-
Base.:(:)(::StaticInt{L}, U::Integer) where {L} = OptionallyStaticUnitRange(StaticInt(L), U)
83-
Base.:(:)(::StaticInt{L}, ::StaticInt{U}) where {L,U} = OptionallyStaticUnitRange(StaticInt(L), StaticInt(U))
84-
8579
Base.first(r::OptionallyStaticUnitRange) = r.start
8680
Base.step(::OptionallyStaticUnitRange) = StaticInt(1)
8781
Base.last(r::OptionallyStaticUnitRange) = r.stop
@@ -90,6 +84,110 @@ known_first(::Type{<:OptionallyStaticUnitRange{StaticInt{F}}}) where {F} = F
9084
known_step(::Type{<:OptionallyStaticUnitRange}) = 1
9185
known_last(::Type{<:OptionallyStaticUnitRange{<:Any,StaticInt{L}}}) where {L} = L
9286

87+
"""
88+
OptionallyStaticStepRange(start, step, stop) <: OrdinalRange{Int,Int}
89+
90+
Similar to [`OptionallyStaticUnitRange`](@ref), `OptionallyStaticStepRange` permits
91+
a combination of static and standard primitive `Int`s to construct a range. It
92+
specifically enables the use of ranges without a step size of 1. It may be constructed
93+
through the use of `OptionallyStaticStepRange` directly or using static integers with
94+
the range operatore (i.e. `:`).
95+
96+
```julia
97+
julia> using ArrayInterface
98+
99+
julia> x = ArrayInterface.StaticInt(2);
100+
101+
julia> x:x:10
102+
ArrayInterface.StaticInt{2}():ArrayInterface.StaticInt{2}():10
103+
104+
julia> ArrayInterface.OptionallyStaticStepRange(x, x, 10)
105+
ArrayInterface.StaticInt{2}():ArrayInterface.StaticInt{2}():10
106+
```
107+
"""
108+
struct OptionallyStaticStepRange{F <: Integer, S <: Integer, L <: Integer} <: OrdinalRange{Int,Int}
109+
start::F
110+
step::S
111+
stop::L
112+
113+
function OptionallyStaticStepRange(start, step, stop)
114+
if eltype(start) <: Int
115+
if eltype(stop) <: Int
116+
lst = _steprange_last(start, step, stop)
117+
return new{typeof(start),typeof(step),typeof(lst)}(start, step, lst)
118+
else
119+
return OptionallyStaticStepRange(start, step, Int(stop))
120+
end
121+
else
122+
return OptionallyStaticStepRange(Int(start), step, stop)
123+
end
124+
end
125+
126+
function OptionallyStaticStepRange(x::AbstractRange)
127+
return OptionallyStaticStepRange(static_first(x), static_step(x), static_last(x))
128+
end
129+
end
130+
131+
# to make StepRange constructor inlineable, so optimizer can see `step` value
132+
@inline function _steprange_last(start::StaticInt, step::StaticInt, stop::StaticInt)
133+
return StaticInt(_steprange_last(Int(start), Int(step), Int(stop)))
134+
end
135+
@inline function _steprange_last(start::Integer, step::StaticInt, stop::StaticInt)
136+
if step === one(step)
137+
# we don't need to check the `stop` if we know it acts like a unit range
138+
return stop
139+
else
140+
return _steprange_last(start, Int(step), Int(stop))
141+
end
142+
end
143+
@inline function _steprange_last(start::Integer, step::Integer, stop::Integer)
144+
z = zero(step)
145+
if step === z
146+
throw(ArgumentError("step cannot be zero"))
147+
else
148+
if stop == start
149+
return Int(stop)
150+
else
151+
if step > z
152+
if stop > start
153+
return stop - Int(unsigned(stop - start) % step)
154+
else
155+
return Int(start - one(start))
156+
end
157+
else
158+
if stop > start
159+
return Int(start + one(start))
160+
else
161+
return stop + Int(unsigned(start - stop) % -step)
162+
end
163+
end
164+
end
165+
end
166+
end
167+
Base.first(r::OptionallyStaticStepRange) = r.start
168+
Base.step(r::OptionallyStaticStepRange) = r.step
169+
Base.last(r::OptionallyStaticStepRange) = r.stop
170+
171+
known_first(::Type{<:OptionallyStaticStepRange{StaticInt{F}}}) where {F} = F
172+
known_step(::Type{<:OptionallyStaticStepRange{<:Any,StaticInt{S}}}) where {S} = S
173+
known_last(::Type{<:OptionallyStaticStepRange{<:Any,<:Any,StaticInt{L}}}) where {L} = L
174+
175+
Base.:(:)(L::Integer, ::StaticInt{U}) where {U} = OptionallyStaticUnitRange(L, StaticInt(U))
176+
Base.:(:)(::StaticInt{L}, U::Integer) where {L} = OptionallyStaticUnitRange(StaticInt(L), U)
177+
Base.:(:)(::StaticInt{L}, ::StaticInt{U}) where {L,U} = OptionallyStaticUnitRange(StaticInt(L), StaticInt(U))
178+
Base.:(:)(::StaticInt{F}, ::StaticInt{S}, ::StaticInt{L}) where {F,S,L} = OptionallyStaticStepRange(StaticInt(F), StaticInt(S), StaticInt(L))
179+
Base.:(:)(start::Integer, ::StaticInt{S}, ::StaticInt{L}) where {S,L} = OptionallyStaticStepRange(start, StaticInt(S), StaticInt(L))
180+
Base.:(:)(::StaticInt{F}, ::StaticInt{S}, stop::Integer) where {F,S} = OptionallyStaticStepRange(StaticInt(F), StaticInt(S), stop)
181+
Base.:(:)(::StaticInt{F}, step::Integer, ::StaticInt{L}) where {F,L} = OptionallyStaticStepRange(StaticInt(F), step, StaticInt(L))
182+
Base.:(:)(start::Integer, step::Integer, ::StaticInt{L}) where {L} = OptionallyStaticStepRange(start, step, StaticInt(L))
183+
Base.:(:)(start::Integer, ::StaticInt{S}, stop::Integer) where {S} = OptionallyStaticStepRange(start, StaticInt(S), stop)
184+
Base.:(:)(::StaticInt{F}, step::Integer, stop::Integer) where {F} = OptionallyStaticStepRange(StaticInt(F), step, stop)
185+
Base.:(:)(::StaticInt{F}, ::StaticInt{1}, ::StaticInt{L}) where {F,L} = OptionallyStaticUnitRange(StaticInt(F), StaticInt(L))
186+
Base.:(:)(start::Integer, ::StaticInt{1}, ::StaticInt{L}) where {L} = OptionallyStaticUnitRange(start, StaticInt(L))
187+
Base.:(:)(::StaticInt{F}, ::StaticInt{1}, stop::Integer) where {F} = OptionallyStaticUnitRange(StaticInt(F), stop)
188+
Base.:(:)(start::Integer, ::StaticInt{1}, stop::Integer) = OptionallyStaticUnitRange(start, stop)
189+
190+
93191
function Base.isempty(r::OptionallyStaticUnitRange)
94192
if known_first(r) === oneunit(eltype(r))
95193
return unsafe_isempty_one_to(last(r))
@@ -98,13 +196,29 @@ function Base.isempty(r::OptionallyStaticUnitRange)
98196
end
99197
end
100198

199+
function Base.isempty(r::OptionallyStaticStepRange)
200+
return (r.start != r.stop) & ((r.step > zero(r.step)) != (r.stop > r.start))
201+
end
202+
101203
unsafe_isempty_one_to(lst) = lst <= zero(lst)
102204
unsafe_isempty_unit_range(fst, lst) = fst > lst
103205

104206
unsafe_length_one_to(lst::Int) = lst
105-
unsafe_length_one_to(::StaticInt{L}) where {L} = lst
207+
unsafe_length_one_to(::StaticInt{L}) where {L} = L
208+
209+
@inline function unsafe_length_step_range(start::Int, step::Int, stop::Int)
210+
if step > 1
211+
return Base.checked_add(Int(div(unsigned(stop - start), step)), 1)
212+
elseif step < -1
213+
return Base.checked_add(Int(div(unsigned(start - stop), -step)), 1)
214+
elseif step > 0
215+
return Base.checked_add(Int(div(Base.checked_sub(stop, start), step)), 1)
216+
else
217+
return Base.checked_add(Int(div(Base.checked_sub(rtart, stop), -step)), 1)
218+
end
219+
end
106220

107-
Base.@propagate_inbounds function Base.getindex(r::OptionallyStaticUnitRange, i::Integer)
221+
@propagate_inbounds function Base.getindex(r::OptionallyStaticUnitRange, i::Integer)
108222
if known_first(r) === oneunit(eltype(r))
109223
return get_index_one_to(r, i)
110224
else
@@ -121,7 +235,7 @@ end
121235

122236
@inline function get_index_unit_range(r, i)
123237
val = first(r) + (i - 1)
124-
@boundscheck if (i < 1) || (val > last(r) && val < first(r))
238+
@boundscheck if (i < 1) || val > last(r)
125239
throw(BoundsError(r, i))
126240
end
127241
return convert(eltype(r), val)
@@ -130,28 +244,28 @@ end
130244
@inline _try_static(::StaticInt{N}, ::StaticInt{N}) where {N} = StaticInt{N}()
131245
@inline _try_static(::StaticInt{M}, ::StaticInt{N}) where {M, N} = @assert false "Unequal Indices: StaticInt{$M}() != StaticInt{$N}()"
132246
@propagate_inbounds function _try_static(::StaticInt{N}, x) where {N}
133-
@boundscheck begin
134-
@assert N == x "Unequal Indices: StaticInt{$N}() != x == $x"
135-
end
136-
return StaticInt{N}()
247+
@boundscheck begin
248+
@assert N == x "Unequal Indices: StaticInt{$N}() != x == $x"
249+
end
250+
return StaticInt{N}()
137251
end
138252
@propagate_inbounds function _try_static(x, ::StaticInt{N}) where {N}
139-
@boundscheck begin
140-
@assert N == x "Unequal Indices: x == $x != StaticInt{$N}()"
141-
end
142-
return StaticInt{N}()
253+
@boundscheck begin
254+
@assert N == x "Unequal Indices: x == $x != StaticInt{$N}()"
255+
end
256+
return StaticInt{N}()
143257
end
144258
@propagate_inbounds function _try_static(x, y)
145-
@boundscheck begin
146-
@assert x == y "Unequal Indicess: x == $x != $y == y"
147-
end
148-
return x
259+
@boundscheck begin
260+
@assert x == y "Unequal Indicess: x == $x != $y == y"
261+
end
262+
return x
149263
end
150264

151265
###
152266
### length
153267
###
154-
@inline function known_length(::Type{T}) where {T<:AbstractUnitRange}
268+
@inline function known_length(::Type{T}) where {T<:OptionallyStaticUnitRange}
155269
fst = known_first(T)
156270
lst = known_last(T)
157271
if fst === nothing || lst === nothing
@@ -165,6 +279,25 @@ end
165279
end
166280
end
167281

282+
@inline function known_length(::Type{T}) where {T<:OptionallyStaticStepRange}
283+
fst = known_first(T)
284+
stp = known_step(T)
285+
lst = known_last(T)
286+
if fst === nothing || stp === nothing || lst === nothing
287+
return nothing
288+
else
289+
if stp === 1
290+
if fst === oneunit(eltype(T))
291+
return unsafe_length_one_to(lst)
292+
else
293+
return unsafe_length_unit_range(fst, lst)
294+
end
295+
else
296+
return unsafe_length_step_range(fst, stp, lst)
297+
end
298+
end
299+
end
300+
168301
function Base.length(r::OptionallyStaticUnitRange)
169302
if isempty(r)
170303
return 0
@@ -177,6 +310,23 @@ function Base.length(r::OptionallyStaticUnitRange)
177310
end
178311
end
179312

313+
function Base.length(r::OptionallyStaticStepRange)
314+
if isempty(r)
315+
return 0
316+
else
317+
if known_step(r) === 1
318+
if known_first(r) === 1
319+
return unsafe_length_one_to(last(r))
320+
else
321+
return unsafe_length_unit_range(first(r), last(r))
322+
end
323+
else
324+
return unsafe_length_step_range(Int(first(r)), Int(step(r)), Int(last(r)))
325+
end
326+
end
327+
end
328+
329+
180330
unsafe_length_unit_range(start::Integer, stop::Integer) = Int((stop - start) + 1)
181331

182332
"""
@@ -219,3 +369,4 @@ end
219369
lst = _try_static(static_last(x), static_last(y))
220370
return Base.Slice(OptionallyStaticUnitRange(fst, lst))
221371
end
372+

src/static.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ StaticInt(::StaticInt{N}) where {N} = StaticInt{N}()
1616
StaticInt(::Val{N}) where {N} = StaticInt{N}()
1717
# Base.Val(::StaticInt{N}) where {N} = Val{N}()
1818
Base.convert(::Type{T}, ::StaticInt{N}) where {T<:Number,N} = convert(T, N)
19+
Base.Bool(x::StaticInt{N}) where {N} = Bool(N)
20+
Base.BigInt(x::StaticInt{N}) where {N} = BigInt(N)
21+
Base.Integer(x::StaticInt{N}) where {N} = x
22+
(::Type{T})(x::StaticInt{N}) where {T<:Integer,N} = T(N)
1923
(::Type{T})(x::Int) where {T<:StaticInt} = StaticInt(x)
2024
Base.convert(::Type{StaticInt{N}}, ::StaticInt{N}) where {N} = StaticInt{N}()
2125

test/runtests.jl

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,21 @@ using ArrayInterface: parent_type
190190
end
191191

192192
@testset "Range Interface" begin
193+
@testset "Range Constructors" begin
194+
@test @inferred(StaticInt(1):StaticInt(10)) == 1:10
195+
@test @inferred(StaticInt(1):StaticInt(2):StaticInt(10)) == 1:2:10
196+
@test @inferred(1:StaticInt(2):StaticInt(10)) == 1:2:10
197+
@test @inferred(StaticInt(1):StaticInt(2):10) == 1:2:10
198+
@test @inferred(StaticInt(1):2:StaticInt(10)) == 1:2:10
199+
@test @inferred(1:2:StaticInt(10)) == 1:2:10
200+
@test @inferred(1:StaticInt(2):10) == 1:2:10
201+
@test @inferred(StaticInt(1):2:10) == 1:2:10
202+
203+
@test @inferred(StaticInt(1):StaticInt(1):StaticInt(10)) === ArrayInterface.OptionallyStaticUnitRange(StaticInt(1), StaticInt(10))
204+
@test @inferred(StaticInt(1):StaticInt(1):10) === ArrayInterface.OptionallyStaticUnitRange(StaticInt(1), 10)
205+
@test @inferred(1:StaticInt(1):10) === ArrayInterface.OptionallyStaticUnitRange(1, 10)
206+
end
207+
193208
@test isnothing(@inferred(ArrayInterface.known_first(typeof(1:4))))
194209
@test isone(@inferred(ArrayInterface.known_first(Base.OneTo(4))))
195210
@test isone(@inferred(ArrayInterface.known_first(typeof(Base.OneTo(4)))))
@@ -201,14 +216,24 @@ end
201216
@test isone(@inferred(ArrayInterface.known_step(1:4)))
202217
@test isone(@inferred(ArrayInterface.known_step(typeof(1:4))))
203218

204-
@test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(1, 0))) == 0
205-
@test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(1, 10))) == 10
206-
@test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(StaticInt(1), 10))) == 10
207-
@test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(StaticInt(0), 10))) == 11
219+
@testset "length" begin
220+
@test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(1, 0))) == 0
221+
@test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(1, 10))) == 10
222+
@test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(StaticInt(1), 10))) == 10
223+
@test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(StaticInt(0), 10))) == 11
224+
@test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(StaticInt(1), StaticInt(10)))) == 10
225+
@test @inferred(length(ArrayInterface.OptionallyStaticUnitRange(StaticInt(0), StaticInt(10)))) == 11
226+
227+
@test @inferred(length(StaticInt(1):StaticInt(2):StaticInt(0))) == 0
228+
@test @inferred(length(StaticInt(0):StaticInt(-2):StaticInt(1))) == 0
229+
end
208230
@test @inferred(getindex(ArrayInterface.OptionallyStaticUnitRange(StaticInt(1), 10), 1)) == 1
209231
@test @inferred(getindex(ArrayInterface.OptionallyStaticUnitRange(StaticInt(0), 10), 1)) == 0
210232
@test_throws BoundsError getindex(ArrayInterface.OptionallyStaticUnitRange(StaticInt(1), 10), 0)
211-
@test_throws BoundsError getindex(ArrayInterface.OptionallyStaticUnitRange(StaticInt(1), 10), 0)
233+
@test_throws BoundsError getindex(ArrayInterface.OptionallyStaticStepRange(StaticInt(1), 2, 10), 0)
234+
@test_throws BoundsError getindex(ArrayInterface.OptionallyStaticUnitRange(StaticInt(1), 10), 11)
235+
@test_throws BoundsError getindex(ArrayInterface.OptionallyStaticStepRange(StaticInt(1), 2, 10), 11)
236+
212237
end
213238

214239
@testset "Memory Layout" begin
@@ -287,12 +312,12 @@ using OffsetArrays
287312
M = @MArray zeros(2,3,4); Mp = @view(PermutedDimsArray(M,(3,1,2))[:,2,:])';
288313
Sp2 = @view(PermutedDimsArray(S,(3,2,1))[2:3,:,:]);
289314
Mp2 = @view(PermutedDimsArray(M,(3,1,2))[2:3,:,2])';
290-
315+
291316
@test @inferred(ArrayInterface.size(A)) === (3,4,5)
292317
@test @inferred(ArrayInterface.size(Ap)) === (2,5)
293318
@test @inferred(ArrayInterface.size(A)) === size(A)
294319
@test @inferred(ArrayInterface.size(Ap)) === size(Ap)
295-
320+
296321
@test @inferred(ArrayInterface.size(S)) === (StaticInt(2), StaticInt(3), StaticInt(4))
297322
@test @inferred(ArrayInterface.size(Sp)) === (2, 2, StaticInt(3))
298323
@test @inferred(ArrayInterface.size(Sp2)) === (2, StaticInt(3), StaticInt(2))
@@ -405,6 +430,11 @@ end
405430
@test @inferred(one(StaticInt)) === StaticInt(1)
406431
@test @inferred(zero(StaticInt)) === StaticInt(0)
407432
@test eltype(one(StaticInt)) <: Int
433+
434+
x = StaticInt(1)
435+
@test @inferred(Bool(x)) isa Bool
436+
@test @inferred(BigInt(x)) isa BigInt
437+
@test @inferred(Integer(x)) === x
408438
# test for ambiguities and correctness
409439
for i [StaticInt(0), StaticInt(1), StaticInt(2), 3]
410440
for j [StaticInt(0), StaticInt(1), StaticInt(2), 3]

0 commit comments

Comments
 (0)