Skip to content

Commit 9b004cc

Browse files
committed
Tile remained with nested ifs.
1 parent f98363b commit 9b004cc

File tree

5 files changed

+90
-49
lines changed

5 files changed

+90
-49
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module LoopVectorization
22

33
using VectorizationBase, SIMDPirates, SLEEFPirates, MacroTools, Parameters
4-
using VectorizationBase: REGISTER_SIZE, extract_data, num_vector_load_expr, mask
4+
using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector_load_expr, mask
55
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul, vrange, reduced_add, reduced_prod
66
using Base.Broadcast: Broadcasted, DefaultArrayStyle
77
using LinearAlgebra: Adjoint

src/determinestrategy.jl

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ function solve_tilesize(X, R)
167167
# X is vector of costs, and R is of register pressures
168168
# @show X
169169
# @show R
170-
RR = VectorizationBase.REGISTER_COUNT - R[3] - R[4]
170+
RR = REGISTER_COUNT - R[3] - R[4]
171171
a = (R[1])^2*X[2] - (R[2])^2*R[1]*X[3]/RR
172172
b = 2*R[1]*R[2]*X[3]
173173
c = -RR*R[1]*X[3]
@@ -180,7 +180,7 @@ function solve_tilesize(X, R)
180180
Uhigh = Ulow + 1 #ceil(Int, Ufloat)
181181
Thigh = Tlow + 1 #ceil(Int, Tfloat)
182182

183-
RR = VectorizationBase.REGISTER_COUNT - R[3] - R[4]
183+
RR = REGISTER_COUNT - R[3] - R[4]
184184
U, T = Ulow, Tlow
185185
tcost = tile_cost(X, Ulow, Tlow)
186186
# @show Ulow*Thigh*R[1] + Ulow*R[2]
@@ -208,10 +208,14 @@ function solve_tilesize(X, R)
208208
min(U,RR), min(T,RR), tcost
209209
end
210210
function solve_tilesize_constU(X, R, U)
211-
floor(Int, (VectorizationBase.REGISTER_COUNT - R[3] - R[4] - U*R[2]) / (U * R[1]))
211+
floor(Int, (REGISTER_COUNT - R[3] - R[4] - U*R[2]) / (U * R[1]))
212212
end
213213
function solve_tilesize_constT(X, R, T)
214-
floor(Int, (VectorizationBase.REGISTER_COUNT - R[3] - R[4]) / (T * R[1] + R[2]))
214+
floor(Int, (REGISTER_COUNT - R[3] - R[4]) / (T * R[1] + R[2]))
215+
end
216+
function solve_tilesize_constT(ls, T)
217+
R = @view ls.reg_pres[:,1]
218+
floor(Int, (REGISTER_COUNT - R[3] - R[4]) / (T * R[1] + R[2]))
215219
end
216220
# Tiling here is about alleviating register pressure for the UxT
217221
function solve_tilesize(X, R, Umax, Tmax)
@@ -233,7 +237,15 @@ function solve_tilesize(X, R, Umax, Tmax)
233237
end
234238
U, T, cost
235239
end
236-
240+
function solve_tilesize(
241+
ls::LoopSet, unrolled::Symbol, tiled::Symbol,
242+
cost_vec::AbstractVector{Float64} = @view(ls.cost_vec[:,1]),
243+
reg_pressure::AbstractVector{Int} = @view(ls.reg_pres[:,1])
244+
)
245+
maxT = isstaticloop(ls, tiled) ? looprangehint(ls, tiled) : REGISTER_COUNT
246+
maxU = isstaticloop(ls, unrolled) ? looprangehint(ls, unrolled) : REGISTER_COUNT
247+
solve_tilesize(cost_vec, reg_pressure, maxT, maxU)
248+
end
237249

