Skip to content

Commit b308bb0

Browse files
committed
Updated cost modeling for twice unrolled-loops. It should now be correctly assigning costs based on position in loop nests (with respect to whether or not costs are reduced by unrolling), and also handles cost reductions caused by repeated loads with small constant offsets correctly, encouraging unrolling of those specific loops.
1 parent f0f1309 commit b308bb0

File tree

4 files changed

+151
-46
lines changed

4 files changed

+151
-46
lines changed

benchmark/driver.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,18 @@
22
# const LOOPVECBENCHDIR = joinpath(pkgdir("LoopVectorization"), "benchmarks")
33
# includet(joinpath(LOOPVECBENCHDIR, "driver.jl"))
44

5-
using Distributed
5+
using Distributed, LoopVectorization
66

7-
pkgdir(pkg::String) = abspath(joinpath(dirname(Base.find_package(pkg)), ".."))
8-
const LOOPVECBENCHDIR = joinpath(pkgdir("LoopVectorization"), "benchmark")
7+
const LOOPVECBENCHDIR = joinpath(pkgdir(LoopVectorization), "benchmark")
98
include(joinpath(LOOPVECBENCHDIR, "benchmarkflops.jl"))
109
include(joinpath(LOOPVECBENCHDIR, "plotbenchmarks.jl"))
1110

1211

1312
addprocs((Sys.CPU_THREADS >> 1)-1); nworkers()
1413

1514
@everywhere begin
16-
pkgdir(pkg::String) = abspath(joinpath(dirname(Base.find_package(pkg)), ".."))
17-
const LOOPVECBENCHDIR = joinpath(pkgdir("LoopVectorization"), "benchmark")
15+
using LoopVectorization
16+
const LOOPVECBENCHDIR = joinpath(pkgdir(LoopVectorization), "benchmark")
1817
include(joinpath(LOOPVECBENCHDIR, "benchmarkflops.jl"))
1918
# BenchmarkTools.DEFAULT_PARAMETERS.seconds = 1
2019
end

benchmark/loadsharedlibs.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
1-
using LinearAlgebra
1+
using LinearAlgebra, LoopVectorization
22
using LoopVectorization.VectorizationBase: REGISTER_SIZE
33

4-
pkgdir(pkg::String) = abspath(joinpath(dirname(Base.find_package(pkg)), ".."))
5-
const LOOPVECBENCHDIR = joinpath(pkgdir("LoopVectorization"), "benchmark")
4+
# const LOOPVECBENCHDIR = joinpath(pkgdir(LoopVectorization), "benchmark")
65
include(joinpath(LOOPVECBENCHDIR, "looptests.jl"))
76

87
const LIBCTEST = joinpath(LOOPVECBENCHDIR, "libctests.so")

src/determinestrategy.jl

