Skip to content

Commit df1424c

Browse files
committed
broadcast fixes
1 parent 6a02c86 commit df1424c

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

src/broadcast.jl

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ for f ∈ [ # groupedstridedpointer support
189189
:(ArrayInterface.contiguous_axis),
190190
:(ArrayInterface.contiguous_batch_size),
191191
:(ArrayInterface.device),
192+
:(ArrayInterface.dense_dims),
192193
:(ArrayInterface.stride_rank),
193194
:(VectorizationBase.val_dense_dims),
194195
:(ArrayInterface.offsets),
@@ -204,6 +205,8 @@ function is_column_major(x)
204205
true
205206
end
206207
is_row_major(x) = is_column_major(reverse(x))
208+
_find_arg_least_greater(r::Vector{Int}, i) =
209+
findmin(x -> x > i ? x : typemax(Int), r)
207210
# @inline _bytestrides(s,paren) = VectorizationBase.bytestrides(paren)
208211
function _strides_expr(
209212
@nospecialize(s),
@@ -215,19 +218,18 @@ function _strides_expr(
215218
q = Expr(:block, Expr(:meta, :inline))
216219
strd_tup = Expr(:tuple)
217220
ifel = GlobalRef(Core, :ifelse)
218-
Nrange = 1:1:N # type stability w/ respect to reverse
221+
Nrange = 1:N # type stability w/ respect to reverse
222+
# Nrange = 1:1:N # type stability w/ respect to reverse
219223
use_stride_acc = true
220224
stride_acc::Int = 1
221-
if is_column_major(R)
222-
# elseif is_row_major(R)
223-
# Nrange = reverse(Nrange)
224-
else # not worth my time optimizing this case at the moment...
225-
# will write something generic stride-rank agnostic eventually
225+
next, n = _find_arg_least_greater(R, 0)
226+
n = findfirst(==(1), R)
227+
if !D[n]
226228
use_stride_acc = false
227229
stride_acc = 0
228230
end
229231
sₙ_value::Int = 0
230-
for n Nrange
232+
for _n Nrange
231233
xₙ_type = x[n]
232234
xₙ_static = xₙ_type <: StaticInt
233235
xₙ_value::Int = xₙ_static ? (xₙ_type.parameters[1])::Int : 0
@@ -254,20 +256,22 @@ function _strides_expr(
254256
)
255257
end
256258
end
257-
if (n last(Nrange)) && use_stride_acc
258-
nnext = n + step(Nrange)
259-
if D[nnext]
260-
if xₙ_static & sₙ_static
261-
stride_acc = xₙ_value * sₙ_value
262-
elseif sₙ_static
263-
if stride_acc 0
264-
stride_acc *= sₙ_value
259+
if (n N)
260+
next, n = _find_arg_least_greater(R, next)
261+
if use_stride_acc
262+
if D[n]
263+
if xₙ_static & sₙ_static
264+
stride_acc = xₙ_value * sₙ_value
265+
elseif sₙ_static
266+
if stride_acc 0
267+
stride_acc *= sₙ_value
268+
end
269+
else
270+
stride_acc = 0
265271
end
266272
else
267273
stride_acc = 0
268274
end
269-
else
270-
stride_acc = 0
271275
end
272276
end
273277
end
@@ -675,6 +679,7 @@ end
675679
::Val{UNROLL},
676680
::Val{dontbc}
677681
) where {T<:NativeTypes,N,BC<:Union{Broadcasted,Product},Mod,UNROLL,dontbc}
682+
@show (dest) (BC)
678683
vmaterialize_fun(sizeof(T), N, BC, Mod, UNROLL, dontbc, false)
679684
end
680685
@generated function vmaterialize!(

0 commit comments

Comments
 (0)