Skip to content

Commit 3bd540f

Browse files
committed
Few minor tweaks.
1 parent a3373e8 commit 3bd540f

File tree

8 files changed

+119
-67
lines changed

8 files changed

+119
-67
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.8.5"
4+
version = "0.8.6"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"

src/determinestrategy.jl

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ function solve_unroll_iter(X, R, u₁L, u₂L, u₁range, u₂range)
299299
for u₂temp u₂range
300300
RR u₁temp*u₂temp*R₁ + u₁temp*R₂ + u₂temp*R₅ || continue
301301
tempcost = unroll_cost(X, u₁temp, u₂temp, u₁L, u₂L)
302-
if tempcost < bestcost
302+
if tempcost bestcost
303303
bestcost = tempcost
304304
u₁best, u₂best = u₁temp, u₂temp
305305
end
@@ -329,6 +329,11 @@ function solve_unroll(X, R, u₁L, u₂L, u₁step, u₂step)
329329
u₂low = max(u₂step, floor(Int, u₂float)) # must be at least 1
330330
u₁high = solve_unroll_constT(R, u₂low) + u₁step
331331
u₂high = solve_unroll_constU(R, u₁low) + u₂step
332+
maxunroll = REGISTER_COUNT == 32 ? 10 : 6
333+
u₁low = min(u₁low, maxunroll)
334+
u₂low = min(u₂low, maxunroll)
335+
u₁high = min(u₁high, maxunroll)
336+
u₂high = min(u₂high, maxunroll)
332337
solve_unroll_iter(X, R, u₁L, u₂L, u₁low:u₁step:u₁high, u₂low:u₂step:u₂high)
333338
end
334339

