Skip to content

Commit 1433577

Browse files
committed
Working on adding basic if/else support, as well as LoopSet -> Type -> LoopSet conversion.
1 parent 4f061f8 commit 1433577

File tree

5 files changed

+111
-26
lines changed

5 files changed

+111
-26
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.3.8"
4+
version = "0.3.9"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/condense_loopset.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
2+
## turn a LoopSet into a type object which can be used to reconstruct the LoopSet.
3+
4+
5+
# Try to condense in type stable manner
6+
function condense_operations(ls::LoopSet)
7+
8+
end
9+

src/costs.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ const COST = Dict{Instruction,InstructionCost}(
103103
Instruction(:(<)) => InstructionCost(1, 0.5),
104104
Instruction(:(>=)) => InstructionCost(1, 0.5),
105105
Instruction(:(<=)) => InstructionCost(1, 0.5),
106+
Instruction(:ifelse) => InstructionCost(1, 0.5),
107+
Instruction(:vifelse) => InstructionCost(1, 0.5),
106108
Instruction(:inv) => InstructionCost(13,4.0,-2.0,1),
107109
Instruction(:vinv) => InstructionCost(13,4.0,-2.0,1),
108110
Instruction(:muladd) => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
@@ -131,7 +133,7 @@ const COST = Dict{Instruction,InstructionCost}(
131133
Instruction(:sincos_fast) => InstructionCost(25,22.0,70.0,26),
132134
Instruction(:identity) => InstructionCost(0,0.0,0.0,0),
133135
Instruction(:adjoint) => InstructionCost(0,0.0,0.0,0),
134-
Instruction(:transpose) => InstructionCost(0,0.0,0.0,0)
136+
Instruction(:transpose) => InstructionCost(0,0.0,0.0,0),
135137
# Symbol("##CONSTANT##") => InstructionCost(0,0.0)
136138
)
137139
# for (k, v) ∈ COST # so we can look up Symbol(typeof(function))

src/determinestrategy.jl

Lines changed: 69 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33
# wrong for transposed matrices, and certain views/SubArrays.
44
unitstride(op::Operation, s) = first(getindices(op)) === s
55

6+
function register_pressure(op::Operation)
7+
if isconstant(op)
8+
0
9+
else
10+
instruction_cost(instruction(op)).register_pressure
11+
end
12+
end
613
function cost(op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int = op.elementbytes)
714
isconstant(op) && return 0.0, 0, 1
815
# Wshift == dependson(op, unrolled) ? Wshift : 0
@@ -45,6 +52,10 @@ function hasintersection(a, b)
4552
end
4653
false
4754
end
55+
function num_iterations(N, step)
56+
iter, rem = divrem(N, step)
57+
iter + (rem != 0)
58+
end
4859

4960
# evaluates cost of evaluating loop in given order
5061
# heuristically, could simplify analysis by just unrolling outer loop?
@@ -62,10 +73,8 @@ function evaluate_cost_unroll(
6273
for itersym order
6374
# Add to set of defined symbles
6475
push!(nested_loop_syms, itersym)
65-
liter = Float64(length(ls, itersym))
66-
if itersym === vectorized
67-
liter /= W
68-
end
76+
looplength = length(ls, itersym)
77+
liter = itersym === vectorized ? num_iterations(looplength, W) : looplength
6978
iter *= liter
7079
# check which vars we can define at this level of loop nest
7180
for (id,op) enumerate(operations(ls))
@@ -183,16 +192,16 @@ function determine_unroll_factor(
183192
roundpow2(max(1, round(Int, latency / (recip_throughput * num_reductions) ) ))
184193
end
185194

186-
function tile_cost(X, U, T)
187-
X[1] + X[4] + X[2] / T + X[3] / U
195+
function tile_cost(X, U, T, UL, TL)
196+
X[1] + X[4] + X[2] * (num_iterations(TL, T)/TL) + X[3] * (num_iterations(UL, U)/UL)
188197
end
189-
function solve_tilesize(X, R)
198+
function solve_tilesize(X, R, UL, TL)
190199
@inbounds any(iszero, (R[1],R[2],R[3])) && return -1,-1,Inf #solve_smalltilesize(X, R, Umax, Tmax)
191200
# @inbounds any(iszero, (R[1],R[2],R[3])) && return -1,-1,Inf #solve_smalltilesize(X, R, Umax, Tmax)
192201
# We use a lagrange multiplier to find floating point values for U and T
193202
# first solving for U via quadratic formula
194203
# X is vector of costs, and R is of register pressures
195-
RR = REGISTER_COUNT - R[3] - R[4]
204+
RR = REGISTER_COUNT - R[3] - R[4] # RR ≡ RemainingRegisters
196205
a = (R[1])^2*X[2] - (R[2])^2*R[1]*X[3]/RR
197206
b = 2*R[1]*R[2]*X[3]
198207
c = -RR*R[1]*X[3]
@@ -205,12 +214,12 @@ function solve_tilesize(X, R)
205214
Uhigh = Ulow + 1 #ceil(Int, Ufloat)
206215
Thigh = Tlow + 1 #ceil(Int, Tfloat)
207216

208-
RR = REGISTER_COUNT - R[3] - R[4]
217+
# RR = REGISTER_COUNT - R[3] - R[4]
209218
U, T = Ulow, Tlow
210-
tcost = tile_cost(X, Ulow, Tlow)
219+
tcost = tile_cost(X, Ulow, Tlow, UL, TL)
211220
# @show Ulow*Thigh*R[1] + Ulow*R[2]
212221
if RR Ulow*Thigh*R[1] + Ulow*R[2]
213-
tcost_temp = tile_cost(X, Ulow, Thigh)
222+
tcost_temp = tile_cost(X, Ulow, Thigh, UL, TL)
214223
# @show tcost_temp, tcost
215224
if tcost_temp < tcost
216225
tcost = tcost_temp
@@ -222,7 +231,7 @@ function solve_tilesize(X, R)
222231
while RR < Uhigh*Tl*R[1] + Uhigh*R[2]
223232
Tl -= 1
224233
end
225-
tcost_temp = tile_cost(X, Uhigh, Tl)
234+
tcost_temp = tile_cost(X, Uhigh, Tl, UL, TL)
226235
if tcost_temp < tcost
227236
tcost = tcost_temp
228237
U, T = Uhigh, Tl
@@ -243,9 +252,9 @@ function solve_tilesize_constT(ls, T)
243252
floor(Int, (REGISTER_COUNT - R[3] - R[4]) / (T * R[1] + R[2]))
244253
end
245254
# Tiling here is about alleviating register pressure for the UxT
246-
function solve_tilesize(X, R, Umax, Tmax)
255+
function solve_tilesize(X, R, Umax, Tmax, UL, TL)
247256
first(R) == 0 && return -1,-1,Inf #solve_smalltilesize(X, R, Umax, Tmax)
248-
U, T, cost = solve_tilesize(X, R)
257+
U, T, cost = solve_tilesize(X, R, UL, TL)
249258
# T -= T & 1
250259
# U = min(U, T)
251260
U_too_large = U > Umax
@@ -264,20 +273,37 @@ function solve_tilesize(X, R, Umax, Tmax)
264273
end
265274
U, T, cost
266275
end
276+
function maybedemotesize(U::Int, N::Int)
277+
U > 1 || return 1
278+
Um1 = U - 1
279+
urep = num_iterations(N, U)
280+
um1rep = num_iterations(N, Um1)
281+
um1rep > urep ? U : Um1
282+
end
267283
function solve_tilesize(
268284
ls::LoopSet, unrolled::Symbol, tiled::Symbol,
269285
cost_vec::AbstractVector{Float64} = @view(ls.cost_vec[:,1]),
270286
reg_pressure::AbstractVector{Int} = @view(ls.reg_pres[:,1])
271287
)
272288
maxT = 4#8
273289
maxU = 4#8
274-
if isstaticloop(ls, tiled)
275-
maxT = min(2maxT, looprangehint(ls, tiled))
290+
tiledloop = ls.loops[tiled]
291+
unrolledloop = ls.loops[unrolled]
292+
if isstaticloop(tiledloop)
293+
maxT = min(4maxT, length(tiledloop))
294+
end
295+
if isstaticloop(unrolledloop)
296+
maxU = min(4maxU, length(unrolledloop))
297+
end
298+
U, T, cost = solve_tilesize(cost_vec, reg_pressure, maxU, maxT, length(unrolledloop), length(tiledloop))
299+
# heuristic to more evenly divide small numbers of iterations
300+
if isstaticloop(tiledloop) & T > 1
301+
T = maybedemotesize(T, length(tiledloop))
276302
end
277-
if isstaticloop(ls, unrolled)
278-
maxU = min(2maxU, looprangehint(ls, unrolled))
303+
if isstaticloop(unrolledloop)
304+
U = maybedemotesize(U, length(unrolledloop))
279305
end
280-
solve_tilesize(cost_vec, reg_pressure, maxU, maxT)
306+
U, T, cost
281307
end
282308

283309
function set_upstream_family!(adal::Vector{T}, op::Operation, val::T) where {T}
@@ -306,7 +332,6 @@ function evaluate_cost_tile(
306332
innerloop = last(order)
307333
iters = fill(-99.9, nops)
308334
nested_loop_syms = Symbol[]# Set{Symbol}()
309-
iter = 1.0
310335
# Need to check if fusion is possible
311336
size_T = biggest_type_size(ls)
312337
W, Wshift = VectorizationBase.pick_vector_width_shift(length(ls, vectorized), size_T)::Tuple{Int,Int}
@@ -320,14 +345,18 @@ function evaluate_cost_tile(
320345
reg_pressure = reg_pres_buf(ls)
321346
# @inbounds reg_pressure[2] = 1
322347
# @inbounds reg_pressure[3] = 1
348+
unrollediter = length(ls, unrolled)
349+
tilediter = length(ls, tiled)
350+
unrollediter = unrolled === vectorized ? num_iterations(unrollediter, W) : unrollediter # tiled cannot be vectorized, so do not check
351+
iter::Int = tilediter * unrollediter
323352
for n 1:N
324353
itersym = order[n]
325354
# Add to set of defined symbles
326355
push!(nested_loop_syms, itersym)
327-
if n == 1
328-
iter = length(ls, itersym) * length(ls, order[2]) / W
329-
elseif n > 2
330-
iter *= Float64(length(ls, itersym))
356+
stepsize = 1
357+
if n > 2
358+
itersymlooplen = length(ls, itersym)
359+
iter *= itersym === vectorized ? num_iterations(itersymlooplen, W) : itersymlooplen
331360
end
332361
# check which vars we can define at this level of loop nest
333362
for (id, op) enumerate(ops)
@@ -480,3 +509,19 @@ function choose_order(ls::LoopSet)
480509
end
481510
end
482511

512+
function register_pressure(ls::LoopSet)
513+
# uses unroll of 1 if not tiling
514+
if num_loops(ls) > 1
515+
torder, tvec, tU, tT, tc = choose_tile(ls)
516+
else
517+
tc = Inf
518+
end
519+
uorder, uvec, uc = choose_unroll_order(ls, tc)
520+
if num_loops(ls) > 1 && tc uc # tile
521+
rp = @view ls.reg_pressure[:,1]
522+
tU * tT * rp[1] + tU * rp[2] + rp[3] + rp[4]
523+
else
524+
sum(register_pressure, operations(ls))
525+
end
526+
end
527+

src/graphs.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,15 @@ function register_single_loop!(ls::LoopSet, looprange::Expr)
324324
N = gensym(Symbol(:loop, itersym))
325325
pushpreamble!(ls, Expr(:(=), N, Expr(:call, :length, r.args[2])))
326326
Loop(itersym, 0, N)
327+
elseif f === :OneTo || f === Expr(:(.), :Base, :OneTo)
328+
otN = r.args[2]
329+
if otN isa Integer
330+
Loop(itersym, 0, otN)
331+
else
332+
N = gensym(Symbol(:loop, itersym))
333+
pushpreamble!(ls, Expr(:(=), N, otN))
334+
Loop(itersym, 0, N)
335+
end
327336
else
328337
throw("Unrecognized loop range type: $r.")
329338
end
@@ -719,6 +728,18 @@ function add_store_setindex!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
719728
array, raw_indices = ref_from_setindex(ex)
720729
add_store!(ls, (ex.args[2])::Symbol, array, rawindices, elementbytes)
721730
end
731+
function add_if!(ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int = 8, mpref::Union{Nothing,ArrayReferenceMetaPosition} = nothing)
732+
condition = first(RHS.args)
733+
m = gensym(:mask)
734+
condop = add_compute!(ls, m, condition, elementbytes, mpref)
735+
iftrue = RHS.args[2]
736+
iftrueisaexpr = iftrue isa Expr
737+
iffalse = RHS.args[3]
738+
iffalseisaexpr = iffalse isa Expr
739+
trueisablock = iftrueisaexpr && iftrue.head !== :call
740+
falseisablock = iffalseisaexpr && iffalse.head !== :call
741+
742+
end
722743
# add operation assigns X to var
723744
function add_operation!(
724745
ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int = 8
@@ -736,6 +757,8 @@ function add_operation!(
736757
else
737758
add_compute!(ls, LHS, RHS, elementbytes)
738759
end
760+
elseif RHS.head === :if
761+
add_if!(ls, LHS, RHS, elementbytes)
739762
else
740763
throw("Expression not recognized:\n$x")
741764
end
@@ -757,6 +780,8 @@ function add_operation!(
757780
else
758781
add_compute!(ls, LHS_sym, RHS, elementbytes, LHS_ref)
759782
end
783+
elseif RHS.head === :if
784+
add_if!(ls, LHS, RHS, elementbytes, LHS_ref)
760785
else
761786
throw("Expression not recognized:\n$x")
762787
end
@@ -817,6 +842,10 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
817842
add_block!(ls, ex)
818843
elseif ex.head === :for
819844
add_loop!(ls, ex)
845+
elseif ex.head === :&&
846+
add_andblock!(ls, ex)
847+
elseif ex.head === :||
848+
add_orblock!(ls, ex)
820849
else
821850
throw("Don't know how to handle expression:\n$ex")
822851
end

0 commit comments

Comments
 (0)