@@ -189,6 +189,7 @@ for f ∈ [ # groupedstridedpointer support
189
189
:(ArrayInterface. contiguous_axis),
190
190
:(ArrayInterface. contiguous_batch_size),
191
191
:(ArrayInterface. device),
192
+ :(ArrayInterface. dense_dims),
192
193
:(ArrayInterface. stride_rank),
193
194
:(VectorizationBase. val_dense_dims),
194
195
:(ArrayInterface. offsets),
@@ -204,7 +205,9 @@ function is_column_major(x)
204
205
true
205
206
end
206
207
is_row_major (x) = is_column_major (reverse (x))
207
- # @inline _bytestrides(s,paren) = VectorizationBase.bytestrides(paren)
208
+ _find_arg_least_greater (r:: Vector{Int} , i) =
209
+ findmin (x -> x > i ? x : typemax (Int), r)
210
+
208
211
function _strides_expr (
209
212
@nospecialize (s),
210
213
@nospecialize (x),
@@ -214,20 +217,19 @@ function _strides_expr(
214
217
N = length (R)
215
218
q = Expr (:block , Expr (:meta , :inline ))
216
219
strd_tup = Expr (:tuple )
220
+ resize! (strd_tup. args, N)
217
221
ifel = GlobalRef (Core, :ifelse )
218
- Nrange = 1 : 1 : N # type stability w/ respect to reverse
222
+ Nrange = 1 : N # type stability w/ respect to reverse
223
+ # Nrange = 1:1:N # type stability w/ respect to reverse
219
224
use_stride_acc = true
220
225
stride_acc:: Int = 1
221
- if is_column_major (R)
222
- # elseif is_row_major(R)
223
- # Nrange = reverse(Nrange)
224
- else # not worth my time optimizing this case at the moment...
225
- # will write something generic stride-rank agnostic eventually
226
+ next, n = _find_arg_least_greater (R, 0 )
227
+ if ! D[n]
226
228
use_stride_acc = false
227
229
stride_acc = 0
228
230
end
229
231
sₙ_value:: Int = 0
230
- for n ∈ Nrange
232
+ for _n ∈ Nrange
231
233
xₙ_type = x[n]
232
234
xₙ_static = xₙ_type <: StaticInt
233
235
xₙ_value:: Int = xₙ_static ? (xₙ_type. parameters[1 ]):: Int : 0
@@ -236,38 +238,38 @@ function _strides_expr(
236
238
if sₙ_static
237
239
sₙ_value = s_type. parameters[1 ]
238
240
if s_type === One
239
- push! ( strd_tup. args, Expr (:call , lv (:Zero ) ))
241
+ strd_tup. args[n] = Expr (:call , lv (:Zero ))
240
242
elseif stride_acc ≠ 0
241
- push! ( strd_tup. args, staticexpr (stride_acc) )
243
+ strd_tup. args[n] = staticexpr (stride_acc)
242
244
else
243
- push! ( strd_tup. args, :($ getfield (x, $ n) ))
245
+ strd_tup. args[n] = :($ getfield (x, $ n))
244
246
end
245
247
else
246
248
if xₙ_static
247
- push! ( strd_tup. args, staticexpr (xₙ_value) )
249
+ strd_tup. args[n] = staticexpr (xₙ_value)
248
250
elseif stride_acc ≠ 0
249
- push! ( strd_tup. args, staticexpr (stride_acc) )
251
+ strd_tup. args[n] = staticexpr (stride_acc)
250
252
else
251
- push! (
252
- strd_tup. args,
253
+ strd_tup. args[n] =
253
254
:($ ifel (isone ($ getfield (s, $ n)), zero ($ xₙ_type), $ getfield (x, $ n)))
254
- )
255
255
end
256
256
end
257
- if (n ≠ last (Nrange)) && use_stride_acc
258
- nnext = n + step (Nrange)
259
- if D[nnext]
260
- if xₙ_static & sₙ_static
261
- stride_acc = xₙ_value * sₙ_value
262
- elseif sₙ_static
263
- if stride_acc ≠ 0
264
- stride_acc *= sₙ_value
257
+ if (_n ≠ N)
258
+ next, n = _find_arg_least_greater (R, next)
259
+ if use_stride_acc
260
+ if D[n]
261
+ if xₙ_static & sₙ_static
262
+ stride_acc = xₙ_value * sₙ_value
263
+ elseif sₙ_static
264
+ if stride_acc ≠ 0
265
+ stride_acc *= sₙ_value
266
+ end
267
+ else
268
+ stride_acc = 0
265
269
end
266
270
else
267
271
stride_acc = 0
268
272
end
269
- else
270
- stride_acc = 0
271
273
end
272
274
end
273
275
end
0 commit comments