Skip to content

Commit 2fdb50f

Browse files
committed
Add loopdeps and reduced children to new parents when removing ifelse.
1 parent 876ac51 commit 2fdb50f

File tree

7 files changed

+55
-22
lines changed

7 files changed

+55
-22
lines changed

src/codegen/lower_constant.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ function lower_constant!(
103103
!opu₂ && suffix > 0 && return
104104
instr = instruction(op)
105105
instr.mod === GLOBALCONSTANT && return
106-
constsym = instr.instr
106+
constsym = constantopname(op)# instr.instr
107107
reducedchildvectorized = vloopsym reducedchildren(op)
108108
if reducedchildvectorized || isvectorized(op) || vloopsym reduceddependencies(op) || should_broadcast_op(op)
109109
# call = Expr(:call, lv(:vbroadcast), W, Expr(:call, lv(:maybeconvert), typeT, constsym))

src/condense_loopset.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -483,12 +483,26 @@ function split_ifelse!(
483483
falseops = operations(lsfalse)
484484
true_op = parents(true_ops[k])[2]
485485
falseop = parents_op[3]
486+
true_op.dependencies = loopdependencies(op)
487+
falseop.dependencies = loopdependencies(op)
488+
true_op.reduced_children = reducedchildren(op)
489+
falseop.reduced_children = reducedchildren(op)
486490
condop_count = 0
487491
for i eachindex(falseops)
488492
fop = falseops[i]
489493
parents_false = parents(fop)
490494
for (j,opp) enumerate(parents_false)
491495
if opp === op # then ops[i]'s jth parent is the ifelse
496+
# These reduction to scalar instructions are added for non-outer reductions initialized with non-constant ops
497+
# So we check if now
498+
# if (j == 2) && (Base.sym_in(instruction(fop).instr, (:reduced_add, :reduced_prod, :reduced_max, :reduced_min, :reduced_all, :reduced_any)))
499+
# if isconstantop(true_op)
500+
# (true_ops[i]).instruction = Instruction(:identity)
501+
# end
502+
# if isconstantop(falseop)
503+
# fop.instruction = Instruction(:identity)
504+
# end
505+
# end
492506
parents(true_ops[i])[j] = true_op
493507
parents_false[j] = falseop
494508
end
@@ -513,12 +527,10 @@ end
513527
function generate_call_split(
514528
ls::LoopSet, preserve::Vector{Symbol}, shouldindbyind::Vector{Bool}, roots::Vector{Bool}, extra_args::Expr, inlineu₁u₂::Tuple{Bool,Int8,Int8}, thread::UInt, debug::Bool
515529
)
516-
if !debug
517-
for (k,op) enumerate(operations(ls))
518-
parents_op = parents(op)
519-
if (iscompute(op) && (instruction(op).instr === :ifelse)) && (length(parents_op) == 3) && isconstantop(first(parents_op))
520-
return split_ifelse!(ls, preserve, shouldindbyind, roots, extra_args, k, inlineu₁u₂, thread, debug)
521-
end
530+
for (k,op) enumerate(operations(ls))
531+
parents_op = parents(op)
532+
if (iscompute(op) && (instruction(op).instr === :ifelse)) && (length(parents_op) == 3) && isconstantop(first(parents_op))
533+
return split_ifelse!(ls, preserve, shouldindbyind, roots, extra_args, k, inlineu₁u₂, thread, debug)
522534
end
523535
end
524536
return generate_call_types(ls, preserve, shouldindbyind, roots, extra_args, inlineu₁u₂, thread, debug)

src/modeling/costs.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,23 @@ function reduction_zero(x::Float64)
474474
throw("Reduction not found.")
475475
end
476476
end
477+
function reduction_zero_class(x::Symbol)::Float64
478+
if x === :one
479+
MULTIPLICATIVE_IN_REDUCTIONS
480+
elseif x === :typemin
481+
MAX
482+
elseif x === :typemax
483+
MIN
484+
elseif x === :max_mask
485+
ALL
486+
elseif x === :zero_mask
487+
ANY
488+
elseif x === :zero#sorted last, as should go into preamble_zeros
489+
ADDITIVE_IN_REDUCTIONS
490+
else
491+
throw("Reduction not found.")
492+
end
493+
end
477494
reduction_zero(x) = reduction_zero(reduction_instruction_class(x))
478495

479496
function isreductcombineinstr(instr::Symbol)

src/modeling/graphs.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,13 +1028,13 @@ function add_operation!(
10281028
f = first(RHS.args)
10291029
if f === :getindex
10301030
add_load_getindex!(ls, LHS, RHS, elementbytes)
1031-
elseif f === :zero || f === :one
1031+
elseif f isa Symbol && Base.sym_in(f, (:zero, :one, :typemin, :typemax))
10321032
c = gensym!(ls, f)
10331033
op = add_constant!(ls, c, ls.loopsymbols[1:position], LHS, elementbytes, :numericconstant)
10341034
if f === :zero
10351035
push!(ls.preamble_zeros, (identifier(op), IntOrFloat))
10361036
else
1037-
push!(ls.preamble_funcofeltypes, (identifier(op), MULTIPLICATIVE_IN_REDUCTIONS))
1037+
push!(ls.preamble_funcofeltypes, (identifier(op), reduction_zero_class(f)))
10381038
end
10391039
op
10401040
else
@@ -1070,14 +1070,14 @@ function add_operation!(
10701070
f = first(RHS.args)
10711071
if f === :getindex
10721072
add_load!(ls, LHS_sym, LHS_ref, elementbytes)
1073-
elseif f === :zero || f === :one
1073+
elseif f isa Symbol && Base.sym_in(f, (:zero, :one, :typemin, :typemax))
10741074
c = gensym!(ls, f)
10751075
op = add_constant!(ls, c, ls.loopsymbols[1:position], LHS_sym, elementbytes, :numericconstant)
10761076
# op = add_constant!(ls, c, Symbol[], LHS_sym, elementbytes, :numericconstant)
10771077
if f === :zero
10781078
push!(ls.preamble_zeros, (identifier(op), IntOrFloat))
10791079
else
1080-
push!(ls.preamble_funcofeltypes, (identifier(op), MULTIPLICATIVE_IN_REDUCTIONS))
1080+
push!(ls.preamble_funcofeltypes, (identifier(op), reduction_zero_class(f)))
10811081
end
10821082
op
10831083
else

test/offsetarrays.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,19 @@ using LoopVectorization: Static
196196
end
197197
out
198198
end
199-
function avxgeneric2!(out, A, kern)
200-
@avx for I in CartesianIndices(out)
201-
tmp = zero(eltype(out))
202-
for J in CartesianIndices(kern)
203-
tmp += A[I+J]*kern[J]
204-
end
205-
out[I] = tmp
199+
function avxgeneric2!(out, A, kern, keep = nothing)
200+
@avx for I in CartesianIndices(out)
201+
tmp = if keep === nothing
202+
zero(eltype(out))
203+
else
204+
out[I]
206205
end
207-
out
206+
for J in CartesianIndices(kern)
207+
tmp += A[I+J]*kern[J]
208+
end
209+
out[I] = tmp
210+
end
211+
out
208212
end
209213
function pparent(x) # go through nested parents
210214
px = parent(x)

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
include("testsetup.jl")
22

3-
import InteractiveUtils
3+
import InteractiveUtils, Aqua
44

55
InteractiveUtils.versioninfo(stdout; verbose = true)
66

@@ -12,7 +12,7 @@ const START_TIME = time()
1212

1313
@time @testset "LoopVectorization.jl" begin
1414

15-
# @time Aqua.test_all(LoopVectorization)
15+
@time Aqua.test_all(LoopVectorization)
1616
# @test isempty(detect_unbound_args(LoopVectorization))
1717

1818
@time include("printmethods.jl")

test/testsetup.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using Test, Aqua
1+
using Test
22
using LoopVectorization
33

44
using LinearAlgebra

0 commit comments

Comments
 (0)