Lines changed: 143 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ function unroll_cost(X, u₁, u₂, u₁L, u₂L)
240240
u₂factor = (num_iterations(u₂L, u₂)/u₂L)
241241
u₁factor = (num_iterations(u₁L, u₁)/u₁L)
242242
# X[1]*u₂factor*u₁factor + X[4] + X[2] * u₂factor + X[3] * u₁factor
243-
X[1] + X[4] + X[2] * ufactor + X[3] * u₁factor
243+
X[1] + X[2] * u₂factor + X[3] * ufactor + X[4] * u₁factor * u₂factor
244244
end
245245
# function itertilesize(X, u₁L, u₂L)
246246
# cb = Inf
@@ -256,14 +256,63 @@ end
256256
# u₁b, u₂b, cb
257257
# end
258258
function solve_unroll(X, R, u₁L, u₂L)
259+
X₁, X₂, X₃, X₄ = X[1], X[2], X[3], X[4]
260+
R₁, R₂, R₃, R₄ = R[1], R[2], R[3], R[4]
261+
RR = REGISTER_COUNT - R₃ - R₄
262+
a = R₂^2*X₃ -R₁*X₄ * R₂ - R₁*X₂*RR
263+
b = R₁ * X₄ * RR - R₁ * X₄ * RR - 2X₃*RR*R₂
264+
c = X₃*RR^2
265+
discriminant = b^2 - 4a*c
266+
discriminant < 0 && return -1,-1,Inf
267+
u₁float = max(1.0, (sqrt(discriminant) + b) / (-2a)) # must be at least 1
268+
u₂float = (RR - u₁float*R₂)/(u₁float*R₁)
269+
if !(isfinite(u₂float) && isfinite(u₁float))
270+
return 4, 4, unroll_cost(X, 4, 4, u₁L, u₂L)
271+
# return itertilesize(X, u₁L, u₂L)
272+
end
273+
u₁low = floor(Int, u₁float)
274+
u₂low = max(1, floor(Int, u₂float)) # must be at least 1
275+
u₁high = u₁low + 1 #ceil(Int, u₁float)
276+
u₂high = u₂low + 1 #ceil(Int, u₂float)
277+
278+
# RR = REGISTER_COUNT - R[3] - R[4]
279+
u₁, u₂ = u₁low, u₂low
280+
ucost = unroll_cost(X, u₁low, u₂low, u₁L, u₂L)
281+
# @show u₁low*u₂high*R[1] + u₁low*R[2]
282+
if RR u₁low*u₂high*R[1] + u₁low*R[2]
283+
ucost_temp = unroll_cost(X, u₁low, u₂high, u₁L, u₂L)
284+
# @show ucost_temp, ucost
285+
if ucost_temp < ucost
286+
ucost = ucost_temp
287+
u₁, u₂ = u₁low, u₂high
288+
end
289+
end
290+
# The RR + 1 is a hack to get it to favor u₁high in more scenarios
291+
u₂l = u₂low
292+
while RR < u₁high*u₂l*R[1] + u₁high*R[2] && u₂l > 1
293+
u₂l -= 1
294+
end
295+
ucost_temp = unroll_cost(X, u₁high, u₂l, u₁L, u₂L)
296+
if ucost_temp < ucost
297+
ucost = ucost_temp
298+
u₁, u₂ = u₁high, u₂l
299+
end
300+
if RR > u₁high*u₂high*R[1] + u₁high*R[2]
301+
throw("Something went wrong when solving for u₂float and u₁float.")
302+
end
303+
u₁, u₂, ucost
304+
end
305+
306+
function solve_unrollold(X, R, u₁L, u₂L)
259307
# @inbounds any(iszero, (R[1],R[2],R[3])) && return -1,-1,Inf #solve_smalltilesize(X, R, u₁max, u₂max)
260-
first(iszero(R)) && return -1,-1,Inf #solve_smalltilesize(X, R, u₁max, u₂max)
308+
# (iszero(X[2]) || iszero(X[3])) && return -1,-1,Inf #solve_smalltilesize(X, R, u₁max, u₂max)
309+
iszero(first(R)) && return -1,-1,Inf #solve_smalltilesize(X, R, u₁max, u₂max)
261310
# @inbounds any(iszero, (R[1],R[2],R[3])) && return -1,-1,Inf #solve_smalltilesize(X, R, u₁max, u₂max)
262311
# We use a lagrange multiplier to find floating point values for u₁ and u₂
263312
# first solving for u₁ via quadratic formula
264313
# X is vector of costs, and R is of register pressures
265314
RR = REGISTER_COUNT - R[3] - R[4] # RR ≡ RemainingRegisters
266-
R[1] + R[2] > 0.5RR && return 1,1, unroll_cost(X, 1, 1, u₁L, u₂L)
315+
R[1] + R[2] > 0.5RR && return 1, 1, unroll_cost(X, 1, 1, u₁L, u₂L)
267316
a = (R[1])^2*X[2] - (R[2])^2*R[1]*X[3]/RR
268317
b = 2*R[1]*R[2]*X[3]
269318
c = -RR*R[1]*X[3]
@@ -493,15 +542,24 @@ function maxnegativeoffset(ls::LoopSet, op::Operation, u::Symbol)
493542
end
494543
function maxnegativeoffset(ls::LoopSet, op::Operation, u1::Symbol, u2::Symbol, v::Symbol)
495544
mno = typemin(Int)
545+
i = 0
496546
if u1 !== v
497-
mno = first(maxnegativeoffset(ls, op, u1))
547+
mnou₁ = first(maxnegativeoffset(ls, op, u1))
548+
if mnou₁ > mno
549+
i = 1
550+
mno = mnou₁
551+
end
498552
end
499553
if u2 !== v
500-
mno = max(mno, first(maxnegativeoffset(ls, op, u2)))
554+
mnou₂ = first(maxnegativeoffset(ls, op, u2))
555+
if mnou₂ > mno
556+
i = 2
557+
mno = mnou₂
558+
end
501559
end
502-
mno
560+
mno, i
503561
end
504-
function loadelimination_cost_factor(ls::LoopSet, op::Operation, u1::Symbol, u2::Symbol, v::Symbol)
562+
function load_elimination_cost_factor(ls::LoopSet, op::Operation, u1::Symbol, u2::Symbol, v::Symbol)
505563
if first(isoptranslation(ls, op, u1, u2, v))
506564
for loop ls.loops
507565
# If another loop is short, assume that LLVM will unroll it, in which case
@@ -516,14 +574,50 @@ function loadelimination_cost_factor(ls::LoopSet, op::Operation, u1::Symbol, u2:
516574
end
517575
(0.25, VectorizationBase.REGISTER_COUNT == 32 ? 1.2 : 1.0)
518576
else
519-
offset = maxnegativeoffset(ls, op, u1, u2, v)
520-
if -5 < offset < 0
521-
(-0.25offset, 1.0)
577+
(1.0, 1.0)
578+
end
579+
end
580+
function add_constant_offset_load_elmination_cost!(
581+
X, R, ls::LoopSet, op::Operation, iters, u₁loop::Symbol, u₁reduces::Bool, u₂loop::Symbol, u₂reduces::Bool, v::Symbol, Wshift::Int, size_T::Int, opisininnerloop::Bool
582+
)
583+
offset, uid = maxnegativeoffset(ls, op, u₁loop, u₂loop, v)
584+
if -4 < offset < 0
585+
udependent_reduction = (-1 - offset) / 3
586+
uindependent_increase = (4 + offset) / 3
587+
rt, lat, rp = cost(ls, op, v, Wshift, size_T)
588+
rt *= iters
589+
rp = opisininnerloop ? rp : zero(rp)
590+
# u_uid is getting eliminated
591+
# we treat this as the unrolled loop getting eliminated is split into 2 parts:
592+
# 1 a non-cost-reduced part, with factor udependent_reduction
593+
# 2 a cost-reduced part, with factor uindependent_increase
594+
if uid == 1 # u₁reduces was false
595+
@assert !u₁reduces
596+
if u₂reduces
597+
r, i = 4, 2
598+
else
599+
r, i = 3, 1
600+
end
601+
elseif uid == 2 # u₂reduces was false
602+
@assert !u₂reduces
603+
if u₁reduces
604+
r, i = 4, 3
605+
else
606+
r, i = 2, 1
607+
end
522608
else
523-
(1.0, 1.0)
609+
throw("uid somehow did not return 1 or 2, even though offset > -4.")
524610
end
611+
X[r] += rt * uindependent_increase
612+
R[r] += rp * uindependent_increase
613+
X[i] += rt * udependent_reduction
614+
R[i] += rp * udependent_reduction
615+
return true
616+
else
617+
return false
525618
end
526619
end
620+
527621
# Just tile outer two loops?
528622
# But optimal order within tile must still be determined
529623
# as well as size of the tiles.
@@ -537,7 +631,7 @@ function evaluate_cost_tile(
537631
ops = operations(ls)
538632
nops = length(ops)
539633
included_vars = fill!(resize!(ls.included_vars, nops), false)
540-
unrolledtiled = fill(false, 2, nops)
634+
reduced_by_unrolling = fill(false, 2, nops)
541635
descendentsininnerloop = fill!(resize!(ls.place_after_loop, nops), false)
542636
innerloop = last(order)
543637
iters = fill(-99.9, nops)
@@ -556,8 +650,14 @@ function evaluate_cost_tile(
556650
# @inbounds reg_pressure[2] = 1
557651
# @inbounds reg_pressure[3] = 1
558652
iter::Int = 1
653+
u₁reached = u₂reached = false
559654
for n 1:N
560655
itersym = order[n]
656+
if itersym == u₁loopsym
657+
u₁reached = true
658+
elseif itersym == u₂loopsym
659+
u₂reached = true
660+
end
561661
# Add to set of defined symbles
562662
push!(nested_loop_syms, itersym)
563663
looplength = length(ls, itersym)
@@ -575,8 +675,11 @@ function evaluate_cost_tile(
575675
rd = reduceddependencies(op)
576676
hasintersection(rd, @view(nested_loop_syms[1:end-length(rd)])) && return 0,0,Inf
577677
included_vars[id] = true
578-
unrolledtiled[1,id] = u₁loopsym loopdependencies(op)
579-
unrolledtiled[2,id] = u₂loopsym loopdependencies(op)
678+
depends_on_u₁ = u₁loopsym loopdependencies(op)
679+
depends_on_u₂ = u₂loopsym loopdependencies(op)
680+
# cost is reduced by unrolling u₁ if it is interior to u₁loop (true if either u₁reached, or if depends on u₂ [or u₁]) and doesn't depend on u₁
681+
reduced_by_unrolling[1,id] = (u₁reached | depends_on_u₂) & !depends_on_u₁
682+
reduced_by_unrolling[2,id] = (u₂reached | depends_on_u₁) & !depends_on_u₂
580683
# @show op iter, unrolledtiled[:,id]
581684
iters[id] = iter
582685
innerloop loopdependencies(op) && set_upstream_family!(descendentsininnerloop, op, true)
@@ -585,31 +688,35 @@ function evaluate_cost_tile(
585688
for (id, op) enumerate(ops)
586689
iters[id] == -99.9 && continue
587690
opisininnerloop = descendentsininnerloop[id]
588-
isunrolled₁, isunrolled₂ = unrolledtiled[1,id], unrolledtiled[2,id]
589-
rt, lat, rp = cost(ls, op, vectorized, Wshift, size_T)
691+
692+
u₁reduces, u₂reduces = reduced_by_unrolling[1,id], reduced_by_unrolling[2,id]
693+
# @show op, u₁reduces, u₂reduces
694+
if !(isload(op) && add_constant_offset_load_elmination_cost!(cost_vec, reg_pressure, ls, op, iters[id], u₁loopsym, u₁reduces, u₂loopsym, u₂reduces, vectorized, Wshift, size_T, opisininnerloop))
695+
rt, lat, rp = cost(ls, op, vectorized, Wshift, size_T)
590696
# @show op rt, lat, rp
591-
if isload(op)
592-
factor1, factor2 = loadelimination_cost_factor(ls, op, u₁loopsym, u₂loopsym, vectorized)
593-
rt *= factor1; rp *= factor2;
594-
end
595-
# @show isunrolled₁, isunrolled₂, op rt, lat, rp
596-
rp = opisininnerloop ? rp : 0 # we only care about register pressure within the inner most loop
597-
rt *= iters[id]
598-
if isunrolled₁ && isunrolled₂ # no cost decrease; cost must be repeated
599-
cost_vec[1] += rt
600-
reg_pressure[1] += rp
601-
elseif isunrolled₁ # cost decreased by unrolling u₂loop
602-
cost_vec[2] += rt
603-
reg_pressure[2] += rp
604-
elseif isunrolled₂ # cost decreased by unrolling u₁loop
605-
cost_vec[3] += rt
606-
reg_pressure[3] += rp
607-
else# no unrolling
608-
cost_vec[4] += rt
609-
reg_pressure[4] += rp
697+
if isload(op)
698+
factor1, factor2 = load_elimination_cost_factor(ls, op, u₁loopsym, u₂loopsym, vectorized)
699+
rt *= factor1; rp *= factor2;
700+
end
701+
# @show isunrolled₁, isunrolled₂, op rt, lat, rp
702+
rp = opisininnerloop ? rp : zero(rp) # we only care about register pressure within the inner most loop
703+
rt *= iters[id]
704+
if u₁reduces & u₂reduces
705+
cost_vec[4] += rt
706+
reg_pressure[4] += rp
707+
elseif u₂reduces # cost decreased by unrolling u₂loop
708+
cost_vec[2] += rt
709+
reg_pressure[2] += rp
710+
elseif u₁reduces # cost decreased by unrolling u₁loop
711+
cost_vec[3] += rt
712+
reg_pressure[3] += rp
713+
else # no cost decrease; cost must be repeated
714+
cost_vec[1] += rt
715+
reg_pressure[1] += rp
716+
end
610717
end
611718
end
612-
# @show reg_pressure
719+
# @show cost_vec reg_pressure
613720
costpenalty = (sum(reg_pressure) > VectorizationBase.REGISTER_COUNT) ? 2 : 1
614721
# @show order, vectorized cost_vec reg_pressure
615722
# @show solve_unroll(ls, u₁loopsym, u₂loopsym, cost_vec, reg_pressure)

test/gemm.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,9 +300,9 @@
300300
end)
301301
lsr2amb = LoopVectorization.LoopSet(r2ambq);
302302
if LoopVectorization.VectorizationBase.REGISTER_COUNT == 32
303-
@test LoopVectorization.choose_order(lsr2amb) == ([:n, :m, :k], :k, :n, :m, 3, 6)
303+
@test LoopVectorization.choose_order(lsr2amb) == ([:n, :m, :k], :m, :n, :m, 3, 6)
304304
else
305-
@test LoopVectorization.choose_order(lsr2amb) == ([:n, :m, :k], :k, :n, :m, 2, 4)
305+
@test LoopVectorization.choose_order(lsr2amb) == ([:n, :m, :k], :m, :n, :m, 2, 4)
306306
end
307307
function rank2AmulBavx!(C, Aₘ, Aₖ, B)
308308
@avx for m 1:size(C,1), n 1:size(C,2)

0 commit comments

Comments
 (0)