Skip to content

Commit 8f0931e

Browse files
committed
Add support for reducing through ? and ifelse
1 parent 5ba0d18 commit 8f0931e

File tree

5 files changed

+123
-8
lines changed

5 files changed

+123
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.9.11"
4+
version = "0.9.12"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/add_compute.jl

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,13 @@ function add_reduction_update_parent!(
162162
)
163163
var = name(parent)
164164
isouterreduction = parent.instruction === LOOPCONSTANT
165-
instrclass = reduction_instruction_class(instr) # key allows for faster lookups
165+
# @show instr, vparents, parent, reduction_ind
166+
if instr.instr === :ifelse
167+
@assert length(vparents) == 2
168+
instrclass = reduction_instruction_class(instruction(vparents[2])) # key allows for faster lookups
169+
else
170+
instrclass = reduction_instruction_class(instr) # key allows for faster lookups
171+
end
166172
reduct_zero = reduction_zero(instrclass)
167173
# if parent is not an outer reduction...
168174
# if !isouterreduction && !isreductzero(parent, ls, reduct_zero)
@@ -295,14 +301,48 @@ function add_compute!(
295301
end
296302

297303
function add_compute!(
298-
ls::LoopSet, LHS::Symbol, instr, vparents::Vector{Operation}, elementbytes
304+
ls::LoopSet, LHS::Symbol, instr, vparents::Vector{Operation}, elementbytes::Int
299305
)
300306
deps = Symbol[]
301307
reduceddeps = Symbol[]
302-
foreach(parent -> update_deps!(deps, reduceddeps, parent), vparents)
308+
for parent vparents
309+
update_deps!(deps, reduceddeps, parent)
310+
end
303311
op = Operation(length(operations(ls)), LHS, elementbytes, instr, compute, deps, reduceddeps, vparents)
304312
pushop!(ls, op, LHS)
305313
end
314+
# checks for reductions
315+
function add_compute_ifelse!(
316+
ls::LoopSet, LHS::Symbol, cond::Operation, iftrue::Operation, iffalse::Operation, elementbytes::Int
317+
)
318+
deps = Symbol[]
319+
reduceddeps = Symbol[]
320+
update_deps!(deps, reduceddeps, cond)
321+
update_deps!(deps, reduceddeps, iftrue)
322+
update_deps!(deps, reduceddeps, iffalse)
323+
if name(iftrue) === LHS
324+
if name(iffalse) === LHS # a = ifelse(condition, a, a) # -- why??? Let's just eliminate it.
325+
return iftrue
326+
end
327+
vparents = Operation[cond, iffalse]
328+
setdiffv!(reduceddeps, deps, loopdependencies(iftrue))
329+
add_reduction_update_parent!(
330+
vparents, deps, reduceddeps, ls,
331+
iftrue, Instruction(:LoopVectorization,:ifelse), 2, elementbytes
332+
)
333+
elseif name(iffalse) === LHS
334+
vparents = Operation[cond, iftrue]
335+
setdiffv!(reduceddeps, deps, loopdependencies(iffalse))
336+
add_reduction_update_parent!(
337+
vparents, deps, reduceddeps, ls,
338+
iffalse, Instruction(:LoopVectorization,:ifelse), 3, elementbytes
339+
)
340+
else
341+
vparents = Operation[cond, iftrue, iffalse]
342+
op = Operation(length(operations(ls)), LHS, elementbytes, :ifelse, compute, deps, reduceddeps, vparents)
343+
pushop!(ls, op, LHS)
344+
end
345+
end
306346

307347
# adds x ^ (p::Real)
308348
function add_pow!(

src/add_ifelse.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ function add_if!(ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int, positio
3535
else
3636
falseop = getop(ls, iffalse, elementbytes)
3737
end
38-
add_compute!(ls, LHS, :ifelse, [condop, trueop, falseop], elementbytes)
38+
add_compute_ifelse!(ls, LHS, condop, trueop, falseop, elementbytes)
3939
end
4040

4141
function add_andblock!(ls::LoopSet, condop::Operation, LHS, rhsop::Operation, elementbytes::Int, position::Int)
4242
if LHS isa Symbol
4343
altop = getop(ls, LHS, elementbytes)
44-
return add_compute!(ls, LHS, :ifelse, [condop, rhsop, altop], elementbytes)
44+
return add_compute_ifelse!(ls, LHS, condop, rhsop, altop, elementbytes)
4545
elseif LHS isa Expr && LHS.head === :ref
4646
return add_conditional_store!(ls, LHS, condop, rhsop, elementbytes)
4747
else
@@ -81,7 +81,7 @@ function add_orblock!(ls::LoopSet, condop::Operation, LHS, rhsop::Operation, ele
8181
# return add_compute!(ls, LHS, :ifelse, [condop, altop, rhsop], elementbytes)
8282
# Placing altop second seems to let LLVM fuse operations; but as of LLVM 9.0.1 it will not if altop is first
8383
# therefore, we negate the condition and switch order so that the altop is second.
84-
return add_compute!(ls, LHS, :ifelse, [negatedcondop, rhsop, altop], elementbytes)
84+
return add_compute_ifelse!(ls, LHS, negatedcondop, rhsop, altop, elementbytes)
8585
elseif LHS isa Expr && LHS.head === :ref
8686
# negatedcondop = add_compute!(ls, gensym(:negated_mask), :~, [condop], elementbytes)
8787
return add_conditional_store!(ls, LHS, negatedcondop, rhsop, elementbytes)

src/lowering.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,13 +892,16 @@ function isunrolled_sym(op::Operation, u₁loop::Symbol, u₂loop::Symbol)
892892
u₂ild = u₂loop reducedchildren(op)
893893
end
894894
end
895+
# @show op u₁ild, u₂ild
895896
(u₁ild & u₂ild) || return u₁ild, u₂ild
896897
reductops = isconstant(op) ? reducedchildren(op) : reduceddependencies(op)
898+
# @show op reductops
897899
iszero(length(reductops)) && return true, true
898900
u₁reduced = u₁loop reductops
899901
u₂reduced = u₂loop reductops
900902
# We want to only unroll one of them.
901903
# We prefer not to unroll a reduced loop
904+
# @show u₁reduced, u₂reduced
902905
if u₂reduced # if both are reduced, we unroll u₁
903906
true, false
904907
elseif u₁reduced

test/ifelsemasks.jl

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,72 @@ T = Float32
360360
f[j, d] = _x
361361
end
362362
end
363+
364+
function barycentric_weight0(X)
365+
T = eltype(X)
366+
n = length(X) - 1
367+
w = zero(X)
368+
@inbounds @fastmath for j in 0:n
369+
tmp = one(T)
370+
for k in 0:n
371+
tmp = k==j ? tmp : tmp * (X[j+1] - X[k+1])
372+
end
373+
w[j+1] = inv(tmp)
374+
end
375+
return w
376+
end
377+
function barycentric_weight1(X)
378+
T = eltype(X)
379+
n = length(X) - 1
380+
w = zero(X)
381+
@avx for j in 0:n
382+
tmp = one(T)
383+
for k in 0:n
384+
tmp = k != j ? tmp * (X[j+1] - X[k+1]) : tmp
385+
end
386+
w[j+1] = inv(tmp)
387+
end
388+
return w
389+
end
390+
function barycentric_weight2(X)
391+
T = eltype(X)
392+
n = length(X) - 1
393+
w = zero(X)
394+
@avx inline=true for j in 0:n
395+
tmp = one(T)
396+
for k in 0:n
397+
tmp = k==j ? tmp : tmp * (X[j+1] - X[k+1])
398+
end
399+
w[j+1] = inv(tmp)
400+
end
401+
return w
402+
end
403+
function barycentric_weight3(X)
404+
T = eltype(X)
405+
n = length(X) - 1
406+
w = zero(X)
407+
@avx inline=true for j in 0:n
408+
tmp = one(T)
409+
for k in 0:n
410+
tmp = ifelse(k != j, tmp * (X[j+1] - X[k+1]), tmp)
411+
end
412+
w[j+1] = inv(tmp)
413+
end
414+
return w
415+
end
416+
function barycentric_weight4(X)
417+
T = eltype(X)
418+
n = length(X) - 1
419+
w = zero(X)
420+
@avx for j in 0:n
421+
tmp = one(T)
422+
for k in 0:n
423+
tmp = ifelse(k == j, tmp, tmp * (X[j+1] - X[k+1]))
424+
end
425+
w[j+1] = inv(tmp)
426+
end
427+
return w
428+
end
363429

364430
N = 117
365431
for T (Float32, Float64, Int32, Int64)
@@ -529,5 +595,11 @@ T = Float32
529595
# fc2 = copy(f);
530596
testfunctionavx!(f, v, d, g, s, θ)
531597
@test f fc
532-
598+
599+
X = rand(4, 5)
600+
bX = barycentric_weight0(X);
601+
@test barycentric_weight1(X) bX
602+
@test barycentric_weight2(X) bX
603+
@test barycentric_weight3(X) bX
604+
@test barycentric_weight4(X) bX
533605
end

0 commit comments

Comments
 (0)