Skip to content

Commit cf412d2

Browse files
committed
Tweaks/improvements to lowering of statically sized loops.
1 parent 3bd540f commit cf412d2

File tree

3 files changed

+39
-13
lines changed

3 files changed

+39
-13
lines changed

src/determinestrategy.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,11 +378,7 @@ function solve_unroll(X, R, u₁max, u₂max, u₁L, u₂L, u₁step, u₂step)
378378
u₁, u₂, cost
379379
end
380380
function maybedemotesize(U::Int, N::Int)
381-
# U > 1 || return 1
382-
Um1 = U - 1
383-
urep = num_iterations(N, U)
384-
um1rep = num_iterations(N, Um1)
385-
um1rep > urep ? U : Um1
381+
num_iterations(N, num_iterations(N, U))
386382
end
387383
function maybedemotesize(u₂::Int, N::Int, U::Int, Uloop::Loop, maxu₂base::Int)
388384
u₂ > 1 || return 1

src/lower_compute.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,23 @@ function load_constrained(op, u₁loop, u₂loop, forprefetch = false)
1919
dependsonu₂ && push!(unrolleddeps, u₂loop)
2020
any(opp -> isload(opp) && all(in(loopdependencies(opp)), unrolleddeps), parents(op))
2121
end
22+
function check_if_remfirst(ls, ua)
23+
usorig = ls.unrollspecification[]
24+
@unpack u₁, u₁loopsym, u₂loopsym, u₂max = ua
25+
u₁loop = getloop(ls, u₁loopsym)
26+
u₂loop = getloop(ls, u₂loopsym)
27+
if isstaticloop(u₁loop) && (usorig.u₁ != u₁)
28+
return true
29+
end
30+
if isstaticloop(u₂loop) && (usorig.u₂ != u₂max)
31+
return true
32+
end
33+
false
34+
end
35+
function sub_fmas(ls::LoopSet, op::Operation, ua::UnrollArgs)
36+
@unpack u₁, u₁loopsym, u₂loopsym, u₂max = ua
37+
!(load_constrained(op, u₁loopsym, u₂loopsym) || check_if_remfirst(ls, ua))
38+
end
2239

2340
struct FalseCollection end
2441
Base.getindex(::FalseCollection, i...) = false
@@ -106,7 +123,7 @@ function add_loopvalue!(instrcall::Expr, loopval, ua::UnrollArgs, u::Int)
106123
end
107124

