Skip to content

Commit d21d018

Browse files
committed
Debugging progress.
1 parent 3555394 commit d21d018

File tree

6 files changed

+304
-117
lines changed

6 files changed

+304
-117
lines changed

src/costs.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ function vector_cost(instruction::InstructionCost, Wshift, sizeof_T)
3232
W = 1 << Wshift
3333
extra_latency = sl - srt
3434
srt *= W
35-
sl = srt + extra_latency
35+
sl = round(Int, srt + extra_latency)
3636
else # we assume custom cost, and that latency == recip_throughput
37-
sl, srt = scaling, scaling
37+
sl, srt = round(Int,scaling), scaling
3838
end
3939
srt, sl, srp
4040
end
@@ -63,8 +63,8 @@ const OPAQUE_INSTRUCTION = InstructionCost(50, 50.0, -1.0, VectorizationBase.REG
6363
# consolidated into a single register. The number of LICM-ed setindex!, on the other
6464
# hand, should indicate how many registers we're keeping live for the sake of eventually storing.
6565
const COST = Dict{Symbol,InstructionCost}(
66-
:getindex => InstructionCost(3,0.5,-3.0,0),
67-
:setindex! => InstructionCost(3,1.0,-3.0,1),
66+
:getindex => InstructionCost(-3.0,0.5,3,0),
67+
:setindex! => InstructionCost(-3.0,1.0,3,1),
6868
:zero => InstructionCost(1,0.5),
6969
:one => InstructionCost(3,0.5),
7070
:(+) => InstructionCost(4,0.5),
@@ -93,7 +93,8 @@ const COST = Dict{Symbol,InstructionCost}(
9393
:exp => InstructionCost(20,20.0,20.0,18),
9494
:sin => InstructionCost(18,15.0,68.0,23),
9595
:cos => InstructionCost(18,15.0,68.0,26),
96-
:sincos => InstructionCost(25,22.0,70.0,26)
96+
:sincos => InstructionCost(25,22.0,70.0,26)#,
97+
# Symbol("##CONSTANT##") => InstructionCost(0,0.0)
9798
)
9899
for (k, v) COST # so we can look up Symbol(typeof(function))
99100
COST[Symbol("typeof(", k, ")")] = v

src/determinestrategy.jl

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
# TODO: FIXME for general case
33
unitstride(op, s) = first(loopdependencies(op)) === s
44

5-
function cost(op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int)
5+
function cost(op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int = op.elementbytes)
6+
isconstant(op) && return 0.0, 0, 0
67
# Wshift == dependson(op, unrolled) ? Wshift : 0
78
# c = first(cost(instruction(op), Wshift, size_T))::Int
89
instr = instruction(op)
910
opisunrolled = dependson(op, unrolled)
1011
srt, sl, srp = opisunrolled ? vector_cost(instr, Wshift, size_T) : scalar_cost(instr)
1112
if accesses_memory(op)
1213
# either vbroadcast/reductionstore, vmov(a/u)pd, or gather/scatter
14+
# @show instr, unrolled, loopdependencies(op), unitstride(op, unrolled)
1315
if opisunrolled
1416
if !unitstride(op, unrolled)# || !isdense(op) # need gather/scatter
1517
r = (1 << Wshift)
@@ -72,7 +74,9 @@ function evaluate_cost_unroll(
7274
included_vars[id] && continue
7375
# it must also be a subset of defined symbols
7476
loopdependencies(op) nested_loop_syms || continue
75-
hasintersection(reduceddependencies(op), nested_loop_syms) && return Inf
77+
# hasintersection(reduceddependencies(op), nested_loop_syms) && return Inf
78+
rd = reduceddependencies(op)
79+
hasintersection(rd, nested_loop_syms[1:end-length(rd)]) && return Inf
7680
included_vars[id] = true
7781

7882
total_cost += iter * first(cost(op, unrolled, Wshift, size_T))
@@ -97,7 +101,8 @@ function depchain_cost!(
97101
if accesses_memory(op)
98102
return sl, rt
99103
end
100-
slᵢ, rtᵢ = cost(op, 1 << Wshift, Wshift, unrolled)
104+
# @show instruction(op)
105+
rtᵢ, slᵢ = cost(op, unrolled, Wshift, size_T)
101106
sl + slᵢ, rt + rtᵢ
102107
end
103108

@@ -110,6 +115,7 @@ function determine_unroll_factor(
110115
# The strategy is to use an unroll factor of 1, unless there appears to be loop carried dependencies (ie, num_reductions > 0)
111116
# The assumption here is that unrolling provides no real benefit, unless it is needed to enable OOO execution by breaking up these dependency chains
112117
num_reductions = sum(isreduction, operations(ls))
118+
# @show num_reductions
113119
iszero(num_reductions) && return 1
114120
# So if num_reductions > 0, we set the unroll factor to be high enough so that the CPU can be kept busy
115121
# if there are, U = max(1, round(Int, max(latency) * throughput / num_reductions)) = max(1, round(Int, latency / (recip_throughput * num_reductions)))
@@ -119,7 +125,7 @@ function determine_unroll_factor(
119125
visited_nodes = fill(false, length(operations(ls)))
120126
for op operations(ls)
121127
if isreduction(op) && dependson(op, unrolled)
122-
sl, rt = depchain_cost!(visited_nodes, instruction(op), unrolled, Wshift, size_T)
128+
sl, rt = depchain_cost!(visited_nodes, op, unrolled, Wshift, size_T)
123129
latency = max(sl, latency)
124130
recip_throughput += rt
125131
end
@@ -131,17 +137,19 @@ function tile_cost(X, U, T)
131137
X[1] + X[4] + X[2] / T + X[3] / U
132138
end
133139
function solve_tilesize(X, R)
140+
first(R) == 0 && return -1,-1,Inf #solve_smalltilesize(X, R, Umax, Tmax)
134141
# We use lagrange multiplier to finding floating point values for U and T
135142
# first solving for U via quadratic formula
136143
# X is vector of costs, and R is of register pressures
137-
@show X
138-
@show R
144+
# @show X
145+
# @show R
139146
RR = VectorizationBase.REGISTER_COUNT - R[3] - R[4]
140147
a = (R[1])^2*X[2] - (R[2])^2*R[1]*X[3]/RR
141148
b = 2*R[1]*R[2]*X[3]
142149
c = -RR*R[1]*X[3]
143150
Ufloat = (sqrt(b^2 - 4a*c) - b) / (2a)
144151
Tfloat = (RR - Ufloat*R[2])/(Ufloat*R[1])
152+
# @show Ufloat, Tfloat
145153
Ulow = max(1, floor(Int, Ufloat)) # must be at least 1
146154
Tlow = max(1, floor(Int, Tfloat)) # must be at least 1
147155
Uhigh = Ulow + 1 #ceil(Int, Ufloat)
@@ -150,14 +158,17 @@ function solve_tilesize(X, R)
150158
RR = VectorizationBase.REGISTER_COUNT - R[3] - R[4]
151159
U, T = Ulow, Tlow
152160
tcost = tile_cost(X, Ulow, Tlow)
153-
if RR > Ulow*Thigh*R[1] + Ulow*R[2]
161+
# @show Ulow*Thigh*R[1] + Ulow*R[2]
162+
if RR Ulow*Thigh*R[1] + Ulow*R[2]
154163
tcost_temp = tile_cost(X, Ulow, Thigh)
164+
# @show tcost_temp, tcost
155165
if tcost_temp < tcost
156166
tcost = tcost_temp
157167
U, T = Ulow, Thigh
158168
end
159169
end
160-
if RR > Uhigh*Tlow*R[1] + Uhigh*R[2]
170+
# @show Uhigh*Tlow*R[1] + Uhigh*R[2]
171+
if RR Uhigh*Tlow*R[1] + Uhigh*R[2]
161172
tcost_temp = tile_cost(X, Uhigh, Tlow)
162173
if tcost_temp < tcost
163174
tcost = tcost_temp
@@ -175,7 +186,9 @@ end
175186
function solve_tilesize_constT(X, R, T)
176187
floor(Int, (VectorizationBase.REGISTER_COUNT - R[3] - R[4]) / (T * R[1] + R[2]))
177188
end
189+
# Tiling here is about alleviating register pressure for the UxT
178190
function solve_tilesize(X, R, Umax, Tmax)
191+
first(R) == 0 && return -1,-1,Inf #solve_smalltilesize(X, R, Umax, Tmax)
179192
U, T, cost = solve_tilesize(X, R)
180193
U_too_large = U > Umax
181194
T_too_large = T > Tmax
@@ -215,9 +228,12 @@ function evaluate_cost_tile(
215228
# cost_mat[1] / ( unrolled * tiled)
216229
# cost_mat[2] / ( tiled)
217230
# cost_mat[3] / ( unrolled)
218-
# cost_mat[4]
231+
# cost_mat[4]
232+
# @show order
219233
cost_vec = zeros(Float64, 4)
220234
reg_pressure = zeros(Int, 4)
235+
@inbounds reg_pressure[2] = 1
236+
@inbounds reg_pressure[3] = 1
221237
for n 1:N
222238
itersym = order[n]
223239
# Add to set of defined symbles
@@ -235,16 +251,17 @@ function evaluate_cost_tile(
235251
included_vars[id] && continue
236252
# it must also be a subset of defined symbols
237253
loopdependencies(op) nested_loop_syms || continue
238-
# @show nested_loop_syms
239-
# @show reduceddependencies(op)
254+
# # @show nested_loop_syms
255+
# # @show reduceddependencies(op)
240256
rd = reduceddependencies(op)
241257
hasintersection(rd, nested_loop_syms[1:end-length(rd)]) && return 0,0,Inf
242258
included_vars[id] = true
243259
rt, lat, rp = cost(op, unrolled, Wshift, size_T)
244-
@show instruction(op), rt, lat, rp, iter
260+
# @show instruction(op), rt, lat, rp, iter
245261
rt *= iter
246262
isunrolled = unrolled loopdependencies(op)
247263
istiled = tiled loopdependencies(op)
264+
# @show isunrolled, istiled
248265
if isunrolled && istiled # no cost decrease; cost must be repeated
249266
cost_vec[1] += rt
250267
reg_pressure[1] += rp
@@ -315,8 +332,8 @@ end
315332
# that I could come up with.
316333
function Base.iterate(lo::LoopOrders, state)
317334
advance_state!(state) || return nothing
318-
# @show state
319-
syms = copy!(lo.buff, lo.syms)
335+
# # @show state
336+
syms = copyto!(lo.buff, lo.syms)
320337
for i eachindex(state)
321338
sᵢ = state[i]
322339
sᵢ == 0 || swap!(syms, i, i + sᵢ)
@@ -340,15 +357,15 @@ function choose_unroll_order(ls::LoopSet, lowest_cost::Float64 = Inf)
340357
end
341358
function choose_tile(ls::LoopSet)
342359
lo = LoopOrders(ls)
343-
best_order = lo.syms
360+
best_order = copy(lo.syms)
344361
new_order, state = iterate(lo) # right now, new_order === best_order
345362
U, T, lowest_cost = 0, 0, Inf
346363
while true
347364
U_temp, T_temp, cost_temp = evaluate_cost_tile(ls, new_order)
348365
if cost_temp < lowest_cost
349366
lowest_cost = cost_temp
350367
U, T = U_temp, T_temp
351-
best_order = new_order
368+
copyto!(best_order, new_order)
352369
end
353370
iter = iterate(lo, state)
354371
iter === nothing && return best_order, U, T, lowest_cost
@@ -363,8 +380,10 @@ function choose_order(ls::LoopSet)
363380
end
364381
uorder, uc = choose_unroll_order(ls, tc)
365382
if num_loops(ls) <= 1 || tc > uc # if tc == uc, then that probably means we want tc, and no unrolled managed to beat the tiled cost
383+
# copyto!(ls.loop_order.loopnames, uorder)
366384
return uorder, determine_unroll_factor(ls, uorder), -1
367385
else
386+
# copyto!(ls.loop_order.loopnames, torder)
368387
return torder, tU, tT
369388
end
370389
end

src/graphs.jl

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ LoopOrder() = LoopOrder(Vector{Operation}[],Symbol[])
6060
Base.empty!(lo::LoopOrder) = foreach(empty!, lo.oporder)
6161
function Base.resize!(lo::LoopOrder, N::Int)
6262
Nold = length(lo.loopnames)
63-
resize!(lo.oporder, 24N)
64-
for n 24Nold+1:24N
63+
resize!(lo.oporder, 32N)
64+
for n 32Nold+1:32N
6565
lo.oporder[n] = Operation[]
6666
end
6767
resize!(lo.loopnames, N)
6868
lo
6969
end
70-
Base.size(lo::LoopOrder) = (3,2,2,2,length(lo.loopnames))
70+
Base.size(lo::LoopOrder) = (4,2,2,2,length(lo.loopnames))
7171
Base.@propagate_inbounds Base.getindex(lo::LoopOrder, i::Int) = lo.oporder[i]
7272
Base.@propagate_inbounds Base.getindex(lo::LoopOrder, i...) = lo.oporder[LinearIndices(size(lo))[i...]]
7373

@@ -80,6 +80,7 @@ struct LoopSet
8080
loop_order::LoopOrder
8181
# stridesets::Dict{ShortVector{Symbol},ShortVector{Symbol}}
8282
preamble::Expr # TODO: add preamble to lowering
83+
includedarrays::Vector{Symbol}
8384
end
8485
function LoopSet()
8586
LoopSet(
@@ -88,13 +89,14 @@ function LoopSet()
8889
Operation[],
8990
Int[],
9091
LoopOrder(),
91-
Expr(:block,)
92+
Expr(:block,),
93+
Symbol[]
9294
)
9395
end
9496
num_loops(ls::LoopSet) = length(ls.loops)
9597
function oporder(ls::LoopSet)
9698
N = length(ls.loop_order.loopnames)
97-
reshape(ls.loop_order.oporder, (3,2,2,2,N))
99+
reshape(ls.loop_order.oporder, (4,2,2,2,N))
98100
end
99101
names(ls::LoopSet) = ls.loop_order.loopnames
100102
isstaticloop(ls::LoopSet, s::Symbol) = ls.loops[s].hintexact
@@ -163,7 +165,7 @@ function register_single_loop!(ls::LoopSet, looprange::Expr)
163165
Loop(itersym, N)
164166
end
165167
elseif f === :eachindex
166-
N = gensym(:loop, itersym)
168+
N = gensym(Symbol(:loop, itersym))
167169
push!(ls.preamble.args, Expr(:(=), N, Expr(:call, :length, r.args[2])))
168170
Loop(itersym, N)
169171
else
@@ -191,11 +193,19 @@ function add_loop!(ls::LoopSet, q::Expr, elementbytes::Int = 8)
191193
Base.push!(ls, q, elementbytes)
192194
end
193195
end
196+
function add_vptr!(ls::LoopSet, indexed::Symbol)
197+
if indexed ls.includedarrays
198+
push!(ls.includedarrays, indexed)
199+
push!(ls.preamble.args, Expr(:(=), Symbol(:vptr_, indexed), Expr(:call, Expr(:(.), :VectorizationBase, QuoteNode(:stridedpointer)), indexed)))
200+
end
201+
nothing
202+
end
194203

195204
function add_load!(
196205
ls::LoopSet, var::Symbol, indexed::Symbol, indices::AbstractVector, elementbytes::Int = 8
197206
)
198207
op = Operation( length(operations(ls)), var, elementbytes, :getindex, memload, indices, [indexed], NOPARENTS )
208+
add_vptr!(ls, indexed)
199209
pushop!(ls, op, var)
200210
end
201211
function add_load_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
@@ -226,12 +236,26 @@ function setdiffv!(s3::AbstractVector{T}, s1::AbstractVector{T}, s2::AbstractVec
226236
(s s2) || (s s3 && push!(s3, s))
227237
end
228238
end
229-
function add_constant!(ls::LoopSet, var::Symbol, elementbytes::Int = 8, deps = NODEPENDENCY)
230-
pushop!(ls, Operation(length(operations(ls)), var, elementbytes, :undef, constant, deps, NODEPENDENCY, NOPARENTS), var)
239+
# This version has no dependencies, and thus will not be lowered
240+
### if it is a literal, that literal is either var"##ZERO#Float##", var"##ONE#Float##", or has to have been assigned to var in the preamble.
241+
# if it is a literal, that literal has to have been assigned to var in the preamble.
242+
function add_constant!(ls::LoopSet, var::Symbol, elementbytes::Int = 8)
243+
pushop!(ls, Operation(length(operations(ls)), var, elementbytes, Symbol("##CONSTANT##"), constant, NODEPENDENCY, NODEPENDENCY, NOPARENTS), var)
231244
end
232-
function add_constant!(ls, var, elementbytes::Int = 8, sym = gensym(:constant), deps = NODEPENDENCY)
245+
function add_constant!(ls::LoopSet, var, elementbytes::Int = 8)
246+
sym = gensym(:temp)
233247
push!(ls.preamble.args, Expr(:(=), sym, var))
234-
add_constant!(ls, sym, elementbytes, deps)
248+
pushop!(ls, Operation(length(operations(ls)), sym, elementbytes, Symbol("##CONSTANT##"), constant, NODEPENDENCY, NODEPENDENCY, NOPARENTS), sym)
249+
end
250+
# This version has loop dependencies. var gets assigned to sym when lowering.
251+
function add_constant!(ls::LoopSet, var::Symbol, deps::Vector{Symbol}, sym::Symbol = gensym(:constant), elementbytes::Int = 8)
252+
# length(deps) == 0 && push!(ls.preamble.args, Expr(:(=), sym, var))
253+
pushop!(ls, Operation(length(operations(ls)), sym, elementbytes, var, constant, deps, NODEPENDENCY, NOPARENTS), sym)
254+
end
255+
function add_constant!(ls::LoopSet, var, deps::Vector{Symbol}, sym::Symbol = gensym(:constant), elementbytes::Int = 8)
256+
sym2 = gensym(:temp)
257+
push!(ls.preamble.args, Expr(:(=), sym2, var))
258+
pushop!(ls, Operation(length(operations(ls)), sym, elementbytes, sym2, constant, deps, NODEPENDENCY, NOPARENTS), sym)
235259
end
236260
function pushparent!(parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, parent::Operation)
237261
push!(parents, parent)
@@ -258,7 +282,7 @@ end
258282
function add_reduction!(
259283
parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet, var::Symbol, elementbytes::Int = 8
260284
)
261-
parent = get!(ls.opdict, var) do
285+
get!(ls.opdict, var) do
262286
p = add_constant!(ls, var, elementbytes)
263287
push!(ls.outer_reductions, identifier(p))
264288
p
@@ -287,6 +311,7 @@ function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
287311
parent = getop(ls, var)
288312
setdiffv!(reduceddeps, deps, loopdependencies(parent))
289313
pushparent!(parents, deps, reduceddeps, parent) # deps and reduced deps will not be disjoint
314+
# append!(reduceddependencies(parent), reduceddeps)
290315
end
291316
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, parents)
292317
pushop!(ls, op, var) # note this overwrites the entry in the operations dict, but not the vector
@@ -296,6 +321,7 @@ function add_store!(
296321
)
297322
parent = getop(ls, var)
298323
op = Operation( length(operations(ls)), indexed, elementbytes, :setindex!, memstore, indices, reduceddependencies(parent), [parent] )
324+
add_vptr!(ls, indexed)
299325
pushop!(ls, op, var)
300326
end
301327
function add_store_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
@@ -335,7 +361,7 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
335361
if RHS isa Expr
336362
add_operation!(ls, LHS, RHS, elementbytes)
337363
else
338-
add_constant!(ls, RHS, elementbytes, LHS, [keys(ls.loops)...])
364+
add_constant!(ls, RHS, [keys(ls.loops)...], LHS, elementbytes)
339365
end
340366
elseif LHS isa Expr
341367
@assert LHS.head === :ref
@@ -363,7 +389,9 @@ end
363389
function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
364390
lo = ls.loop_order
365391
ro = lo.loopnames # reverse order; will have same order as lo
366-
copyto!(lo.names, order)
392+
# @show 1, ro, order
393+
# copyto!(ro, order)
394+
# @show 2, ro, order
367395
empty!(lo)
368396
nloops = length(order)
369397
if loopistiled
@@ -378,17 +406,19 @@ function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
378406
for _n 1:nloops
379407
n = 1 + nloops - _n
380408
ro[_n] = loopsym = order[n]
409+
#loopsym = order[n]
381410
for (id,op) enumerate(operations(ls))
382411
included_vars[id] && continue
383-
loopsym dependencies(op) || continue
412+
loopsym loopdependencies(op) || continue
384413
included_vars[id] = true
385414
isunrolled = (unrolled loopdependencies(op)) + 1
386-
istiled = (loopistiled ? false : (tiled loopdependencies(op))) + 1
387-
optype = Int(op.node_type)
415+
istiled = (loopistiled ? (tiled loopdependencies(op)) : false) + 1
416+
optype = Int(op.node_type) + 1
388417
after_loop = (length(reduceddependencies(op)) > 0) + 1
389418
push!(lo[optype,isunrolled,istiled,after_loop,_n], op)
390419
end
391420
end
421+
@show 3, ro, order
392422
end
393423

394424

0 commit comments

Comments
 (0)