3
3
# wrong for transposed matrices, and certain views/SubArrays.
4
4
unitstride (op:: Operation , s) = first (getindices (op)) === s
5
5
6
+ function register_pressure (op:: Operation )
7
+ if isconstant (op)
8
+ 0
9
+ else
10
+ instruction_cost (instruction (op)). register_pressure
11
+ end
12
+ end
6
13
function cost (op:: Operation , unrolled:: Symbol , Wshift:: Int , size_T:: Int = op. elementbytes)
7
14
isconstant (op) && return 0.0 , 0 , 1
8
15
# Wshift == dependson(op, unrolled) ? Wshift : 0
@@ -45,6 +52,10 @@ function hasintersection(a, b)
45
52
end
46
53
false
47
54
end
55
+ function num_iterations (N, step)
56
+ iter, rem = divrem (N, step)
57
+ iter + (rem != 0 )
58
+ end
48
59
49
60
# evaluates cost of evaluating loop in given order
50
61
# heuristically, could simplify analysis by just unrolling outer loop?
@@ -62,10 +73,8 @@ function evaluate_cost_unroll(
62
73
for itersym ∈ order
63
74
# Add to set of defined symbles
64
75
push! (nested_loop_syms, itersym)
65
- liter = Float64 (length (ls, itersym))
66
- if itersym === vectorized
67
- liter /= W
68
- end
76
+ looplength = length (ls, itersym)
77
+ liter = itersym === vectorized ? num_iterations (looplength, W) : looplength
69
78
iter *= liter
70
79
# check which vars we can define at this level of loop nest
71
80
for (id,op) ∈ enumerate (operations (ls))
@@ -183,16 +192,16 @@ function determine_unroll_factor(
183
192
roundpow2 (max (1 , round (Int, latency / (recip_throughput * num_reductions) ) ))
184
193
end
185
194
186
- function tile_cost (X, U, T)
187
- X[1 ] + X[4 ] + X[2 ] / T + X[3 ] / U
195
+ function tile_cost (X, U, T, UL, TL )
196
+ X[1 ] + X[4 ] + X[2 ] * ( num_iterations (TL, T) / TL) + X[3 ] * ( num_iterations (UL, U) / UL)
188
197
end
189
- function solve_tilesize (X, R)
198
+ function solve_tilesize (X, R, UL, TL )
190
199
@inbounds any (iszero, (R[1 ],R[2 ],R[3 ])) && return - 1 ,- 1 ,Inf # solve_smalltilesize(X, R, Umax, Tmax)
191
200
# @inbounds any(iszero, (R[1],R[2],R[3])) && return -1,-1,Inf #solve_smalltilesize(X, R, Umax, Tmax)
192
201
# We use a lagrange multiplier to find floating point values for U and T
193
202
# first solving for U via quadratic formula
194
203
# X is vector of costs, and R is of register pressures
195
- RR = REGISTER_COUNT - R[3 ] - R[4 ]
204
+ RR = REGISTER_COUNT - R[3 ] - R[4 ] # RR ≡ RemainingRegisters
196
205
a = (R[1 ])^ 2 * X[2 ] - (R[2 ])^ 2 * R[1 ]* X[3 ]/ RR
197
206
b = 2 * R[1 ]* R[2 ]* X[3 ]
198
207
c = - RR* R[1 ]* X[3 ]
@@ -205,12 +214,12 @@ function solve_tilesize(X, R)
205
214
Uhigh = Ulow + 1 # ceil(Int, Ufloat)
206
215
Thigh = Tlow + 1 # ceil(Int, Tfloat)
207
216
208
- RR = REGISTER_COUNT - R[3 ] - R[4 ]
217
+ # RR = REGISTER_COUNT - R[3] - R[4]
209
218
U, T = Ulow, Tlow
210
- tcost = tile_cost (X, Ulow, Tlow)
219
+ tcost = tile_cost (X, Ulow, Tlow, UL, TL )
211
220
# @show Ulow*Thigh*R[1] + Ulow*R[2]
212
221
if RR ≥ Ulow* Thigh* R[1 ] + Ulow* R[2 ]
213
- tcost_temp = tile_cost (X, Ulow, Thigh)
222
+ tcost_temp = tile_cost (X, Ulow, Thigh, UL, TL )
214
223
# @show tcost_temp, tcost
215
224
if tcost_temp < tcost
216
225
tcost = tcost_temp
@@ -222,7 +231,7 @@ function solve_tilesize(X, R)
222
231
while RR < Uhigh* Tl* R[1 ] + Uhigh* R[2 ]
223
232
Tl -= 1
224
233
end
225
- tcost_temp = tile_cost (X, Uhigh, Tl)
234
+ tcost_temp = tile_cost (X, Uhigh, Tl, UL, TL )
226
235
if tcost_temp < tcost
227
236
tcost = tcost_temp
228
237
U, T = Uhigh, Tl
@@ -243,9 +252,9 @@ function solve_tilesize_constT(ls, T)
243
252
floor (Int, (REGISTER_COUNT - R[3 ] - R[4 ]) / (T * R[1 ] + R[2 ]))
244
253
end
245
254
# Tiling here is about alleviating register pressure for the UxT
246
- function solve_tilesize (X, R, Umax, Tmax)
255
+ function solve_tilesize (X, R, Umax, Tmax, UL, TL )
247
256
first (R) == 0 && return - 1 ,- 1 ,Inf # solve_smalltilesize(X, R, Umax, Tmax)
248
- U, T, cost = solve_tilesize (X, R)
257
+ U, T, cost = solve_tilesize (X, R, UL, TL )
249
258
# T -= T & 1
250
259
# U = min(U, T)
251
260
U_too_large = U > Umax
@@ -264,20 +273,37 @@ function solve_tilesize(X, R, Umax, Tmax)
264
273
end
265
274
U, T, cost
266
275
end
276
+ function maybedemotesize (U:: Int , N:: Int )
277
+ U > 1 || return 1
278
+ Um1 = U - 1
279
+ urep = num_iterations (N, U)
280
+ um1rep = num_iterations (N, Um1)
281
+ um1rep > urep ? U : Um1
282
+ end
267
283
function solve_tilesize (
268
284
ls:: LoopSet , unrolled:: Symbol , tiled:: Symbol ,
269
285
cost_vec:: AbstractVector{Float64} = @view (ls. cost_vec[:,1 ]),
270
286
reg_pressure:: AbstractVector{Int} = @view (ls. reg_pres[:,1 ])
271
287
)
272
288
maxT = 4 # 8
273
289
maxU = 4 # 8
274
- if isstaticloop (ls, tiled)
275
- maxT = min (2 maxT, looprangehint (ls, tiled))
290
+ tiledloop = ls. loops[tiled]
291
+ unrolledloop = ls. loops[unrolled]
292
+ if isstaticloop (tiledloop)
293
+ maxT = min (4 maxT, length (tiledloop))
294
+ end
295
+ if isstaticloop (unrolledloop)
296
+ maxU = min (4 maxU, length (unrolledloop))
297
+ end
298
+ U, T, cost = solve_tilesize (cost_vec, reg_pressure, maxU, maxT, length (unrolledloop), length (tiledloop))
299
+ # heuristic to more evenly divide small numbers of iterations
300
+ if isstaticloop (tiledloop) & T > 1
301
+ T = maybedemotesize (T, length (tiledloop))
276
302
end
277
- if isstaticloop (ls, unrolled )
278
- maxU = min ( 2 maxU, looprangehint (ls, unrolled ))
303
+ if isstaticloop (unrolledloop )
304
+ U = maybedemotesize (U, length (unrolledloop ))
279
305
end
280
- solve_tilesize (cost_vec, reg_pressure, maxU, maxT)
306
+ U, T, cost
281
307
end
282
308
283
309
function set_upstream_family! (adal:: Vector{T} , op:: Operation , val:: T ) where {T}
@@ -306,7 +332,6 @@ function evaluate_cost_tile(
306
332
innerloop = last (order)
307
333
iters = fill (- 99.9 , nops)
308
334
nested_loop_syms = Symbol[]# Set{Symbol}()
309
- iter = 1.0
310
335
# Need to check if fusion is possible
311
336
size_T = biggest_type_size (ls)
312
337
W, Wshift = VectorizationBase. pick_vector_width_shift (length (ls, vectorized), size_T):: Tuple{Int,Int}
@@ -320,14 +345,18 @@ function evaluate_cost_tile(
320
345
reg_pressure = reg_pres_buf (ls)
321
346
# @inbounds reg_pressure[2] = 1
322
347
# @inbounds reg_pressure[3] = 1
348
+ unrollediter = length (ls, unrolled)
349
+ tilediter = length (ls, tiled)
350
+ unrollediter = unrolled === vectorized ? num_iterations (unrollediter, W) : unrollediter # tiled cannot be vectorized, so do not check
351
+ iter:: Int = tilediter * unrollediter
323
352
for n ∈ 1 : N
324
353
itersym = order[n]
325
354
# Add to set of defined symbles
326
355
push! (nested_loop_syms, itersym)
327
- if n = = 1
328
- iter = length (ls, itersym) * length (ls, order[ 2 ]) / W
329
- elseif n > 2
330
- iter *= Float64 ( length (ls, itersym))
356
+ stepsize = 1
357
+ if n > 2
358
+ itersymlooplen = length (ls, itersym)
359
+ iter *= itersym === vectorized ? num_iterations (itersymlooplen, W) : itersymlooplen
331
360
end
332
361
# check which vars we can define at this level of loop nest
333
362
for (id, op) ∈ enumerate (ops)
@@ -480,3 +509,19 @@ function choose_order(ls::LoopSet)
480
509
end
481
510
end
482
511
512
+ function register_pressure (ls:: LoopSet )
513
+ # uses unroll of 1 if not tiling
514
+ if num_loops (ls) > 1
515
+ torder, tvec, tU, tT, tc = choose_tile (ls)
516
+ else
517
+ tc = Inf
518
+ end
519
+ uorder, uvec, uc = choose_unroll_order (ls, tc)
520
+ if num_loops (ls) > 1 && tc ≤ uc # tile
521
+ rp = @view ls. reg_pressure[:,1 ]
522
+ tU * tT * rp[1 ] + tU * rp[2 ] + rp[3 ] + rp[4 ]
523
+ else
524
+ sum (register_pressure, operations (ls))
525
+ end
526
+ end
527
+
0 commit comments