Skip to content

Commit 358d3fc

Browse files
committed
Fixed before/after loop placement, hopefully improved performance through reducing calls/arguments passed to , and tweaked solve_tilesize.
1 parent f55123e commit 358d3fc

File tree

6 files changed

+112
-111
lines changed

6 files changed

+112
-111
lines changed

src/broadcast.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,9 @@ end
134134
function add_broadcast!(
135135
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{T}, elementbytes::Int = 8
136136
) where {T<:Union{Integer,Float32,Float64}}
137-
pushpreamble!(ls, Expr(:(=), Symbol("##", destname), bcname))
138-
add_constant!(ls, destname, elementbytes) # or replace elementbytes with sizeof(T) ? u
137+
op = add_constant!(ls, destname, elementbytes) # or replace elementbytes with sizeof(T) ? u
138+
pushpreamble!(ls, Expr(:(=), mangledvar(op), bcname))
139+
op
139140
end
140141
function add_broadcast!(
141142
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},

src/determinestrategy.jl

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,8 @@ end
236236
function solve_tilesize(X, R, Umax, Tmax)
237237
first(R) == 0 && return -1,-1,Inf #solve_smalltilesize(X, R, Umax, Tmax)
238238
U, T, cost = solve_tilesize(X, R)
239-
T -= T & 1
240-
U = min(U, T)
239+
# T -= T & 1
240+
# U = min(U, T)
241241
U_too_large = U > Umax
242242
T_too_large = T > Tmax
243243
if U_too_large
@@ -259,20 +259,21 @@ function solve_tilesize(
259259
cost_vec::AbstractVector{Float64} = @view(ls.cost_vec[:,1]),
260260
reg_pressure::AbstractVector{Int} = @view(ls.reg_pres[:,1])
261261
)
262-
maxT = 8
263-
maxU = 8
262+
maxT = 4#8
263+
maxU = 4#8
264264
if isstaticloop(ls, tiled)
265-
maxT = min(maxT, looprangehint(ls, tiled))
265+
maxT = min(2maxT, looprangehint(ls, tiled))
266266
end
267267
if isstaticloop(ls, unrolled)
268-
maxU = min(maxU, looprangehint(ls, unrolled))
268+
maxU = min(2maxU, looprangehint(ls, unrolled))
269269
end
270270
solve_tilesize(cost_vec, reg_pressure, maxU, maxT)
271271
end
272272

273-
function set_for_each_parent!(adal::Vector{T}, op::Operation, val::T) where {T}
274-
@inbounds for opp parents(op)
275-
adal[identifier(opp)] = val
273+
function set_upstream_family!(adal::Vector{T}, op::Operation, val::T) where {T}
274+
adal[identifier(op)] = val
275+
for opp parents(op)
276+
set_upstream_family!(adal, opp, val)
276277
end
277278
end
278279

