1
1
2
+ # TODO : FIXME for general case
3
+ unitstride (op, s) = first (loopdependencies (op)) === s
4
+
2
5
function cost (op:: Operation , unrolled:: Symbol , Wshift:: Int , size_T:: Int )
3
6
# Wshift == dependson(op, unrolled) ? Wshift : 0
4
7
# c = first(cost(instruction(op), Wshift, size_T))::Int
@@ -10,12 +13,12 @@ function cost(op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int)
10
13
if opisunrolled
11
14
if ! unitstride (op, unrolled)# || !isdense(op) # need gather/scatter
12
15
r = (1 << Wshift)
13
- c *= r
16
+ srt *= r
14
17
sl *= r
15
18
# else # vmov(a/u)pd
16
19
end
17
20
elseif instr === :setindex! # broadcast or reductionstore; if store we want to penalize reduction
18
- c *= 2
21
+ srt *= 2
19
22
sl *= 2
20
23
end
21
24
end
33
36
function VectorizationBase. pick_vector_width_shift (ls:: LoopSet , u:: Symbol )
34
37
VectorizationBase. pick_vector_width_shift (length (ls, u), biggest_type_size (ls))
35
38
end
36
-
39
+ function hasintersection (a, b)
40
+ for aᵢ ∈ a, bᵢ ∈ b
41
+ aᵢ === bᵢ && return true
42
+ end
43
+ false
44
+ end
37
45
38
46
# evaluates cost of evaluating loop in given order
39
47
# heuristically, could simplify analysis by just unrolling outer loop?
@@ -42,7 +50,7 @@ function evaluate_cost_unroll(
42
50
)
43
51
# included_vars = Set{UInt}()
44
52
included_vars = fill (false , length (operations (ls)))
45
- nested_loop_syms = Set {Symbol} ()
53
+ nested_loop_syms = Symbol[] # Set{Symbol}()
46
54
total_cost = 0.0
47
55
iter = 1.0
48
56
# Need to check if fusion is possible
@@ -122,10 +130,12 @@ end
122
130
function tile_cost (X, U, T)
123
131
X[1 ] + X[4 ] + X[2 ] / T + X[3 ] / U
124
132
end
125
- function solve_tilsize (X, R)
133
+ function solve_tilesize (X, R)
126
134
# We use lagrange multiplier to finding floating point values for U and T
127
135
# first solving for U via quadratic formula
128
136
# X is vector of costs, and R is of register pressures
137
+ @show X
138
+ @show R
129
139
RR = VectorizationBase. REGISTER_COUNT - R[3 ] - R[4 ]
130
140
a = (R[1 ])^ 2 * X[2 ] - (R[2 ])^ 2 * R[1 ]* X[3 ]/ RR
131
141
b = 2 * R[1 ]* R[2 ]* X[3 ]
@@ -196,7 +206,7 @@ function evaluate_cost_tile(
196
206
tiled = order[1 ]
197
207
unrolled = order[2 ]
198
208
included_vars = fill (false , length (operations (ls)))
199
- nested_loop_syms = Set {Symbol} ()
209
+ nested_loop_syms = Symbol[] # Set{Symbol}()
200
210
iter = 1.0
201
211
# Need to check if fusion is possible
202
212
size_T = biggest_type_size (ls)
@@ -225,24 +235,28 @@ function evaluate_cost_tile(
225
235
included_vars[id] && continue
226
236
# it must also be a subset of defined symbols
227
237
loopdependencies (op) ⊆ nested_loop_syms || continue
228
- hasintersection (reduceddependencies (op), nested_loop_syms) && return 0 ,0 ,Inf
238
+ # @show nested_loop_syms
239
+ # @show reduceddependencies(op)
240
+ rd = reduceddependencies (op)
241
+ hasintersection (rd, nested_loop_syms[1 : end - length (rd)]) && return 0 ,0 ,Inf
229
242
included_vars[id] = true
230
243
rt, lat, rp = cost (op, unrolled, Wshift, size_T)
244
+ @show instruction (op), rt, lat, rp, iter
231
245
rt *= iter
232
246
isunrolled = unrolled ∈ loopdependencies (op)
233
247
istiled = tiled ∈ loopdependencies (op)
234
248
if isunrolled && istiled # no cost decrease; cost must be repeated
235
- cost_vec[1 ] = rt
236
- reg_pressure[1 ] = rp
249
+ cost_vec[1 ] + = rt
250
+ reg_pressure[1 ] + = rp
237
251
elseif isunrolled # cost decreased by tiling
238
- cost_vec[2 ] = rt
239
- reg_pressure[2 ] = rp
252
+ cost_vec[2 ] + = rt
253
+ reg_pressure[2 ] + = rp
240
254
elseif istiled # cost decreased by unrolling
241
- cost_vec[3 ] = rt
242
- reg_pressure[3 ] = rp
255
+ cost_vec[3 ] + = rt
256
+ reg_pressure[3 ] + = rp
243
257
else # neither unrolled or tiled
244
- cost_vec[4 ] = rt
245
- reg_pressure[4 ] = rp
258
+ cost_vec[4 ] + = rt
259
+ reg_pressure[4 ] + = rp
246
260
end
247
261
end
248
262
end
@@ -252,13 +266,13 @@ function evaluate_cost_tile(
252
266
if Ustatic
253
267
solve_tilesize (cost_vec, reg_pressure, looprangehint (ls, tiled), looprangehint (ls, unrolled))
254
268
else
255
- solve_tilesize (cost_vec, reg_pressure, looprangehint (ls, tiled), nothing )
269
+ solve_tilesize (cost_vec, reg_pressure, looprangehint (ls, tiled), typemax (Int) )
256
270
end
257
271
else
258
272
if Ustatic
259
- solve_tilesize (cost_vec, reg_pressure, nothing , looprangehint (ls, unrolled))
273
+ solve_tilesize (cost_vec, reg_pressure, typemax (Int) , looprangehint (ls, unrolled))
260
274
else
261
- solve_tilesize (cost_vec, reg_pressure)
275
+ solve_tilesize (cost_vec, reg_pressure)# , typemax(Int), typemax(Int))
262
276
end
263
277
end
264
278
end
@@ -270,7 +284,7 @@ struct LoopOrders
270
284
end
271
285
function LoopOrders (ls:: LoopSet )
272
286
syms = [s for s ∈ keys (ls. loops)]
273
- LoopOrders (syms, similar (buff ))
287
+ LoopOrders (syms, similar (syms ))
274
288
end
275
289
function Base. iterate (lo:: LoopOrders )
276
290
lo. syms, zeros (Int, length (lo. syms))# - 1)
0 commit comments