Skip to content

Commit a73a101

Browse files
committed
Implemented better loop outlining. Should implement an even better version later...
1 parent 54fbcb6 commit a73a101

File tree

9 files changed

+97
-33
lines changed

9 files changed

+97
-33
lines changed

src/add_compute.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,11 @@ function add_reduction_update_parent!(
106106
elseif reduct_zero === :one
107107
push!(ls.preamble_ones, identifier(reductinit))
108108
else
109-
pushpreamble!(ls, Expr(:(=), name(reductinit), reductzero))
109+
if reductzero === :true || reductzero === :false
110+
pushpreamble!(ls, Expr(:(=), name(reductinit), reductzero))
111+
else
112+
pushpreamble!(ls, Expr(:(=), name(reductinit), Expr(:call, reductzero, ls.T)))
113+
end
110114
pushpreamble!(ls, op, name, reductinit)
111115
end
112116
if isconstant(parent) && reduct_zero === parent.instruction.mod #we can use parent op as initialization.
@@ -166,7 +170,7 @@ function add_compute!(
166170
elseif arg ls.loopsymbols
167171
loopsym = gensym(arg)
168172
pushpreamble!(ls, Expr(:(=), loopsym, LoopValue()))
169-
loopsymop = add_simple_load!(ls, gensym(loopsym), ArrayReference(loopsym, [arg]), elementbytes)
173+
loopsymop = add_simple_load!(ls, gensym(loopsym), ArrayReference(loopsym, [arg]), elementbytes, false)
170174
push!(ls.syms_aliasing_refs, name(loopsymop))
171175
push!(ls.refs_aliasing_syms, loopsymop.ref)
172176
pushparent!(parents, deps, reduceddeps, loopsymop)

src/add_loads.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ end
2727

2828
# for use with broadcasting
2929
function add_simple_load!(
30-
ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int
30+
ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int, actualarray::Bool = true
3131
)
3232
loopdeps = Symbol[s for s ref.indices]
3333
mref = ArrayReferenceMeta(
@@ -38,7 +38,7 @@ function add_simple_load!(
3838
:getindex, memload, loopdeps,
3939
NODEPENDENCY, NOPARENTS, mref
4040
)
41-
add_vptr!(ls, op)
41+
add_vptr!(ls, op.ref.ref.array, vptr(op.ref), actualarray)
4242
pushop!(ls, op, var)
4343
end
4444
function add_load_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int)

src/condense_loopset.jl

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,41 @@ function generate_call(ls::LoopSet, IUT)
203203
q
204204
end
205205

206-
function setup_call(ls::LoopSet, inline = Int8(2), U = zero(Int8), T = zero(Int8))
206+
function setup_call_noinline(ls::LoopSet, inline = Int8(2), U = zero(Int8), T = zero(Int8))
207+
call = generate_call(ls, (inline,U,T))
208+
hasouterreductions = length(ls.outer_reductions) > 0
209+
q = ls.preamble
210+
if hasouterreductions
211+
outer_reducts = Expr(:local)
212+
for or ls.outer_reductions
213+
op = ls.operations[or]
214+
var = name(op)
215+
mvar = mangledvar(op)
216+
out = Symbol(mvar, 0)
217+
push!(outer_reducts.args, out)
218+
end
219+
push!(q.args, outer_reducts)
220+
retv = loopset_return_value(ls, Val(false))
221+
call = Expr(:(=), retv, call)
222+
push!(q.args, gc_preserve(ls, call))
223+
push!(q.args, Expr(:return, retv))
224+
q = Expr(:block, Expr(:(=), retv, Expr(:call, Expr(:(->), Expr(:tuple, ls.includedactualarrays...), q), ls.includedactualarrays...)))
225+
for or ls.outer_reductions
226+
op = ls.operations[or]
227+
var = name(op)
228+
mvar = mangledvar(op)
229+
instr = instruction(op)
230+
out = Symbol(mvar, 0)
231+
push!(q.args, Expr(:(=), var, Expr(:call, lv(reduction_scalar_combine(instr)), out, var)))
232+
end
233+
else
234+
push!(q.args, gc_preserve(ls, call))
235+
push!(q.args, Expr(:return, :nothing))
236+
q = Expr(:call, Expr(:(->), Expr(:tuple, ls.includedactualarrays...), q), ls.includedactualarrays...)
237+
end
238+
q
239+
end
240+
function setup_call_inline(ls::LoopSet, inline = Int8(2), U = zero(Int8), T = zero(Int8))
207241
call = generate_call(ls, (inline,U,T))
208242
hasouterreductions = length(ls.outer_reductions) > 0
209243
if hasouterreductions
@@ -219,12 +253,22 @@ function setup_call(ls::LoopSet, inline = Int8(2), U = zero(Int8), T = zero(Int8
219253
instr = instruction(op)
220254
out = Symbol(mvar, 0)
221255
push!(outer_reducts.args, out)
222-
# push!(q.args, Expr(:(=), var, Expr(:call, lv(reduction_scalar_combine(instr)), Expr(:call, lv(:SVec), out), var)))
223256
push!(q.args, Expr(:(=), var, Expr(:call, lv(reduction_scalar_combine(instr)), out, var)))
224257
end
225258
hasouterreductions && pushpreamble!(ls, outer_reducts)
226259
append!(ls.preamble.args, q.args)
227260
ls.preamble
228261
end
229-
262+
function setup_call(ls::LoopSet, inline = Int8(2), U = zero(Int8), T = zero(Int8))
263+
# We outline/inline at the macro level by creating/not creating an anonymous function.
264+
# The old API instead was based on inlining or not inline the generated function, but
265+
# the generated function must be inlined into the initial loop preamble for performance reasons.
266+
# Creating an anonymous function and calling it also achieves the outlining, while still
267+
# inlining the generated function into the loop preamble.
268+
if inline == Int8(2)
269+
setup_call_inline(ls, Int8(2), U, T)
270+
else
271+
setup_call_noinline(ls, Int8(2), U, T)
272+
end
273+
end
230274

src/constructors.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,15 @@ function substitute_broadcast(q::Expr, mod::Symbol)
4141
ex
4242
end
4343

44+
4445
function LoopSet(q::Expr, mod::Symbol = :LoopVectorization)
4546
q = SIMDPirates.contract_pass(q)
4647
ls = LoopSet(mod)
4748
copyto!(ls, q)
4849
resize!(ls.loop_order, num_loops(ls))
4950
ls
5051
end
51-
52+
LoopSet(q::Expr, m::Module) = LoopSet(macroexpand(m, q), Symbol(m))
5253

5354
"""
5455
@avx
@@ -84,11 +85,10 @@ true
8485
8586
"""
8687
macro avx(q)
87-
mod = Symbol(__module__)
8888
q2 = if q.head === :for
89-
setup_call(LoopSet(q, mod))
89+
setup_call(LoopSet(q, __module__))
9090
else# assume broadcast
91-
substitute_broadcast(q, mod)
91+
substitute_broadcast(q, Symbol(__module__))
9292
end
9393
esc(q2)
9494
end
@@ -130,24 +130,24 @@ macro avx(arg, q)
130130
@assert q.head === :for
131131
@assert arg.head === :(=)
132132
inline, U, T = check_macro_kwarg(arg)
133-
esc(setup_call(LoopSet(q, Symbol(__module__)), inline, U, T))
133+
esc(setup_call(LoopSet(q, __module__), inline, U, T))
134134
end
135135
macro avx(arg1, arg2, q)
136136
@assert q.head === :for
137137
inline, U, T = check_macro_kwarg(arg1)
138138
inline, U, T = check_macro_kwarg(arg2, inline, U, T)
139-
esc(setup_call(LoopSet(q, Symbol(__module__)), inline, U, T))
139+
esc(setup_call(LoopSet(q, __module__), inline, U, T))
140140
end
141141

142142

143143

144144
macro _avx(q)
145-
esc(lower(LoopSet(q, Symbol(__module__))))
145+
esc(lower(LoopSet(q, __module__)))
146146
end
147147
macro _avx(arg, q)
148148
@assert q.head === :for
149149
inline, U, T = check_macro_kwarg(arg)
150-
esc(lower(LoopSet(q, Symbol(__module__)), U, T))
150+
esc(lower(LoopSet(q, __module__), U, T))
151151
end
152152

153153

src/costs.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ const COST = Dict{Instruction,InstructionCost}(
131131
Instruction(:>>) => InstructionCost(1, 0.5),
132132
Instruction(:>>>) => InstructionCost(1, 0.5),
133133
Instruction(:<<) => InstructionCost(1, 0.5),
134+
Instruction(:max) => InstructionCost(4,0.5),
135+
Instruction(:min) => InstructionCost(4,0.5),
134136
Instruction(:ifelse) => InstructionCost(1, 0.5),
135137
Instruction(:vifelse) => InstructionCost(1, 0.5),
136138
Instruction(:inv) => InstructionCost(13,4.0,-2.0,1),
@@ -185,6 +187,8 @@ const ADDITIVE_IN_REDUCTIONS = 1.0
185187
const MULTIPLICATIVE_IN_REDUCTIONS = 2.0
186188
const ANY = 3.0
187189
const ALL = 4.0
190+
const MAX = 5.0
191+
const MIN = 6.0
188192

189193
const REDUCTION_CLASS = Dict{Symbol,Float64}(
190194
:+ => ADDITIVE_IN_REDUCTIONS,
@@ -213,28 +217,30 @@ const REDUCTION_CLASS = Dict{Symbol,Float64}(
213217
:reduced_add => ADDITIVE_IN_REDUCTIONS,
214218
:reduced_prod => MULTIPLICATIVE_IN_REDUCTIONS,
215219
:reduced_all => ALL,
216-
:reduced_any => ANY
220+
:reduced_any => ANY,
221+
:max => MAX,
222+
:min => MIN
217223
)
218224
reduction_instruction_class(instr::Symbol) = get(REDUCTION_CLASS, instr, NaN)
219225
reduction_instruction_class(instr::Instruction) = get(REDUCTION_CLASS, instr.instr, NaN)
220226
function reduction_to_single_vector(x::Float64)
221-
x == 1.0 ? :evadd : x == 2.0 ? :evmul : x == 3.0 ? :vand : x == 4.0 ? :vor : throw("Reduction not found.")
227+
x == 1.0 ? :evadd : x == 2.0 ? :evmul : x == 3.0 ? :vor : x == 4.0 ? :vand : x == 5.0 ? :max : x == 6.0 ? :min : throw("Reduction not found.")
222228
end
223229
reduction_to_single_vector(x) = reduction_to_single_vector(reduction_instruction_class(x))
224230
function reduction_to_scalar(x::Float64)
225-
x == 1.0 ? :vsum : x == 2.0 ? :vprod : x == 3.0 ? :vany : x == 4.0 ? :vall : throw("Reduction not found.")
231+
x == 1.0 ? :vsum : x == 2.0 ? :vprod : x == 3.0 ? :vany : x == 4.0 ? :vall : x == 5.0 ? :maximum : x == 6.0 ? :minimum : throw("Reduction not found.")
226232
end
227233
reduction_to_scalar(x) = reduction_to_scalar(reduction_instruction_class(x))
228234
function reduction_scalar_combine(x::Float64)
229-
x == 1.0 ? :reduced_add : x == 2.0 ? :reduced_prod : x == 3.0 ? :reduced_any : x == 4.0 ? :reduced_all : throw("Reduction not found.")
235+
x == 1.0 ? :reduced_add : x == 2.0 ? :reduced_prod : x == 3.0 ? :reduced_any : x == 4.0 ? :reduced_all : x == 5.0 ? :reduced_max : x == 6.0 ? :reduced_min : throw("Reduction not found.")
230236
end
231237
reduction_scalar_combine(x) = reduction_scalar_combine(reduction_instruction_class(x))
232238
function reduction_combine_to(x::Float64)
233-
x == 1.0 ? :reduce_to_add : x == 2.0 ? :reduce_to_prod : x == 3.0 ? :reduce_to_any : x == 4.0 ? :reduce_to_all : throw("Reduction not found.")
239+
x == 1.0 ? :reduce_to_add : x == 2.0 ? :reduce_to_prod : x == 3.0 ? :reduce_to_any : x == 4.0 ? :reduce_to_all : x == 5.0 ? :reduce_to_max : x == 6.0 ? :reduce_to_min : throw("Reduction not found.")
234240
end
235241
reduction_combine_to(x) = reduction_combine_to(reduction_instruction_class(x))
236242
function reduction_zero(x::Float64)
237-
x == 1.0 ? :zero : x == 2.0 ? :one : x == 3.0 ? :false : x == 4.0 ? :true : throw("Reduction not found.")
243+
x == 1.0 ? :zero : x == 2.0 ? :one : x == 3.0 ? :false : x == 4.0 ? :true : x == 5.0 ? :typemin : x == 6.0 ? :typemax : throw("Reduction not found.")
238244
end
239245
reduction_zero(x) = reduction_zero(reduction_instruction_class(x))
240246

@@ -291,6 +297,11 @@ const FUNCTIONSYMBOLS = Dict{Type{<:Function},Instruction}(
291297
typeof(SLEEFPirates.cos) => :cos,
292298
typeof(sincos) => :sincos,
293299
typeof(Base.FastMath.sincos_fast) => :sincos,
294-
typeof(SLEEFPirates.sincos) => :sincos
300+
typeof(SLEEFPirates.sincos) => :sincos,
301+
typeof(max) => :max,
302+
typeof(min) => :min,
303+
typeof(<<) => :<<,
304+
typeof(>>) => :>>,
305+
typeof(>>>) => :>>>
295306
)
296307

src/graphs.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ struct LoopSet
157157
preamble_zeros::Vector{Int}
158158
preamble_ones::Vector{Int}
159159
includedarrays::Vector{Symbol}
160+
includedactualarrays::Vector{Symbol}
160161
syms_aliasing_refs::Vector{Symbol}
161162
refs_aliasing_syms::Vector{ArrayReferenceMeta}
162163
cost_vec::Matrix{Float64}
@@ -228,8 +229,7 @@ function LoopSet(mod::Symbol)# = :LoopVectorization)
228229
Tuple{Int,Int}[],
229230
Tuple{Int,Float64}[],
230231
Int[],Int[],
231-
Tuple{Symbol,Int}[],
232-
Symbol[],
232+
Symbol[], Symbol[], Symbol[],
233233
ArrayReferenceMeta[],
234234
Matrix{Float64}(undef, 4, 2),
235235
Matrix{Int}(undef, 4, 2),

src/lowering.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,22 +204,23 @@ function reduce_expr!(q::Expr, ls::LoopSet, U::Int)
204204
end
205205
end
206206
function gc_preserve(ls::LoopSet, q::Expr)
207-
length(ls.includedarrays) == 0 && return q
207+
length(ls.includedactualarrays) == 0 && return q
208208
gcp = Expr(:macrocall, Expr(:(.), :GC, QuoteNode(Symbol("@preserve"))), LineNumberNode(@__LINE__, @__FILE__))
209-
for array ls.includedarrays
209+
for array ls.includedactualarrays
210210
push!(gcp.args, array)
211211
end
212212
q.head === :block && push!(q.args, nothing)
213213
push!(gcp.args, q)
214214
Expr(:block, gcp)
215215
end
216216
function determine_eltype(ls::LoopSet)
217-
# length(ls.includedarrays) == 0 && return REGISTER_SIZE >>> 3
218-
if length(ls.includedarrays) == 1
219-
return Expr(:call, :eltype, first(ls.includedarrays))
217+
if length(ls.includedactualarrays) == 0
218+
return Expr(:call, :typeof, 0)
219+
elseif length(ls.includedactualarrays) == 1
220+
return Expr(:call, :eltype, first(ls.includedactualarrays))
220221
end
221222
promote_q = Expr(:call, :promote_type)
222-
for array ls.includedarrays
223+
for array ls.includedactualarrays
223224
push!(promote_q.args, Expr(:call, :eltype, array))
224225
end
225226
promote_q

src/memory_ops_common.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
add_vptr!(ls::LoopSet, op::Operation) = add_vptr!(ls, op.ref)
22
add_vptr!(ls::LoopSet, mref::ArrayReferenceMeta) = add_vptr!(ls, mref.ref.array, vptr(mref))
3-
function add_vptr!(ls::LoopSet, array::Symbol, vptrarray::Symbol = vptr(array))
3+
function add_vptr!(ls::LoopSet, array::Symbol, vptrarray::Symbol = vptr(array), actualarray::Bool = true)
44
if !includesarray(ls, array)
55
push!(ls.includedarrays, array)
6+
actualarray && push!(ls.includedactualarrays, array)
67
pushpreamble!(ls, Expr(:(=), vptrarray, Expr(:call, lv(:stridedpointer), array)))
78
end
89
nothing

test/runtests.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,14 @@ end
9494
end
9595
function dot_unroll2avx_noinline(x::Vector{T}, y::Vector{T}) where {T<:AbstractFloat}
9696
z = zero(T)
97-
@avx inline=true unroll=2 for i 1:length(x)
97+
@avx inline=false unroll=2 for i 1:length(x)
9898
z += x[i]*y[i]
9999
end
100100
return z
101101
end
102102
function dot_unroll3avx_inline(x::Vector{T}, y::Vector{T}) where {T<:AbstractFloat}
103103
z = zero(T)
104-
@avx unroll=3 inline=false for i 1:length(x)
104+
@avx unroll=3 inline=true for i 1:length(x)
105105
z += x[i]*y[i]
106106
end
107107
return z
@@ -245,6 +245,9 @@ end
245245
res[i] = sin(i * code_phase_delta)
246246
end
247247
end
248+
@macroexpand @avx for i eachindex(res)
249+
res[i] = sin(i * code_phase_delta)
250+
end
248251
function calc_sins_avx!(res::AbstractArray{T}) where {T}
249252
code_phase_delta = T(0.01)
250253
@_avx for i eachindex(res)

0 commit comments

Comments
 (0)