@@ -333,7 +334,7 @@ function evaluate_cost_tile(
333334
unrolledtiled[1,id] = unrolled loopdependencies(op)
334335
unrolledtiled[2,id] = tiled loopdependencies(op)
335336
iters[id] = iter
336-
innerloop loopdependencies(op) && set_for_each_parent!(descendentsininnerloop, op, true)
337+
innerloop loopdependencies(op) && set_upstream_family!(descendentsininnerloop, op, true)
337338
end
338339
end
339340
for (id, op) enumerate(ops)

src/graphs.jl

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,10 @@ end
299299
function add_loop!(ls::LoopSet, loop::Loop)
300300
ls.loops[loop.itersym] = loop
301301
end
302-
function add_vptr!(ls::LoopSet, indexed::Symbol, id::Int)
302+
function add_vptr!(ls::LoopSet, indexed::Symbol, id::Int, ptr::Symbol = Symbol("##vptr##_", indexed))
303303
if includesarray(ls, indexed) < 0
304304
push!(ls.includedarrays, (indexed, id))
305-
pushpreamble!(ls, Expr(:(=), Symbol("##vptr##_", indexed), Expr(:call, lv(:stridedpointer), indexed)))
305+
pushpreamble!(ls, Expr(:(=), ptr, Expr(:call, lv(:stridedpointer), indexed)))
306306
end
307307
nothing
308308
end
@@ -339,7 +339,7 @@ function add_load!(
339339
:getindex, memload, loopdependencies(ref, ls),
340340
NODEPENDENCY, NOPARENTS, ref
341341
)
342-
add_vptr!(ls, ref.array, identifier(op))
342+
add_vptr!(ls, ref.array, identifier(op), ref.ptr)
343343
pushop!(ls, op, var)
344344
end
345345
function add_load_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
@@ -390,8 +390,9 @@ function add_constant!(ls::LoopSet, var::Symbol, elementbytes::Int = 8)
390390
end
391391
function add_constant!(ls::LoopSet, var, elementbytes::Int = 8)
392392
sym = gensym(:temp)
393-
pushpreamble!(ls, Expr(:(=), Symbol("##", sym), var))
394-
pushop!(ls, Operation(length(operations(ls)), sym, elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS), sym)
393+
op = Operation(length(operations(ls)), sym, elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS)
394+
pushpreamble!(ls, Expr(:(=), mangledvar(op), var))
395+
pushop!(ls, op, sym)
395396
end
396397
# This version has loop dependencies. var gets assigned to sym when lowering.
397398
function add_constant!(ls::LoopSet, var::Symbol, deps::Vector{Symbol}, sym::Symbol = gensym(:constant), elementbytes::Int = 8)
@@ -450,8 +451,9 @@ function add_parent!(
450451
parent = if var isa Symbol
451452
get!(ls.opdict, var) do
452453
# might add constant
453-
pushpreamble!(ls, Expr(:(=), Symbol("##", var), var))
454-
add_constant!(ls, var, elementbytes)
454+
op = add_constant!(ls, var, elementbytes)
455+
pushpreamble!(ls, Expr(:(=), mangledvar(op), var))
456+
op
455457
end
456458
elseif var isa Expr #CSE candidate
457459
maybe_cse_load!(ls, var, elementbytes)
@@ -511,7 +513,7 @@ function add_store!(
511513
)
512514
parent = getop(ls, var)
513515
op = Operation( length(operations(ls)), ref.array, elementbytes, :setindex!, memstore, loopdependencies(ref), reduceddependencies(parent), [parent], ref )
514-
add_vptr!(ls, ref.array, identifier(op))
516+
add_vptr!(ls, ref.array, identifier(op), ref.ptr)
515517
pushop!(ls, op, ref.array)
516518
end
517519
function add_store_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
@@ -603,18 +605,6 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
603605
end
604606
end
605607

606-
function place_after_loop!(adal::Vector{Bool}, op::Operation)
607-
pal = if isload(op) || length(reduceddependencies(op)) == 0
608-
1
609-
elseif length(reduceddependencies(op)) > 1
610-
2
611-
else
612-
rd = first(reduceddependencies(op))
613-
any(d -> d === rd, loopdependencies(op)) ? 1 : 2
614-
end
615-
pal == 1 && set_for_each_parent!(adal, op, false)
616-
pal
617-
end
618608

619609
function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
620610
lo = ls.loop_order
@@ -633,9 +623,9 @@ function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
633623
ops = operations(ls)
634624
nops = length(ops)
635625
included_vars = fill(false, nops)
636-
all_descendents_after_loop = fill(true, nops)
637-
positions = fill((-1,-1,-1,-1,-1), nops)#Vector{NTuple{5,Int}}(undef, nops)
626+
place_after_loop = fill(true, nops)
638627
# to go inside out, we just have to include all those not-yet included depending on the current sym
628+
empty!(lo)
639629
for _n 1:nloops
640630
n = 1 + nloops - _n
641631
ro[_n] = loopsym = order[n]
@@ -647,20 +637,11 @@ function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
647637
isunrolled = (unrolled loopdependencies(op)) + 1
648638
istiled = (loopistiled ? (tiled loopdependencies(op)) : false) + 1
649639
optype = Int(op.node_type) + 1
650-
after_loop = place_after_loop!(all_descendents_after_loop, op)
651-
positions[id] = (optype,isunrolled,istiled,after_loop,_n)
652-
end
653-
end
654-
empty!(lo)
655-
for id 1:nops
656-
optype,isunrolled,istiled,after_loop,_n = positions[id]
657-
optype == -1 && continue#@show ops[id]
658-
if all_descendents_after_loop[id]
659-
after_loop = 2
640+
after_loop = place_after_loop[id] + 1
641+
push!(lo[optype,isunrolled,istiled,after_loop,_n], ops[id])
642+
set_upstream_family!(place_after_loop, op, false) # parents that have already been included are not moved, so no need to check included_vars to filter
660643
end
661-
push!(lo[optype,isunrolled,istiled,after_loop,_n], ops[id])
662644
end
663-
# 3, ro, order
664645
end
665646

666647

src/lowering.jl

Lines changed: 37 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,6 @@
11

2-
# function unitstride(op::Operation, sym::Symbol)
3-
# (first(op.symbolic_metadata) === sym) && (first(op.numerical_metadata) == 1)
4-
# end
5-
6-
function variable_name(op::Operation, suffix)
7-
var = op.variable
8-
suffix === nothing ? var : Symbol(var, :_, suffix)
9-
end
10-
2+
variable_name(op::Operation, ::Nothing) = mangledvar(op)
3+
variable_name(op::Operation, suffix) = Symbol(mangledvar(op), suffix, :_)
114

125
function append_inds!(ret, indices, deps)
136
start = (first(indices) === Symbol("##DISCONTIGUOUSSUBARRAY##")) + 1# && return append_inds!(ret, @view(indices[2:end]), deps)
@@ -86,15 +79,15 @@ end
8679
# end
8780
# end
8881
function varassignname(var::Symbol, u::Int, isunrolled::Bool)
89-
isunrolled ? Symbol("##", var, :_, u) : Symbol("##", var)
82+
isunrolled ? Symbol(var, u) : var
9083
end
9184
# name_mo only gets called when vectorized
9285
function name_mo(var::Symbol, op::Operation, u::Int, W::Symbol, vecnotunrolled::Bool, unrolled::Symbol)
9386
if u < 0 # sentinel value meaning not unrolled
94-
name = Symbol("##",var)
87+
name = var
9588
mo = mem_offset(op)
9689
else
97-
name = Symbol("##",var,:_,u)
90+
name = Symbol(var, u)
9891
mo = vecnotunrolled ? mem_offset(op, u, unrolled) : mem_offset(op, W, u, unrolled)
9992
end
10093
name, mo
@@ -184,18 +177,18 @@ function lower_load!(
184177
end
185178
function reduce_range!(q::Expr, toreduct::Symbol, instr::Instruction, Uh::Int, Uh2::Int)
186179
for u 0:Uh-1
187-
tru = Symbol("##",toreduct,:_,u)
188-
push!(q.args, Expr(:(=), tru, Expr(instr, tru, Symbol("##",toreduct,:_,u + Uh))))
180+
tru = Symbol(toreduct, u)
181+
push!(q.args, Expr(:(=), tru, Expr(instr, tru, Symbol(toreduct, u + Uh))))
189182
end
190183
for u 2Uh:Uh2-1
191-
tru = Symbol("##",toreduct,:_, u + 1 - 2Uh)
192-
push!(q.args, Expr(:(=), tru, Expr(instr, tru, Symbol("##",toreduct,:_,u))))
184+
tru = Symbol(toreduct, u + 1 - 2Uh)
185+
push!(q.args, Expr(:(=), tru, Expr(instr, tru, Symbol(toreduct, u))))
193186
end
194187
end
195188
function reduce_range!(q::Expr, ls::LoopSet, Ulow::Int, Uhigh::Int)
196189
for or ls.outer_reductions
197190
op = ls.operations[or]
198-
var = op.variable
191+
var = mangledvar(op)
199192
temp = gensym(var)
200193
instr = op.instruction
201194
instr = get(REDUCTION_TRANSLATION, instr, instr)
@@ -221,10 +214,9 @@ function reduce_expr!(q::Expr, toreduct::Symbol, instr::Instruction, U::Int)
221214
nothing
222215
end
223216

224-
function pvariable_name(op::Operation, suffix)
225-
var = first(parents(op)).variable
226-
suffix === nothing ? var : Symbol(var, :_, suffix)
227-
end
217+
pvariable_name(op::Operation, ::Nothing) = mangledvar(first(parents(op)))
218+
pvariable_name(op::Operation, suffix) = Symbol(pvariable_name(op, nothing), suffix, :_)
219+
228220
function reduce_unroll!(q, op, U, unrolled)
229221
loopdeps = loopdependencies(op)
230222
isunrolled = unrolled loopdeps
@@ -332,13 +324,16 @@ function lower_compute_unrolled!(
332324
)
333325
lower_compute!(q, op, vectorized, W, unrolled, tiled, U, suffix, mask, true)
334326
end
327+
struct FalseCollection end
328+
Base.getindex(::FalseCollection, i...) = false
335329
function lower_compute!(
336330
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, tiled::Symbol, U::Int,
337331
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing,
338332
opunrolled = unrolled loopdependencies(op)
339333
)
340334

341335
var = op.variable
336+
mvar = mangledvar(op)
342337
parents_op = parents(op)
343338
nparents = length(parents_op)
344339
if opunrolled
@@ -355,11 +350,12 @@ function lower_compute!(
355350
parentstiled = if suffix === nothing
356351
optiled = false
357352
tiledouterreduction = -1
358-
fill(false, nparents)
353+
FalseCollection()
359354
else
360355
tiledouterreduction = isouterreduction(op)
356+
suffix_ = Symbol(suffix, :_)
361357
if tiledouterreduction == -1
362-
var = Symbol(var, :_, suffix)
358+
mvar = Symbol(mvar, suffix_)
363359
end
364360
optiled = true
365361
[tiled loopdependencies(opp) for opp parents_op]
@@ -384,25 +380,25 @@ function lower_compute!(
384380
instrcall = Expr(instr) # Expr(:call, instr)
385381
varsym = if tiledouterreduction > 0 # then suffix !== nothing
386382
modsuffix = ((u + suffix*U) & 3)
387-
Symbol("##",var,:_, modsuffix)
383+
Symbol(mvar, modsuffix)
388384
elseif opunrolled
389-
Symbol("##",var,:_,u)
385+
Symbol(mvar, u)
390386
else
391-
Symbol("##",var)
387+
mvar
392388
end
393389
for n 1:nparents
394-
parent = parents_op[n].variable
390+
parent = mangledvar(parents_op[n])
395391
if n == tiledouterreduction
396-
parent = Symbol(parent,:_,modsuffix)
392+
parent = Symbol(parent, modsuffix)
397393
else
398394
if parentstiled[n]
399-
parent = Symbol(parent,:_,suffix)
395+
parent = Symbol(parent, suffix_)
400396
end
401397
if parentsunrolled[n]
402-
parent = Symbol(parent,:_,u)
398+
parent = Symbol(parent, u)
403399
end
404400
end
405-
push!(instrcall.args, Symbol("##", parent))
401+
push!(instrcall.args, parent)
406402
end
407403
if maskreduct && u == Uiter # only mask last
408404
push!(q.args, Expr(:(=), varsym, Expr(:call, lv(:vifelse), mask, instrcall, varsym)))
@@ -429,21 +425,18 @@ function lower_constant!(
429425
q::Expr, op::Operation, vectorized::Symbol, W::Symbol, unrolled::Symbol, U::Int,
430426
suffix::Union{Nothing,Int}, mask::Any = nothing
431427
)
432-
@unpack variable, instruction = op
433-
if suffix !== nothing
434-
variable = Symbol(variable, :_, suffix)
435-
end
436-
# store parent's reduction deps
437-
# @show op.instruction, loopdependencies(op), reduceddependencies(op), unrolled, unrolled ∈ loopdependencies(op)
428+
instruction = op.instruction
429+
mvar = variable_name(op, suffix)
430+
438431
constsym = instruction.instr
439432
if vectorized loopdependencies(op) || vectorized reduceddependencies(op)
440433
call = Expr(:call, lv(:vbroadcast), W, constsym)
441434
for u 0:U-1
442-
push!(q.args, Expr(:(=), Symbol("##", variable, :_, u), call))
435+
push!(q.args, Expr(:(=), Symbol(mvar, u), call))
443436
end
444437
else
445438
for u 0:U-1
446-
push!(q.args, Expr(:(=), Symbol("##", variable, :_, u), constsym))
439+
push!(q.args, Expr(:(=), Symbol(mvar, u), constsym))
447440
end
448441
end
449442
nothing
@@ -621,16 +614,13 @@ function initialize_outer_reductions!(
621614
q::Expr, op::Operation, Umin::Int, Umax::Int, W::Symbol, typeT::Symbol, unrolled::Symbol, suffix::Union{Symbol,Nothing} = nothing
622615
)
623616
# T = op.elementbytes == 8 ? :Float64 : :Float32
624-
var = op.variable
625617
z = Expr(:call, REDUCTION_ZERO[op.instruction], typeT)
626618
if unrolled reduceddependencies(op)
627619
z = Expr(:call, lv(:vbroadcast), W, z)
628620
end
629-
if suffix !== nothing
630-
var = Symbol(var, :_, suffix)
631-
end
621+
mvar = variable_name(op, suffix)
632622
for u Umin:Umax-1
633-
push!(q.args, Expr(:(=), Symbol("##", var, :_, u), z))
623+
push!(q.args, Expr(:(=), Symbol(mvar, u), z))
634624
end
635625
nothing
636626
end
@@ -658,9 +648,10 @@ function reduce_expr!(q::Expr, ls::LoopSet, U::Int)
658648
for or ls.outer_reductions
659649
op = ls.operations[or]
660650
var = op.variable
651+
mvar = mangledvar(op)
661652
instr = op.instruction
662-
reduce_expr!(q, var, instr, U)
663-
push!(q.args, Expr(:(=), var, Expr(:call, REDUCTION_SCALAR_COMBINE[instr], var, Symbol("##", var, :_0))))
653+
reduce_expr!(q, mvar, instr, U)
654+
push!(q.args, Expr(:(=), var, Expr(:call, REDUCTION_SCALAR_COMBINE[instr], var, Symbol(mvar, 0))))
664655
end
665656
end
666657
function gc_preserve(ls::LoopSet, q::Expr)

0 commit comments

Comments
 (0)