@@ -658,6 +663,17 @@ function add_constant_offset_load_elmination_cost!(
658663
end
659664
end
660665

666+
function update_costs!(costs, cost, u₁reduces, u₂reduces)
667+
if u₁reduces & u₂reduces
668+
costs[4] += cost
669+
elseif u₂reduces # cost decreased by unrolling u₂loop
670+
costs[2] += cost
671+
elseif u₁reduces # cost decreased by unrolling u₁loop
672+
costs[3] += cost
673+
else # no cost decrease; cost must be repeated
674+
costs[1] += cost
675+
end
676+
end
661677

662678
# Just tile outer two loops?
663679
# But optimal order within tile must still be determined
@@ -674,7 +690,7 @@ function evaluate_cost_tile(
674690
ops = operations(ls)
675691
nops = length(ops)
676692
included_vars = fill!(resize!(ls.included_vars, nops), false)
677-
reduced_by_unrolling = fill(false, 2, nops)
693+
reduced_by_unrolling = fill(false, 2, 2, nops)
678694
descendentsininnerloop = fill!(resize!(ls.place_after_loop, nops), false)
679695
innerloop = last(order)
680696
iters = fill(-99.9, nops)
@@ -718,13 +734,19 @@ function evaluate_cost_tile(
718734
rd = reduceddependencies(op)
719735
hasintersection(rd, @view(nested_loop_syms[1:end-length(rd)])) && return 0,0,Inf,false
720736
included_vars[id] = true
721-
depends_on_u₁ = isu₁unrolled(op)
722-
depends_on_u₂ = isu₂unrolled(op)
737+
if isconstant(op)
738+
depends_on_u₁, depends_on_u₂ = isunrolled_sym(op, u₁loopsym, u₂loopsym)
739+
reduced_by_unrolling[1,1,id] = !depends_on_u₁
740+
reduced_by_unrolling[2,1,id] = !depends_on_u₂
741+
else
742+
depends_on_u₁ = isu₁unrolled(op)
743+
depends_on_u₂ = isu₂unrolled(op)
744+
reduced_by_unrolling[1,1,id] = (u₁reached) & !depends_on_u₁
745+
reduced_by_unrolling[2,1,id] = (u₂reached) & !depends_on_u₂
746+
end
723747
# cost is reduced by unrolling u₁ if it is interior to u₁loop (true if either u₁reached, or if depends on u₂ [or u₁]) and doesn't depend on u₁
724-
# reduced_by_unrolling[1,id] = (u₁reached | depends_on_u₂) & !depends_on_u₁
725-
# reduced_by_unrolling[2,id] = (u₂reached | depends_on_u₁) & !depends_on_u₂
726-
reduced_by_unrolling[1,id] = (u₁reached) & !depends_on_u₁
727-
reduced_by_unrolling[2,id] = (u₂reached) & !depends_on_u₂
748+
reduced_by_unrolling[1,2,id] = (u₁reached | depends_on_u₂) & !depends_on_u₁
749+
reduced_by_unrolling[2,2,id] = (u₂reached | depends_on_u₁) & !depends_on_u₂
728750
iters[id] = iter
729751
innerloop loopdependencies(op) && set_upstream_family!(descendentsininnerloop, op, true)
730752
end
@@ -733,13 +755,15 @@ function evaluate_cost_tile(
733755
iters[id] == -99.9 && continue
734756
opisininnerloop = descendentsininnerloop[id]
735757

736-
u₁reduces, u₂reduces = reduced_by_unrolling[1,id], reduced_by_unrolling[2,id]
758+
u₁reducesrt, u₂reducesrt = reduced_by_unrolling[1,1,id], reduced_by_unrolling[2,1,id]
759+
u₁reducesrp, u₂reducesrp = reduced_by_unrolling[1,2,id], reduced_by_unrolling[2,2,id]
737760
if isload(op)
738-
if add_constant_offset_load_elmination_cost!(cost_vec, reg_pressure, choose_to_inline, ls, op, iters[id], unrollsyms, u₁reduces, u₂reduces, Wshift, size_T, opisininnerloop)
761+
if add_constant_offset_load_elmination_cost!(cost_vec, reg_pressure, choose_to_inline, ls, op, iters[id], unrollsyms, u₁reducesrp, u₂reducesrp, Wshift, size_T, opisininnerloop)
739762
continue
740763
elseif load_elimination_cost_factor!(cost_vec, reg_pressure, choose_to_inline, ls, op, iters[id], unrollsyms, Wshift, size_T)
741764
continue
742-
end
765+
end
766+
elseif isconstant(op)
743767
end
744768
rt, lat, rp = cost(ls, op, vectorized, Wshift, size_T)
745769
if isload(op) && !iszero(prefetchisagoodidea(ls, op, UnrollArgs(4, unrollsyms, 4, 0)))
@@ -750,19 +774,8 @@ function evaluate_cost_tile(
750774
# rp = opisininnerloop ? rp : zero(rp) # we only care about register pressure within the inner most loop
751775
rto = rt
752776
rt *= iters[id]
753-
if u₁reduces & u₂reduces
754-
cost_vec[4] += rt
755-
reg_pressure[4] += rp
756-
elseif u₂reduces # cost decreased by unrolling u₂loop
757-
cost_vec[2] += rt
758-
reg_pressure[2] += rp
759-
elseif u₁reduces # cost decreased by unrolling u₁loop
760-
cost_vec[3] += rt
761-
reg_pressure[3] += rp
762-
else # no cost decrease; cost must be repeated
763-
cost_vec[1] += rt
764-
reg_pressure[1] += rp
765-
end
777+
update_costs!(cost_vec, rt, u₁reducesrt, u₂reducesrt)
778+
update_costs!(reg_pressure, rp, u₁reducesrp, u₂reducesrp)
766779
end
767780
# @inbounds ((cost_vec[4] > 0) || ((cost_vec[2] > 0) & (cost_vec[3] > 0))) || return 0,0,Inf,false
768781
costpenalty = (sum(reg_pressure) > REGISTER_COUNT) ? 2 : 1

src/loopstartstopmanager.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,19 @@ function defpointermax(ls::LoopSet, ar::ArrayReferenceMeta, n::Int, sub::Int, is
201201
Expr(:(=), maxsym(vptr(ar), sub), pointermax(ls, ar, n, sub, isvectorized))
202202
end
203203

204-
function startloop(ls::LoopSet, us::UnrollSpecification, n::Int)
204+
function maxunroll(us::UnrollSpecification, n)
205+
@unpack u₁loopnum, u₂loopnum, u₁, u₂ = us
206+
if n == u₁loopnum# && u₁ > 1
207+
u₁
208+
elseif n == u₂loopnum# && u₂ > 1
209+
u₂
210+
else
211+
1
212+
end
213+
end
214+
215+
216+
function startloop(ls::LoopSet, us::UnrollSpecification, n::Int, submax = maxunroll(us, n))
205217
@unpack u₁loopnum, u₂loopnum, vectorizedloopnum, u₁, u₂ = us
206218
lssm = ls.lssm[]
207219
termind = lssm.terminators[n]
@@ -217,17 +229,6 @@ function startloop(ls::LoopSet, us::UnrollSpecification, n::Int)
217229
push!(loopstart.args, startloop(getloop(ls, loopsym), loopsym))
218230
else
219231
isvectorized = n == vectorizedloopnum
220-
submax = if n == u₁loopnum# && u₁ > 1
221-
# push!(loopstart.args, defpointermax(ls, ptrdefs[termind], n, u₁ - 1, isvectorized))
222-
u₁
223-
elseif n == u₂loopnum# && u₂ > 1
224-
u₂
225-
# push!(loopstart.args, defpointermax(ls, ptrdefs[termind], n, u₂ - 1, isvectorized))
226-
# elseif isvectorized
227-
# push!(loopstart.args, defpointermax(ls, ptrdefs[termind], n, 1, isvectorized))
228-
else
229-
1
230-
end
231232
for sub 0:submax
232233
push!(loopstart.args, defpointermax(ls, ptrdefs[termind], n, sub, isvectorized))
233234
end

src/lower_load.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,7 @@ function prefetchisagoodidea(ls::LoopSet, op::Operation, td::UnrollArgs)
9999
0
100100
end
101101
function add_prefetches!(q::Expr, ls::LoopSet, op::Operation, td::UnrollArgs, prefetchind::Int, umin::Int)
102-
103102
@unpack u₁, u₁loopsym, u₂loopsym, vectorized, u₂max = td
104-
105103
dontskip = (64 ÷ VectorizationBase.REGISTER_SIZE) - 1
106104
ptr = vptr(op)
107105
innermostloopsym = first(names(ls))
@@ -117,8 +115,8 @@ function add_prefetches!(q::Expr, ls::LoopSet, op::Operation, td::UnrollArgs, pr
117115
offsets[prefetchind] = inner_offset
118116
ptr = vptr(op)
119117
gptr = Symbol(ptr, "##GESPEDPREFETCH##")
120-
for i eachindex(gespinds.args)
121-
gespinds.args[i] = Expr(:call, lv(:extract_data), gespinds.args[i])
118+
for i eachindex(gespinds.args)
119+
gespinds.args[i] = Expr(:call, lv(:extract_data), gespinds.args[i])
122120
end
123121
push!(q.args, Expr(:(=), gptr, Expr(:call, lv(:gesp), ptr, gespinds)))
124122

src/lowering.jl

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -185,24 +185,48 @@ end
185185
# Expr(:block, loopiteratesatleastonce(loop), q)
186186
# end
187187
# end
188-
# function lower_unroll_for_throughput(ls::LoopSet, us::UnrollSpecification, n::Int, loop::Loop, loopsym::Symbol)
189-
# sl = startloop(loop, false, loopsym)
190-
# UF = 4
191-
# tcc = terminatecondition(loop, us, n, loopsym, false, 1)
192-
# tcu = terminatecondition(loop, us, n, loopsym, false, UF)
193-
# body = lower_block(ls, us, n, false, 1)
194-
# # loopisstatic = isstaticloop(loop)
195-
# unrolledlabel = gensym(:unrolled)
196-
# cleanuplabel = gensym(:cleanup)
197-
# gotounrolled = Expr(:macrocall, Symbol("@goto"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), unrolledlabel)
198-
# gotocleanup = Expr(:macrocall, Symbol("@goto"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), cleanuplabel)
199-
# branch = Expr(:if, tcu, gotounrolled, gotocleanup)
200-
# unrolled = Expr(:block, Expr(:macrocall, Symbol("@label"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), unrolledlabel))
201-
# foreach(_ -> push!(unrolled.args, body), 1:UF)
202-
# push!(unrolled.args, Expr(:if, tcu, gotounrolled, Expr(:if, tcc, gotocleanup)))
203-
# cleanup = Expr(:block, Expr(:macrocall, Symbol("@label"), LineNumberNode(@__LINE__, Symbol(@__FILE__)), cleanuplabel), body, Expr(:if, tcc, gotocleanup))
204-
# Expr(:let, sl, Expr(:block, branch, unrolled, cleanup))
205-
# end
188+
function lower_unroll_for_throughput(ls::LoopSet, us::UnrollSpecification, loop::Loop, loopsym::Symbol)
189+
UF = 4
190+
sl = startloop(ls, us, 1, UF)
191+
tcc = terminatecondition(ls, us, 1, false, 1)
192+
tcu = terminatecondition(ls, us, 1, false, UF)
193+
body = lower_block(ls, us, 1, false, 1)
194+
loopisstatic = isstaticloop(loop)
195+
tcu = loopisstatic ? tcu : expect(tcu)
196+
termcondu = gensym(:maybetermu)
197+
unrolledbody = Expr(:block)
198+
foreach(_ -> push!(unrolledbody.args, body), 1:UF)
199+
200+
# q = Expr(
201+
# :block,
202+
# Expr(:while, tcu, unrolledbody),
203+
# Expr(:while, tcc, body)
204+
# )
205+
# return Expr(:let, sl, q)
206+
207+
push!(unrolledbody.args, Expr(:(=), termcondu, tcu))
208+
209+
unrolledloop = Expr(
210+
:block,
211+
Expr(:while, termcondu, unrolledbody),
212+
Expr(:while, tcc, body)
213+
)
214+
215+
termcond = gensym(:maybeterm)
216+
singleloop = Expr(
217+
:block,
218+
Expr(:(=), termcond, true),
219+
Expr(:while, termcond, Expr(:block, body, Expr(:(=), termcond, tcc)))
220+
)
221+
222+
q = Expr(
223+
:block,
224+
assume(tcc),
225+
Expr(:(=), termcondu, tcu),
226+
Expr(:if, termcondu, unrolledloop, singleloop)
227+
)
228+
Expr(:let, sl, q)
229+
end
206230

207231
function assume(ex)
208232
Expr(:call, Expr(:(.), Expr(:(.), :LoopVectorization, QuoteNode(:SIMDPirates)), QuoteNode(:assume)), ex)
@@ -228,25 +252,29 @@ function lower_no_unroll(ls::LoopSet, us::UnrollSpecification, n::Int, inclmask:
228252
nisvectorized = isvectorized(us, n)
229253
loopsym = names(ls)[n]
230254
loop = getloop(ls, loopsym)
231-
# if !nisvectorized && !inclmask && isone(n) && !ls.loadelimination[] && (us.u₁ > 1) && (usorig.u₁ == us.u₁) && (usorig.u₂ == us.u₂) && length(loop) > 7
232-
# # return lower_unroll_for_throughput(ls, us, n, loop, loopsym)
233-
# return lower_llvm_unroll(ls, us, n, loop)
255+
# if !nisvectorized && !inclmask && isone(n) && !iszero(ls.lssm[].terminators[1]) && !ls.loadelimination[] && (us.u₁ > 1) && (usorig.u₁ == us.u₁) && (usorig.u₂ == us.u₂) && length(loop) > 15
256+
# return lower_unroll_for_throughput(ls, us, loop, loopsym)
257+
# # return lower_llvm_unroll(ls, us, n, loop)
234258
# end
235259
# sl = startloop(loop, nisvectorized, loopsym)
236260
sl = startloop(ls, us, n)
237261
tc = terminatecondition(ls, us, n, inclmask, 1)
238262
body = lower_block(ls, us, n, inclmask, 1)
239263
isstatic = isstaticloop(loop)
264+
265+
if !isstatic && (usorig.u₁ == us.u₁) && (usorig.u₂ == us.u₂) && !inclmask
266+
tc = expect(tc)
267+
end
240268
q = if nisvectorized
241269
# Expr(:block, loopiteratesatleastonce(loop, true), Expr(:while, expect(tc), body))
242-
Expr(:block, Expr(:while, isstatic ? tc : expect(tc), body))
270+
Expr(:block, Expr(:while, tc, body))
243271
elseif isstatic && length(loop) 8
244272
bodyq = Expr(:block)
245273
foreach(_ -> push!(bodyq.args, body), 1:length(loop))
246274
bodyq
247275
else
248276
termcond = gensym(:maybeterm)
249-
push!(body.args, Expr(:(=), termcond, isstatic ? tc : expect(tc)))
277+
push!(body.args, Expr(:(=), termcond, tc))
250278
Expr(:block, Expr(:(=), termcond, true), Expr(:while, termcond, body))
251279
# Expr(:block, Expr(:while, expect(tc), body))
252280
# Expr(:block, assume(tc), Expr(:while, tc, body))

test/gemm.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,19 @@
578578
# end)
579579
# lsmul2x2q = LoopVectorization.LoopSet(mul2x2q)
580580

581+
lsAtmulBt8 = :(for m 1:8, n 1:8
582+
ΔCₘₙ = zero(eltype(C))
583+
for k 1:8
584+
ΔCₘₙ += A[k,m] * B[n,k]
585+
end
586+
C[m,n] = ΔCₘₙ
587+
end) |> LoopVectorization.LoopSet;
588+
if LoopVectorization.VectorizationBase.REGISTER_COUNT == 32
589+
@test LoopVectorization.choose_order(lsAtmulBt8) == ([:n, :m, :k], :m, :n, :m, 1, 8)
590+
else
591+
# @test LoopVectorization.choose_order(lsAtmulBt8) == ([:n, :m, :k], :m, :n, :m, 2, 4)
592+
end
593+
581594
struct SizedMatrix{M,N,T} <: DenseMatrix{T}
582595
data::Matrix{T}
583596
end

test/gemv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
# T = Float32
44
@testset "GEMV" begin
55
# Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (4, 6)
6-
Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (1, 6) : (1, 10)
6+
Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (2, 6) : (2, 10)
77
gemvq = :(for i eachindex(y)
88
yᵢ = 0.0
99
for j eachindex(x)

test/miscellaneous.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ using Test
44

55
@testset "Miscellaneous" begin
66

7-
# Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (4, 4)
8-
Unum, Tnum = LoopVectorization.REGISTER_COUNT == 16 ? (1, 6) : (1, 10)
7+
Unum, Tnum = LoopVectorization.REGISTER_COUNT == 16 ? (2, 6) : (2, 10)
98
dot3q = :(for m 1:M, n 1:N
109
s += x[m] * A[m,n] * y[n]
1110
end);

0 commit comments

Comments
 (0)