Skip to content

Commit a21d6f8

Browse files
authored
Strides fix in broadcast (#504)
* ignore oftype * broadcast fixes * delete problematic line that accidentally wasn't removed * no print * fix order
1 parent befd727 commit a21d6f8

File tree

2 files changed

+30
-27
lines changed

2 files changed

+30
-27
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.162"
4+
version = "0.12.163"
5+
56

67
[deps]
78
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/broadcast.jl

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ for f ∈ [ # groupedstridedpointer support
189189
:(ArrayInterface.contiguous_axis),
190190
:(ArrayInterface.contiguous_batch_size),
191191
:(ArrayInterface.device),
192+
:(ArrayInterface.dense_dims),
192193
:(ArrayInterface.stride_rank),
193194
:(VectorizationBase.val_dense_dims),
194195
:(ArrayInterface.offsets),
@@ -204,7 +205,9 @@ function is_column_major(x)
204205
true
205206
end
206207
is_row_major(x) = is_column_major(reverse(x))
207-
# @inline _bytestrides(s,paren) = VectorizationBase.bytestrides(paren)
208+
_find_arg_least_greater(r::Vector{Int}, i) =
209+
findmin(x -> x > i ? x : typemax(Int), r)
210+
208211
function _strides_expr(
209212
@nospecialize(s),
210213
@nospecialize(x),
@@ -214,20 +217,19 @@ function _strides_expr(
214217
N = length(R)
215218
q = Expr(:block, Expr(:meta, :inline))
216219
strd_tup = Expr(:tuple)
220+
resize!(strd_tup.args, N)
217221
ifel = GlobalRef(Core, :ifelse)
218-
Nrange = 1:1:N # type stability w/ respect to reverse
222+
Nrange = 1:N # type stability w/ respect to reverse
223+
# Nrange = 1:1:N # type stability w/ respect to reverse
219224
use_stride_acc = true
220225
stride_acc::Int = 1
221-
if is_column_major(R)
222-
# elseif is_row_major(R)
223-
# Nrange = reverse(Nrange)
224-
else # not worth my time optimizing this case at the moment...
225-
# will write something generic stride-rank agnostic eventually
226+
next, n = _find_arg_least_greater(R, 0)
227+
if !D[n]
226228
use_stride_acc = false
227229
stride_acc = 0
228230
end
229231
sₙ_value::Int = 0
230-
for n Nrange
232+
for _n Nrange
231233
xₙ_type = x[n]
232234
xₙ_static = xₙ_type <: StaticInt
233235
xₙ_value::Int = xₙ_static ? (xₙ_type.parameters[1])::Int : 0
@@ -236,38 +238,38 @@ function _strides_expr(
236238
if sₙ_static
237239
sₙ_value = s_type.parameters[1]
238240
if s_type === One
239-
push!(strd_tup.args, Expr(:call, lv(:Zero)))
241+
strd_tup.args[n] = Expr(:call, lv(:Zero))
240242
elseif stride_acc 0
241-
push!(strd_tup.args, staticexpr(stride_acc))
243+
strd_tup.args[n] = staticexpr(stride_acc)
242244
else
243-
push!(strd_tup.args, :($getfield(x, $n)))
245+
strd_tup.args[n] = :($getfield(x, $n))
244246
end
245247
else
246248
if xₙ_static
247-
push!(strd_tup.args, staticexpr(xₙ_value))
249+
strd_tup.args[n] = staticexpr(xₙ_value)
248250
elseif stride_acc 0
249-
push!(strd_tup.args, staticexpr(stride_acc))
251+
strd_tup.args[n] = staticexpr(stride_acc)
250252
else
251-
push!(
252-
strd_tup.args,
253+
strd_tup.args[n] =
253254
:($ifel(isone($getfield(s, $n)), zero($xₙ_type), $getfield(x, $n)))
254-
)
255255
end
256256
end
257-
if (n last(Nrange)) && use_stride_acc
258-
nnext = n + step(Nrange)
259-
if D[nnext]
260-
if xₙ_static & sₙ_static
261-
stride_acc = xₙ_value * sₙ_value
262-
elseif sₙ_static
263-
if stride_acc 0
264-
stride_acc *= sₙ_value
257+
if (_n N)
258+
next, n = _find_arg_least_greater(R, next)
259+
if use_stride_acc
260+
if D[n]
261+
if xₙ_static & sₙ_static
262+
stride_acc = xₙ_value * sₙ_value
263+
elseif sₙ_static
264+
if stride_acc 0
265+
stride_acc *= sₙ_value
266+
end
267+
else
268+
stride_acc = 0
265269
end
266270
else
267271
stride_acc = 0
268272
end
269-
else
270-
stride_acc = 0
271273
end
272274
end
273275
end

0 commit comments

Comments
 (0)