238250
# Just tile outer two loops?
239251
# But optimal order within tile must still be determined
@@ -257,8 +269,8 @@ function evaluate_cost_tile(
257269
# cost_mat[3] / ( unrolled)
258270
# cost_mat[4]
259271
# @show order
260-
cost_vec = zeros(Float64, 4)
261-
reg_pressure = zeros(Int, 4)
272+
cost_vec = cost_vec_buf(ls)
273+
reg_pressure = reg_pres_buf(ls)
262274
# @inbounds reg_pressure[2] = 1
263275
# @inbounds reg_pressure[3] = 1
264276
for n 1:N
@@ -304,22 +316,7 @@ function evaluate_cost_tile(
304316
end
305317
end
306318
end
307-
Tstatic = isstaticloop(ls, tiled)
308-
Ustatic = isstaticloop(ls, unrolled)
309-
# @show order, cost_vec, reg_pressure
310-
if Tstatic
311-
if Ustatic
312-
solve_tilesize(cost_vec, reg_pressure, looprangehint(ls, tiled), looprangehint(ls, unrolled))
313-
else
314-
solve_tilesize(cost_vec, reg_pressure, looprangehint(ls, tiled), typemax(Int))
315-
end
316-
else
317-
if Ustatic
318-
solve_tilesize(cost_vec, reg_pressure, typemax(Int), looprangehint(ls, unrolled))
319-
else
320-
solve_tilesize(cost_vec, reg_pressure)#, typemax(Int), typemax(Int))
321-
end
322-
end
319+
solve_tilesize(ls, unrolled, tiled, cost_vec, reg_pressure)
323320
end
324321

325322

@@ -385,7 +382,7 @@ function choose_unroll_order(ls::LoopSet, lowest_cost::Float64 = Inf)
385382
end
386383
function choose_tile(ls::LoopSet)
387384
lo = LoopOrders(ls)
388-
best_order = copy(lo.syms)
385+
best_order = copyto!(ls.loop_order.bestorder, lo.syms)
389386
new_order, state = iterate(lo) # right now, new_order === best_order
390387
U, T, lowest_cost = 0, 0, Inf
391388
while true
@@ -394,6 +391,7 @@ function choose_tile(ls::LoopSet)
394391
lowest_cost = cost_temp
395392
U, T = U_temp, T_temp
396393
copyto!(best_order, new_order)
394+
save_tilecost!(ls)
397395
end
398396
iter = iterate(lo, state)
399397
iter === nothing && return best_order, U, T, lowest_cost

src/graphs.jl

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,12 @@ end
5353
struct LoopOrder <: AbstractArray{Vector{Operation},5}
5454
oporder::Vector{Vector{Operation}}
5555
loopnames::Vector{Symbol}
56+
bestorder::Vector{Symbol}
5657
end
5758
function LoopOrder(N::Int)
58-
LoopOrder( [ Operation[] for i 1:24N ], Vector{Symbol}(undef, N) )
59+
LoopOrder( [ Operation[] for i 1:24N ], Vector{Symbol}(undef, N), Vector{Symbol}(undef, N) )
5960
end
60-
LoopOrder() = LoopOrder(Vector{Operation}[],Symbol[])
61+
LoopOrder() = LoopOrder(Vector{Operation}[],Symbol[],Symbol[])
6162
Base.empty!(lo::LoopOrder) = foreach(empty!, lo.oporder)
6263
function Base.resize!(lo::LoopOrder, N::Int)
6364
Nold = length(lo.loopnames)
@@ -66,6 +67,7 @@ function Base.resize!(lo::LoopOrder, N::Int)
6667
lo.oporder[n] = Operation[]
6768
end
6869
resize!(lo.loopnames, N)
70+
resize!(lo.bestorder, N)
6971
lo
7072
end
7173
Base.size(lo::LoopOrder) = (4,2,2,2,length(lo.loopnames))
@@ -85,10 +87,34 @@ struct LoopSet
8587
includedarrays::Vector{Tuple{Symbol,Int}}
8688
syms_aliasing_refs::Vector{Symbol} # O(N) search is faster at small sizes
8789
refs_aliasing_syms::Vector{ArrayReference}
90+
cost_vec::Matrix{Float64}
91+
reg_pres::Matrix{Int}
8892
# sym_to_ref_aliases::Dict{Symbol,ArrayReference}
8993
# ref_to_sym_aliases::Dict{ArrayReference,Symbol}
9094
end
9195

96+
function cost_vec_buf(ls::LoopSet)
97+
cv = @view(ls.cost_vec[:,2])
98+
@inbounds for i 1:4
99+
cv[i] = 0.0
100+
end
101+
cv
102+
end
103+
function reg_pres_buf(ls::LoopSet)
104+
ps = @view(ls.reg_pres[:,2])
105+
@inbounds for i 1:4
106+
ps[i] = 0
107+
end
108+
ps
109+
end
110+
function save_tilecost!(ls::LoopSet)
111+
@inbounds for i 1:4
112+
ls.cost_vec[i,1] = ls.cost_vec[i,2]
113+
ls.reg_pres[i,1] = ls.reg_pres[i,2]
114+
end
115+
end
116+
117+
92118
# function op_to_ref(ls::LoopSet, op::Operation)
93119
# s = op.variable
94120
# id = findfirst(ls.syms_aliasing_regs)
@@ -113,9 +139,9 @@ function LoopSet()
113139
Expr(:block,),
114140
Tuple{Symbol,Int}[],
115141
Symbol[],
116-
ArrayReference[]
117-
# Dict{Symbol,ArrayReference}()
118-
# Dict{ArrayReference,Symbol}()
142+
ArrayReference[],
143+
Matrix{Float64}(undef, 4, 2),
144+
Matrix{Int}(undef, 4, 2)
119145
)
120146
end
121147
num_loops(ls::LoopSet) = length(ls.loops)
@@ -131,8 +157,7 @@ looprangesym(ls::LoopSet, s::Symbol) = ls.loops[s].rangesym
131157
getop(ls::LoopSet, s::Symbol) = ls.opdict[s]
132158
getop(ls::LoopSet, i::Int) = ls.operations[i + 1]
133159

134-
function looprange(ls::LoopSet, s::Symbol, incr::Int = 1, mangledname::Symbol = s)
135-
loop = ls.loops[s]
160+
function looprange(ls::LoopSet, s::Symbol, incr::Int = 1, mangledname::Symbol = s, loop = ls.loops[s])
136161
incr -= 1
137162
if iszero(incr)
138163
Expr(:call, :<, mangledname, loop.hintexact ? loop.rangehint : loop.rangesym)

src/lowering.jl

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ function lower_load_scalar!(
4949
q::Expr, op::Operation, W::Int, unrolled::Symbol, U::Int,
5050
suffix::Union{Nothing,Int}, mask::Union{Nothing,Symbol,Unsigned} = nothing
5151
)
52-
5352
loopdeps = loopdependencies(op)
5453
@assert unrolled loopdeps
5554
var = op.variable
@@ -559,7 +558,6 @@ function lower_unrolled!(
559558
end
560559
Wt = W
561560
Ut = U
562-
Urem = 0
563561
Urepeat = true
564562
while Urepeat
565563
if Uexprtype !== :skip
@@ -639,7 +637,8 @@ function lower_tiled(ls::LoopSet, U::Int, T::Int)
639637
unrolled = order[end-1]
640638
mangledtiled = tiledsym(tiled)
641639
W = VectorizationBase.pick_vector_width(ls, unrolled)
642-
static_tile = isstaticloop(ls, tiled)
640+
tiledloop = ls.loops[tiled]
641+
static_tile = tiledloop.hintexact
643642
static_unroll = isstaticloop(ls, unrolled)
644643
unrolled_iter = looprangehint(ls, unrolled)
645644
unrolled_itersym = looprangesym(ls, unrolled)
@@ -649,32 +648,51 @@ function lower_tiled(ls::LoopSet, U::Int, T::Int)
649648
Trem = Tt = T
650649
nloops = num_loops(ls);
651650
# addtileonly = sum(length, @view(oporder(ls)[:,:,:,:,end])) > 0
652-
Texprtype = (static_tile && tiled_iter < 2T) ? :block : :while
651+
# Texprtype = (static_tile && tiled_iter < 2T) ? :block : :while
652+
firstiter = true
653+
mangledtiled = tiledsym(tiled)
654+
local qifelse::Expr
653655
while Tt > 0
654-
#
655656
tiledloopbody = Expr(:block, )
656-
# else
657-
# Expr(:block, Expr(:(=), unrolled, 0))
658-
# end
659657
lower_unrolled!(tiledloopbody, ls, U, Tt, W, static_unroll, unrolled_iter, unrolled_itersym)
660658
tiledloopbody = lower_nest(ls, nloops, U, Tt, tiledloopbody, 0, W, nothing, :block)
661-
push!(q.args, Texprtype === :block ? tiledloopbody : Expr(Texprtype, looprange(ls, tiled, Tt, tiledsym(tiled)), tiledloopbody))
659+
if firstiter
660+
push!(q.args, (static_tile && tiled_iter < 2T) ? tiledloopbody : Expr(:while, looprange(ls, tiled, Tt, mangledtiled, tiledloop), tiledloopbody))
661+
elseif static_tile
662+
push!(q.args, tiledloopbody)
663+
else # not static, not firstiter
664+
comparison = Expr(:call, :(==), mangledtiled, Expr(:call, :-, tiledloop.rangesym, Tt))
665+
qifelsenew = Expr(:elseif, comparison, tiledloopbody)
666+
push!(qifelse.args, qifelsenew)
667+
qifelse = qifelsenew
668+
end
662669
if static_tile
663-
Tt = if Tt == T
670+
if Tt == T
664671
# push!(tiledloopbody.args, Expr(:+=, mangledtiled, Tt))
665672
Texprtype = :block
666-
looprangehint(ls, tiled) % T
673+
Tt = looprangehint(ls, tiled) % T
674+
# Recalculate U
675+
U = solve_tilesize_constT(ls, Tt)
667676
else
668-
0 # terminate
677+
Tt = 0 # terminate
669678
end
670679
nothing
671680
else
672-
Ttold = Tt
673-
Tt >>>= 1
674-
# Tt == 0 || push!(tiledloopbody.args, Expr(:+=, mangledtiled, Ttold))
675-
Texprtype = 2Tt == Ttold ? :if : :while
681+
if firstiter
682+
comparison = Expr(:call, :(==), mangledtiled, tiledloop.rangesym)
683+
qifelse = Expr(:if, comparison, Expr(:block)) # do nothing
684+
push!(q.args, qifelse)
685+
Tt = 0
686+
end
687+
Tt += 1 # we start counting up by 1
688+
if Tt == T # terminate on Tt = T
689+
Tt = 0
690+
else
691+
U = solve_tilesize_constT(ls, Tt)
692+
end
676693
nothing
677694
end
695+
firstiter = false
678696
end
679697
q = gc_preserve(ls, q)
680698
reduce_expr!(q, ls, U)

0 commit comments

Comments
 (0)