@@ -43,7 +43,7 @@ known_step(::Type{<:AbstractUnitRange{T}}) where {T} = one(T)
43
43
# add methods to support ArrayInterface
44
44
45
45
"""
46
- OptionallyStaticUnitRange{T<:Integer} (start, stop) <: OrdinalRange{T,T }
46
+ OptionallyStaticUnitRange(start, stop) <: AbstractUnitRange{Int }
47
47
48
48
This range permits diverse representations of arrays to comunicate common information
49
49
about their indices. Each field may be an integer or `Val(<:Integer)` if it is known
@@ -67,21 +67,15 @@ struct OptionallyStaticUnitRange{F <: Integer, L <: Integer} <: AbstractUnitRang
67
67
end
68
68
end
69
69
70
- function OptionallyStaticUnitRange (x:: AbstractRange )
70
+ function OptionallyStaticUnitRange (x:: AbstractRange )
71
71
if step (x) == 1
72
- fst = static_first (x)
73
- lst = static_last (x)
74
- return OptionallyStaticUnitRange (fst, lst)
72
+ return OptionallyStaticUnitRange (static_first (x), static_last (x))
75
73
else
76
74
throw (ArgumentError (" step must be 1, got $(step (r)) " ))
77
75
end
78
76
end
79
77
end
80
78
81
- Base.:(:)(L:: Integer , :: StaticInt{U} ) where {U} = OptionallyStaticUnitRange (L, StaticInt (U))
82
- Base.:(:)(:: StaticInt{L} , U:: Integer ) where {L} = OptionallyStaticUnitRange (StaticInt (L), U)
83
- Base.:(:)(:: StaticInt{L} , :: StaticInt{U} ) where {L,U} = OptionallyStaticUnitRange (StaticInt (L), StaticInt (U))
84
-
85
79
Base. first (r:: OptionallyStaticUnitRange ) = r. start
86
80
Base. step (:: OptionallyStaticUnitRange ) = StaticInt (1 )
87
81
Base. last (r:: OptionallyStaticUnitRange ) = r. stop
@@ -90,6 +84,110 @@ known_first(::Type{<:OptionallyStaticUnitRange{StaticInt{F}}}) where {F} = F
90
84
known_step (:: Type{<:OptionallyStaticUnitRange} ) = 1
91
85
known_last (:: Type{<:OptionallyStaticUnitRange{<:Any,StaticInt{L}}} ) where {L} = L
92
86
87
+ """
88
+ OptionallyStaticStepRange(start, step, stop) <: OrdinalRange{Int,Int}
89
+
90
+ Similar to [`OptionallyStaticUnitRange`](@ref), `OptionallyStaticStepRange` permits
91
+ a combination of static and standard primitive `Int`s to construct a range. It
92
+ specifically enables the use of ranges without a step size of 1. It may be constructed
93
+ through the use of `OptionallyStaticStepRange` directly or using static integers with
94
+ the range operatore (i.e. `:`).
95
+
96
+ ```julia
97
+ julia> using ArrayInterface
98
+
99
+ julia> x = ArrayInterface.StaticInt(2);
100
+
101
+ julia> x:x:10
102
+ ArrayInterface.StaticInt{2}():ArrayInterface.StaticInt{2}():10
103
+
104
+ julia> ArrayInterface.OptionallyStaticStepRange(x, x, 10)
105
+ ArrayInterface.StaticInt{2}():ArrayInterface.StaticInt{2}():10
106
+ ```
107
+ """
108
+ struct OptionallyStaticStepRange{F <: Integer , S <: Integer , L <: Integer } <: OrdinalRange{Int,Int}
109
+ start:: F
110
+ step:: S
111
+ stop:: L
112
+
113
+ function OptionallyStaticStepRange (start, step, stop)
114
+ if eltype (start) <: Int
115
+ if eltype (stop) <: Int
116
+ lst = _steprange_last (start, step, stop)
117
+ return new {typeof(start),typeof(step),typeof(lst)} (start, step, lst)
118
+ else
119
+ return OptionallyStaticStepRange (start, step, Int (stop))
120
+ end
121
+ else
122
+ return OptionallyStaticStepRange (Int (start), step, stop)
123
+ end
124
+ end
125
+
126
+ function OptionallyStaticStepRange (x:: AbstractRange )
127
+ return OptionallyStaticStepRange (static_first (x), static_step (x), static_last (x))
128
+ end
129
+ end
130
+
131
+ # to make StepRange constructor inlineable, so optimizer can see `step` value
132
+ @inline function _steprange_last (start:: StaticInt , step:: StaticInt , stop:: StaticInt )
133
+ return StaticInt (_steprange_last (Int (start), Int (step), Int (stop)))
134
+ end
135
+ @inline function _steprange_last (start:: Integer , step:: StaticInt , stop:: StaticInt )
136
+ if step === one (step)
137
+ # we don't need to check the `stop` if we know it acts like a unit range
138
+ return stop
139
+ else
140
+ return _steprange_last (start, Int (step), Int (stop))
141
+ end
142
+ end
143
+ @inline function _steprange_last (start:: Integer , step:: Integer , stop:: Integer )
144
+ z = zero (step)
145
+ if step === z
146
+ throw (ArgumentError (" step cannot be zero" ))
147
+ else
148
+ if stop == start
149
+ return Int (stop)
150
+ else
151
+ if step > z
152
+ if stop > start
153
+ return stop - Int (unsigned (stop - start) % step)
154
+ else
155
+ return Int (start - one (start))
156
+ end
157
+ else
158
+ if stop > start
159
+ return Int (start + one (start))
160
+ else
161
+ return stop + Int (unsigned (start - stop) % - step)
162
+ end
163
+ end
164
+ end
165
+ end
166
+ end
167
+ Base. first (r:: OptionallyStaticStepRange ) = r. start
168
+ Base. step (r:: OptionallyStaticStepRange ) = r. step
169
+ Base. last (r:: OptionallyStaticStepRange ) = r. stop
170
+
171
+ known_first (:: Type{<:OptionallyStaticStepRange{StaticInt{F}}} ) where {F} = F
172
+ known_step (:: Type{<:OptionallyStaticStepRange{<:Any,StaticInt{S}}} ) where {S} = S
173
+ known_last (:: Type{<:OptionallyStaticStepRange{<:Any,<:Any,StaticInt{L}}} ) where {L} = L
174
+
175
+ Base.:(:)(L:: Integer , :: StaticInt{U} ) where {U} = OptionallyStaticUnitRange (L, StaticInt (U))
176
+ Base.:(:)(:: StaticInt{L} , U:: Integer ) where {L} = OptionallyStaticUnitRange (StaticInt (L), U)
177
+ Base.:(:)(:: StaticInt{L} , :: StaticInt{U} ) where {L,U} = OptionallyStaticUnitRange (StaticInt (L), StaticInt (U))
178
+ Base.:(:)(:: StaticInt{F} , :: StaticInt{S} , :: StaticInt{L} ) where {F,S,L} = OptionallyStaticStepRange (StaticInt (F), StaticInt (S), StaticInt (L))
179
+ Base.:(:)(start:: Integer , :: StaticInt{S} , :: StaticInt{L} ) where {S,L} = OptionallyStaticStepRange (start, StaticInt (S), StaticInt (L))
180
+ Base.:(:)(:: StaticInt{F} , :: StaticInt{S} , stop:: Integer ) where {F,S} = OptionallyStaticStepRange (StaticInt (F), StaticInt (S), stop)
181
+ Base.:(:)(:: StaticInt{F} , step:: Integer , :: StaticInt{L} ) where {F,L} = OptionallyStaticStepRange (StaticInt (F), step, StaticInt (L))
182
+ Base.:(:)(start:: Integer , step:: Integer , :: StaticInt{L} ) where {L} = OptionallyStaticStepRange (start, step, StaticInt (L))
183
+ Base.:(:)(start:: Integer , :: StaticInt{S} , stop:: Integer ) where {S} = OptionallyStaticStepRange (start, StaticInt (S), stop)
184
+ Base.:(:)(:: StaticInt{F} , step:: Integer , stop:: Integer ) where {F} = OptionallyStaticStepRange (StaticInt (F), step, stop)
185
+ Base.:(:)(:: StaticInt{F} , :: StaticInt{1} , :: StaticInt{L} ) where {F,L} = OptionallyStaticUnitRange (StaticInt (F), StaticInt (L))
186
+ Base.:(:)(start:: Integer , :: StaticInt{1} , :: StaticInt{L} ) where {L} = OptionallyStaticUnitRange (start, StaticInt (L))
187
+ Base.:(:)(:: StaticInt{F} , :: StaticInt{1} , stop:: Integer ) where {F} = OptionallyStaticUnitRange (StaticInt (F), stop)
188
+ Base.:(:)(start:: Integer , :: StaticInt{1} , stop:: Integer ) = OptionallyStaticUnitRange (start, stop)
189
+
190
+
93
191
function Base. isempty (r:: OptionallyStaticUnitRange )
94
192
if known_first (r) === oneunit (eltype (r))
95
193
return unsafe_isempty_one_to (last (r))
@@ -98,13 +196,29 @@ function Base.isempty(r::OptionallyStaticUnitRange)
98
196
end
99
197
end
100
198
199
+ function Base. isempty (r:: OptionallyStaticStepRange )
200
+ return (r. start != r. stop) & ((r. step > zero (r. step)) != (r. stop > r. start))
201
+ end
202
+
101
203
unsafe_isempty_one_to (lst) = lst <= zero (lst)
102
204
unsafe_isempty_unit_range (fst, lst) = fst > lst
103
205
104
206
unsafe_length_one_to (lst:: Int ) = lst
105
- unsafe_length_one_to (:: StaticInt{L} ) where {L} = lst
207
+ unsafe_length_one_to (:: StaticInt{L} ) where {L} = L
208
+
209
+ @inline function unsafe_length_step_range (start:: Int , step:: Int , stop:: Int )
210
+ if step > 1
211
+ return Base. checked_add (Int (div (unsigned (stop - start), step)), 1 )
212
+ elseif step < - 1
213
+ return Base. checked_add (Int (div (unsigned (start - stop), - step)), 1 )
214
+ elseif step > 0
215
+ return Base. checked_add (Int (div (Base. checked_sub (stop, start), step)), 1 )
216
+ else
217
+ return Base. checked_add (Int (div (Base. checked_sub (rtart, stop), - step)), 1 )
218
+ end
219
+ end
106
220
107
- Base . @propagate_inbounds function Base. getindex (r:: OptionallyStaticUnitRange , i:: Integer )
221
+ @propagate_inbounds function Base. getindex (r:: OptionallyStaticUnitRange , i:: Integer )
108
222
if known_first (r) === oneunit (eltype (r))
109
223
return get_index_one_to (r, i)
110
224
else
121
235
122
236
@inline function get_index_unit_range (r, i)
123
237
val = first (r) + (i - 1 )
124
- @boundscheck if (i < 1 ) || ( val > last (r) && val < first (r) )
238
+ @boundscheck if (i < 1 ) || val > last (r)
125
239
throw (BoundsError (r, i))
126
240
end
127
241
return convert (eltype (r), val)
@@ -130,28 +244,28 @@ end
130
244
@inline _try_static (:: StaticInt{N} , :: StaticInt{N} ) where {N} = StaticInt {N} ()
131
245
@inline _try_static (:: StaticInt{M} , :: StaticInt{N} ) where {M, N} = @assert false " Unequal Indices: StaticInt{$M }() != StaticInt{$N }()"
132
246
@propagate_inbounds function _try_static (:: StaticInt{N} , x) where {N}
133
- @boundscheck begin
134
- @assert N == x " Unequal Indices: StaticInt{$N }() != x == $x "
135
- end
136
- return StaticInt {N} ()
247
+ @boundscheck begin
248
+ @assert N == x " Unequal Indices: StaticInt{$N }() != x == $x "
249
+ end
250
+ return StaticInt {N} ()
137
251
end
138
252
@propagate_inbounds function _try_static (x, :: StaticInt{N} ) where {N}
139
- @boundscheck begin
140
- @assert N == x " Unequal Indices: x == $x != StaticInt{$N }()"
141
- end
142
- return StaticInt {N} ()
253
+ @boundscheck begin
254
+ @assert N == x " Unequal Indices: x == $x != StaticInt{$N }()"
255
+ end
256
+ return StaticInt {N} ()
143
257
end
144
258
@propagate_inbounds function _try_static (x, y)
145
- @boundscheck begin
146
- @assert x == y " Unequal Indicess: x == $x != $y == y"
147
- end
148
- return x
259
+ @boundscheck begin
260
+ @assert x == y " Unequal Indicess: x == $x != $y == y"
261
+ end
262
+ return x
149
263
end
150
264
151
265
# ##
152
266
# ## length
153
267
# ##
154
- @inline function known_length (:: Type{T} ) where {T<: AbstractUnitRange }
268
+ @inline function known_length (:: Type{T} ) where {T<: OptionallyStaticUnitRange }
155
269
fst = known_first (T)
156
270
lst = known_last (T)
157
271
if fst === nothing || lst === nothing
165
279
end
166
280
end
167
281
282
+ @inline function known_length (:: Type{T} ) where {T<: OptionallyStaticStepRange }
283
+ fst = known_first (T)
284
+ stp = known_step (T)
285
+ lst = known_last (T)
286
+ if fst === nothing || stp === nothing || lst === nothing
287
+ return nothing
288
+ else
289
+ if stp === 1
290
+ if fst === oneunit (eltype (T))
291
+ return unsafe_length_one_to (lst)
292
+ else
293
+ return unsafe_length_unit_range (fst, lst)
294
+ end
295
+ else
296
+ return unsafe_length_step_range (fst, stp, lst)
297
+ end
298
+ end
299
+ end
300
+
168
301
function Base. length (r:: OptionallyStaticUnitRange )
169
302
if isempty (r)
170
303
return 0
@@ -177,6 +310,23 @@ function Base.length(r::OptionallyStaticUnitRange)
177
310
end
178
311
end
179
312
313
+ function Base. length (r:: OptionallyStaticStepRange )
314
+ if isempty (r)
315
+ return 0
316
+ else
317
+ if known_step (r) === 1
318
+ if known_first (r) === 1
319
+ return unsafe_length_one_to (last (r))
320
+ else
321
+ return unsafe_length_unit_range (first (r), last (r))
322
+ end
323
+ else
324
+ return unsafe_length_step_range (Int (first (r)), Int (step (r)), Int (last (r)))
325
+ end
326
+ end
327
+ end
328
+
329
+
180
330
unsafe_length_unit_range (start:: Integer , stop:: Integer ) = Int ((stop - start) + 1 )
181
331
182
332
"""
219
369
lst = _try_static (static_last (x), static_last (y))
220
370
return Base. Slice (OptionallyStaticUnitRange (fst, lst))
221
371
end
372
+
0 commit comments