Skip to content

Commit a311e1a

Browse files
committed
Interpret lonely loads in if/else statements as being conditional loads.
1 parent d238ab2 commit a311e1a

File tree

4 files changed

+72
-35
lines changed

4 files changed

+72
-35
lines changed

src/add_ifelse.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
## Currently, if/else will create its own local scope
33
## Assignments will not register in the loop's main scope
44
## although stores and return values will.
5-
5+
negateop!(ls::LoopSet, condop::Operation, elementbytes::Int) = add_compute!(ls, gensym(:negated_mask), :~, [condop], elementbytes)
66

77
function add_if!(ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int, position::Int, mpref::Union{Nothing,ArrayReferenceMetaPosition} = nothing)
88
# for now, just simple 1-liners
@@ -14,18 +14,24 @@ function add_if!(ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int, positio
1414
add_operation!(ls, gensym(:mask), condition, mpref, elementbytes, position)
1515
end
1616
iftrue = RHS.args[2]
17-
trueop = if iftrue isa Expr
18-
(iftrue isa Expr && iftrue.head !== :call) && throw("Only calls or constant expressions are currently supported in if/else blocks.")
19-
add_operation!(ls, Symbol(:iftrue), iftrue, elementbytes, position)
17+
if iftrue isa Expr
18+
trueop = add_operation!(ls, Symbol(:iftrue), iftrue, elementbytes, position)
19+
if iftrue.head === :ref && all(ld -> ld loopdependencies(trueop), loopdependencies(condop))
20+
trueop.instruction = Instruction(:conditionalload)
21+
push!(parents(trueop), condop)
22+
end
2023
else
21-
getop(ls, iftrue, elementbytes)
24+
trueop = getop(ls, iftrue, elementbytes)
2225
end
2326
iffalse = RHS.args[3]
24-
falseop = if iffalse isa Expr
25-
(iffalse isa Expr && iffalse.head !== :call) && throw("Only calls or constant expressions are currently supported in if/else blocks.")
26-
add_operation!(ls, Symbol(:iffalse), iffalse, elementbytes, position)
27+
if iffalse isa Expr
28+
falseop = add_operation!(ls, Symbol(:iffalse), iffalse, elementbytes, position)
29+
if iffalse.head === :ref && all(ld -> ld loopdependencies(falseop), loopdependencies(condop))
30+
falseop.instruction = Instruction(:conditionalload)
31+
push!(parents(falseop), negateop!(ls, condop, elementbytes))
32+
end
2733
else
28-
getop(ls, iffalse, elementbytes)
34+
falseop = getop(ls, iffalse, elementbytes)
2935
end
3036
add_compute!(ls, LHS, :vifelse, [condop, trueop, falseop], elementbytes)
3137
end
@@ -67,7 +73,7 @@ function add_andblock!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
6773
end
6874

6975
function add_orblock!(ls::LoopSet, condop::Operation, LHS, rhsop::Operation, elementbytes::Int, position::Int)
70-
negatedcondop = add_compute!(ls, gensym(:negated_mask), :~, [condop], elementbytes)
76+
negatedcondop = negateop!(ls, condop, elementbytes)
7177
if LHS isa Symbol
7278
altop = getop(ls, LHS, elementbytes)
7379
# return add_compute!(ls, LHS, :vifelse, [condop, altop, rhsop], elementbytes)

src/lower_load.jl

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,63 @@
11
function lower_load_scalar!(
2-
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
3-
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing, umin::Int = 0
2+
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol,
3+
tiled::Symbol, U::Int, suffix::Union{Nothing,Int}, umin::Int = 0
44
)
55
loopdeps = loopdependencies(op)
66
@assert vectorized loopdeps
77
var = variable_name(op, suffix)
88
ptr = refname(op)
99
isunrolled = unrolled loopdeps
1010
U = isunrolled ? U : 1
11-
for u umin:U-1
12-
varname = varassignname(var, u, isunrolled)
13-
td = UnrollArgs(u, unrolled, tiled, suffix)
14-
push!(q.args, Expr(:(=), varname, Expr(:call, lv(:vload), ptr, mem_offset_u(op, td))))
11+
if instruction(op).instr !== :conditionalload
12+
for u umin:U-1
13+
varname = varassignname(var, u, isunrolled)
14+
td = UnrollArgs(u, unrolled, tiled, suffix)
15+
push!(q.args, Expr(:(=), varname, Expr(:call, lv(:vload), ptr, mem_offset_u(op, td))))
16+
end
17+
else
18+
condop = last(parents(op))
19+
condvar = variable_name(condop, suffix)
20+
condunrolled = any(isequal(unrolled), loopdependencies(condop))
21+
for u umin:U-1
22+
condsym = condunrolled ? Symbol(condvar, u) : condvar
23+
varname = varassignname(var, u, isunrolled)
24+
td = UnrollArgs(u, unrolled, tiled, suffix)
25+
load = Expr(:call, lv(:vload), ptr, mem_offset_u(op, td))
26+
cload = Expr(:if, condsym, load, Expr(:call, :zero, Expr(:call, :eltype, ptr)))
27+
push!(q.args, Expr(:(=), varname, cload))
28+
end
1529
end
1630
nothing
1731
end
18-
function pushvectorload!(q::Expr, op::Operation, var::Symbol, td::UnrollArgs, U::Int, W::Symbol, mask, vecnotunrolled::Bool)
19-
@unpack u, unrolled = td
32+
function pushvectorload!(
33+
q::Expr, op::Operation, var::Symbol, td::UnrollArgs, U::Int, W::Symbol, vectorized::Symbol, mask
34+
)
35+
@unpack u, unrolled, suffix = td
2036
ptr = refname(op)
37+
vecnotunrolled = vectorized !== unrolled
2138
name, mo = name_memoffset(var, op, td, W, vecnotunrolled)
2239
instrcall = Expr(:call, lv(:vload), ptr, mo)
23-
if mask !== nothing && (vecnotunrolled || u == U - 1)
40+
41+
iscondstore = instruction(op).instr === :conditionalload
42+
maskend = mask !== nothing && (vecnotunrolled || u == U - 1)
43+
if iscondstore
44+
condop = last(parents(op))
45+
# @show condop
46+
condsym = variable_name(condop, suffix)
47+
condsym = any(isequal(unrolled), loopdependencies(condop)) ? Symbol(condsym, u) : condsym
48+
if vectorized loopdependencies(condop)
49+
if maskend
50+
push!(instrcall.args, Expr(:call, :&, condsym, mask))
51+
else
52+
push!(instrcall.args, condsym)
53+
end
54+
else
55+
if maskend
56+
push!(instrcall.args, mask)
57+
end
58+
instrcall = Expr(:if, condsym, instrcall, Expr(:call, lv(:vzero), W, Expr(:call, :eltype, ptr)))
59+
end
60+
elseif maskend
2461
push!(instrcall.args, mask)
2562
end
2663
push!(q.args, Expr(:(=), name, instrcall))
@@ -40,10 +77,9 @@ function lower_load_vectorized!(
4077
end
4178
# Urange = unrolled ∈ loopdeps ? 0:U-1 : 0
4279
var = variable_name(op, suffix)
43-
vecnotunrolled = vectorized !== unrolled
4480
for u umin:U-1
4581
td = UnrollArgs(u, unrolled, tiled, suffix)
46-
pushvectorload!(q, op, var, td, U, W, mask, vecnotunrolled)
82+
pushvectorload!(q, op, var, td, U, W, vectorized, mask)
4783
end
4884
nothing
4985
end
@@ -73,6 +109,6 @@ function lower_load!(
73109
if vectorized loopdependencies(op)
74110
lower_load_vectorized!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask, umin)
75111
else
76-
lower_load_scalar!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask, umin)
112+
lower_load_scalar!(q, op, vectorized, W, unrolled, tiled, U, suffix, umin)
77113
end
78114
end

src/lower_store.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ end
4141
# const STOREOP = :vstore!
4242
variable_name(op::Operation, ::Nothing) = mangledvar(op)
4343
variable_name(op::Operation, suffix) = Symbol(mangledvar(op), suffix, :_)
44+
# variable_name(op::Operation, suffix, u::Int) = (n = variable_name(op, suffix); u < 0 ? n : Symbol(n, u))
4445
function reduce_range!(q::Expr, toreduct::Symbol, instr::Instruction, Uh::Int, Uh2::Int)
4546
for u 0:Uh-1
4647
tru = Symbol(toreduct, u)

test/miscellaneous.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,6 @@ using Test
681681
ifirst, ilast = first(irng), last(irng)
682682
ifirst > ilast && return s
683683
@avx tile=(1,1) for Ipost in Rpost
684-
# Handle all other entries
685684
for Ipre in Rpre
686685
s[Ipre, ifirst, Ipost] = x[Ipre, ifirst, Ipost]
687686
for i = ifirst+1:ilast
@@ -694,15 +693,10 @@ using Test
694693
function smoothdim_ifelse_avx!(s, x, α, Rpre, irng::AbstractUnitRange, Rpost)
695694
ifirst, ilast = first(irng), last(irng)
696695
ifirst > ilast && return s
697-
@avx tile=(1,1) for Ipost in Rpost
698-
# Handle all other entries
699-
for i = ifirst:ilast
700-
for Ipre in Rpre
701-
xi = x[Ipre, i, Ipost]
702-
xim = ifelse(i == ifirst, xi, x[Ipre, i-1, Ipost])
703-
s[Ipre, i, Ipost] = α*xi + (1-α)*xim
704-
end
705-
end
696+
@avx tile=(1,1) for Ipost in Rpost, i = ifirst:ilast, Ipre in Rpre
697+
xi = x[Ipre, i, Ipost]
698+
xim = i > ifirst ? x[Ipre, i-1, Ipost] : xi
699+
s[Ipre, i, Ipost] = α*xi + (1-α)*xim
706700
end
707701
s
708702
end
@@ -714,10 +708,10 @@ using Test
714708
# @show d
715709
Rpre = CartesianIndices(axes(x)[1:d-1]);
716710
Rpost = CartesianIndices(axes(x)[d+1:end]);
717-
smoothdim!(dest1, x, α, Rpre, axes(x, d), Rpost)
718-
smoothdim_avx!(dest2, x, α, Rpre, axes(x, d), Rpost)
711+
smoothdim!(dest1, x, α, Rpre, axes(x, d), Rpost);
712+
smoothdim_avx!(dest2, x, α, Rpre, axes(x, d), Rpost);
719713
@test dest1 dest2
720-
fill!(dest2, NaN); smoothdim_ifelse_avx!(dest2, x, α, Rpre, axes(x, d), Rpost)
714+
fill!(dest2, NaN); smoothdim_ifelse_avx!(dest2, x, α, Rpre, axes(x, d), Rpost);
721715
@test dest1 dest2
722716
end
723717
end

0 commit comments

Comments
 (0)