108125
function lower_compute!(
109-
q::Expr, op::Operation, ua::UnrollArgs, mask::Union{Nothing,Symbol,Unsigned} = nothing,
126+
q::Expr, op::Operation, ls::LoopSet, ua::UnrollArgs, mask::Union{Nothing,Symbol,Unsigned} = nothing,
110127
)
111128
@unpack u₁, u₁loopsym, u₂loopsym, vectorized, suffix = ua
112129
var = name(op)
@@ -176,16 +193,16 @@ function lower_compute!(
176193
instrfid = findfirst(isequal(instr.instr), (:vfmadd_fast, :vfnmadd_fast, :vfmsub_fast, :vfnmsub_fast))
177194
# want to instcombine when parent load's deps are superset
178195
# also make sure opp is unrolled
179-
if instrfid !== nothing && (opunrolled && u₁ > 1) && !load_constrained(op, u₁loopsym, u₂loopsym)
180-
specific_fmas = Base.libllvm_version > v"11.0.0" ? (:vfmadd, :vfnmadd, :vfmsub, :vfnmsub) : (:vfmadd231, :vfnmadd231, :vfmsub231, :vfnmsub231)
196+
if !isnothing(instrfid) && (opunrolled && u₁ > 1) && sub_fmas(ls, op, ua)
197+
specific_fmas = Base.libllvm_version >= v"11.0.0" ? (:vfmadd, :vfnmadd, :vfmsub, :vfnmsub) : (:vfmadd231, :vfnmadd231, :vfmsub231, :vfnmsub231)
181198
# specific_fmas = (:vfmadd231, :vfnmadd231, :vfmsub231, :vfnmsub231)
182199
instr = Instruction(specific_fmas[instrfid])
183200
end
184201
end
185202
# @show instr.instr
186203
reduceddeps = reduceddependencies(op)
187204
vecinreduceddeps = isreduct && vectorized reduceddeps
188-
maskreduct = mask !== nothing && vecinreduceddeps #any(opp -> opp.variable === var, parents_op)
205+
maskreduct = !isnothing(mask) && vecinreduceddeps #any(opp -> opp.variable === var, parents_op)
189206
# if vecinreduceddeps && vectorized ∉ loopdependencies(op) # screen parent opps for those needing a reduction to scalar
190207
# # parents_op = reduce_vectorized_parents!(q, op, parents_op, U, u₁loopsym, u₂loopsym, vectorized, suffix)
191208
# isreducingidentity!(q, op, parents_op, U, u₁loopsym, u₂loopsym, vectorized, suffix) && return

src/lowering.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function lower!(
2121
elseif isload(op)
2222
lower_load!(q, op, ls, unrollargs, mask)
2323
elseif iscompute(op)
24-
lower_compute!(q, op, unrollargs, mask)
24+
lower_compute!(q, op, ls, unrollargs, mask)
2525
elseif isstore(op)
2626
lower_store!(q, ls, op, unrollargs, mask)
2727
# elseif isloopvalue(op)
@@ -45,7 +45,7 @@ function lower!(
4545
elseif isload(op)
4646
lower_load!(q, op, ls, unrollargs, mask)
4747
elseif iscompute(op)
48-
lower_compute!(q, op, unrollargs, mask)
48+
lower_compute!(q, op, ls, unrollargs, mask)
4949
end
5050
end
5151
end
@@ -317,13 +317,26 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
317317
1
318318
end
319319
remfirst = loopisstatic & (UFt > 0) & !(unsigned(Ureduct) < unsigned(UF))
320+
# @show remfirst, loopsym
320321
tc = terminatecondition(ls, us, n, inclmask, remfirst ? 1 : UF)
321322
usorig = ls.unrollspecification[]
322323

323324
# tc = (usorig.u₁ == us.u₁) && (usorig.u₂ == us.u₂) && !loopisstatic && !inclmask && !ls.loadelimination[] ? expect(tc) : tc
324325

325326
body = lower_block(ls, us, n, inclmask, UF)
326-
q = Expr(:while, tc, body)
327+
328+
q = if loopisstatic
329+
iters = length(loop) ÷ UF
330+
if iters 4
331+
q = Expr(:block)
332+
foreach(_ -> push!(q.args, body), 1:iters)
333+
q
334+
else
335+
Expr(:while, tc, body)
336+
end
337+
else
338+
Expr(:while, tc, body)
339+
end
327340
remblock = init_remblock(loop, ls.lssm[], n)#loopsym)
328341
q = if unsigned(Ureduct) < unsigned(UF) # unsigned(-1) == typemax(UInt); is logic relying on twos-complement bad?
329342
UF_cleanup = UF - Ureduct
@@ -579,7 +592,7 @@ function setup_preamble!(ls::LoopSet, us::UnrollSpecification)
579592
lower_licm_constants!(ls)
580593
isone(num_loops(ls)) || pushpreamble!(ls, definemask(getloop(ls, vectorized)))#, u₁ > 1 && u₁loopnum == vectorizedloopnum))
581594
for op operations(ls)
582-
(iszero(length(loopdependencies(op))) && iscompute(op)) && lower_compute!(ls.preamble, op, UnrollArgs(u₁, u₁loopsym, u₂loopsym, vectorized, u₂, nothing), nothing)
595+
(iszero(length(loopdependencies(op))) && iscompute(op)) && lower_compute!(ls.preamble, op, ls, UnrollArgs(u₁, u₁loopsym, u₂loopsym, vectorized, u₂, nothing), nothing)
583596
end
584597
# define_remaining_ops!( ls, vectorized, W, u₁loop, u₂loop, u₁ )
585598
end

0 commit comments

Comments
 (0)