@@ -189,6 +189,7 @@ for f ∈ [ # groupedstridedpointer support
189
189
:(ArrayInterface. contiguous_axis),
190
190
:(ArrayInterface. contiguous_batch_size),
191
191
:(ArrayInterface. device),
192
+ :(ArrayInterface. dense_dims),
192
193
:(ArrayInterface. stride_rank),
193
194
:(VectorizationBase. val_dense_dims),
194
195
:(ArrayInterface. offsets),
@@ -204,6 +205,8 @@ function is_column_major(x)
204
205
true
205
206
end
206
207
is_row_major (x) = is_column_major (reverse (x))
208
+ _find_arg_least_greater (r:: Vector{Int} , i) =
209
+ findmin (x -> x > i ? x : typemax (Int), r)
207
210
# @inline _bytestrides(s,paren) = VectorizationBase.bytestrides(paren)
208
211
function _strides_expr (
209
212
@nospecialize (s),
@@ -215,19 +218,18 @@ function _strides_expr(
215
218
q = Expr (:block , Expr (:meta , :inline ))
216
219
strd_tup = Expr (:tuple )
217
220
ifel = GlobalRef (Core, :ifelse )
218
- Nrange = 1 : 1 : N # type stability w/ respect to reverse
221
+ Nrange = 1 : N # type stability w/ respect to reverse
222
+ # Nrange = 1:1:N # type stability w/ respect to reverse
219
223
use_stride_acc = true
220
224
stride_acc:: Int = 1
221
- if is_column_major (R)
222
- # elseif is_row_major(R)
223
- # Nrange = reverse(Nrange)
224
- else # not worth my time optimizing this case at the moment...
225
- # will write something generic stride-rank agnostic eventually
225
+ next, n = _find_arg_least_greater (R, 0 )
226
+ n = findfirst (== (1 ), R)
227
+ if ! D[n]
226
228
use_stride_acc = false
227
229
stride_acc = 0
228
230
end
229
231
sₙ_value:: Int = 0
230
- for n ∈ Nrange
232
+ for _n ∈ Nrange
231
233
xₙ_type = x[n]
232
234
xₙ_static = xₙ_type <: StaticInt
233
235
xₙ_value:: Int = xₙ_static ? (xₙ_type. parameters[1 ]):: Int : 0
@@ -254,20 +256,22 @@ function _strides_expr(
254
256
)
255
257
end
256
258
end
257
- if (n ≠ last (Nrange)) && use_stride_acc
258
- nnext = n + step (Nrange)
259
- if D[nnext]
260
- if xₙ_static & sₙ_static
261
- stride_acc = xₙ_value * sₙ_value
262
- elseif sₙ_static
263
- if stride_acc ≠ 0
264
- stride_acc *= sₙ_value
259
+ if (n ≠ N)
260
+ next, n = _find_arg_least_greater (R, next)
261
+ if use_stride_acc
262
+ if D[n]
263
+ if xₙ_static & sₙ_static
264
+ stride_acc = xₙ_value * sₙ_value
265
+ elseif sₙ_static
266
+ if stride_acc ≠ 0
267
+ stride_acc *= sₙ_value
268
+ end
269
+ else
270
+ stride_acc = 0
265
271
end
266
272
else
267
273
stride_acc = 0
268
274
end
269
- else
270
- stride_acc = 0
271
275
end
272
276
end
273
277
end
675
679
:: Val{UNROLL} ,
676
680
:: Val{dontbc}
677
681
) where {T<: NativeTypes ,N,BC<: Union{Broadcasted,Product} ,Mod,UNROLL,dontbc}
682
+ @show (dest) (BC)
678
683
vmaterialize_fun (sizeof (T), N, BC, Mod, UNROLL, dontbc, false )
679
684
end
680
685
@generated function vmaterialize! (
0 commit comments