1
-
1
+ lv (x) = Expr (:(.), :LoopVectorization , QuoteNode (x))
2
2
3
3
isdense (:: Type{<:DenseArray} ) = true
4
4
@@ -72,6 +72,7 @@ Base.@propagate_inbounds Base.getindex(lo::LoopOrder, i::Int) = lo.oporder[i]
72
72
Base. @propagate_inbounds Base. getindex (lo:: LoopOrder , i... ) = lo. oporder[LinearIndices (size (lo))[i... ]]
73
73
74
74
# Must make it easy to iterate
75
+ # outer_reductions is a vector of indixes (within operation vectors) of the reduction operation, eg the vmuladd op in a dot product
75
76
struct LoopSet
76
77
loops:: Dict{Symbol,Loop} # sym === loops[sym].itersymbol
77
78
opdict:: Dict{Symbol,Operation}
@@ -80,7 +81,13 @@ struct LoopSet
80
81
loop_order:: LoopOrder
81
82
# stridesets::Dict{ShortVector{Symbol},ShortVector{Symbol}}
82
83
preamble:: Expr # TODO : add preamble to lowering
83
- includedarrays:: Vector{Symbol}
84
+ includedarrays:: Vector{Tuple{Symbol,Int}}
85
+ end
86
+ function includesarray (ls:: LoopSet , array:: Symbol )
87
+ for (a,i) ∈ ls. includedarrays
88
+ a === array && return i
89
+ end
90
+ - 1
84
91
end
85
92
function LoopSet ()
86
93
LoopSet (
@@ -90,7 +97,7 @@ function LoopSet()
90
97
Int[],
91
98
LoopOrder (),
92
99
Expr (:block ,),
93
- Symbol[]
100
+ Tuple{ Symbol,Int} []
94
101
)
95
102
end
96
103
num_loops (ls:: LoopSet ) = length (ls. loops)
@@ -106,13 +113,13 @@ looprangesym(ls::LoopSet, s::Symbol) = ls.loops[s].rangesym
106
113
getop (ls:: LoopSet , s:: Symbol ) = ls. opdict[s]
107
114
getop (ls:: LoopSet , i:: Int ) = ls. operations[i + 1 ]
108
115
109
- function looprange (ls:: LoopSet , s:: Symbol , incr:: Int = 1 )
116
+ function looprange (ls:: LoopSet , s:: Symbol , incr:: Int = 1 , mangledname :: Symbol = s )
110
117
loop = ls. loops[s]
111
118
incr -= 1
112
119
if iszero (incr)
113
- Expr (:call , :< , s , loop. hintexact ? loop. rangehint : loop. rangesym)
120
+ Expr (:call , :< , mangledname , loop. hintexact ? loop. rangehint : loop. rangesym)
114
121
else
115
- Expr (:call , :< , s , loop. hintexact ? loop. rangehint - incr : Expr (:call , :- , loop. rangesym, incr))
122
+ Expr (:call , :< , mangledname , loop. hintexact ? loop. rangehint - incr : Expr (:call , :- , loop. rangesym, incr))
116
123
end
117
124
end
118
125
function Base. length (ls:: LoopSet , is:: Symbol )
@@ -193,10 +200,10 @@ function add_loop!(ls::LoopSet, q::Expr, elementbytes::Int = 8)
193
200
Base. push! (ls, q, elementbytes)
194
201
end
195
202
end
196
- function add_vptr! (ls:: LoopSet , indexed:: Symbol )
197
- if indexed ∉ ls . includedarrays
198
- push! (ls. includedarrays, indexed)
199
- push! (ls. preamble. args, Expr (:(= ), Symbol (:vptr_ , indexed), Expr (:call , Expr (:(.), :VectorizationBase , QuoteNode ( : stridedpointer) ), indexed)))
203
+ function add_vptr! (ls:: LoopSet , indexed:: Symbol , id :: Int )
204
+ if includesarray (ls, indexed) < 0
205
+ push! (ls. includedarrays, ( indexed, id) )
206
+ push! (ls. preamble. args, Expr (:(= ), Symbol (" ##vptr##_ " , indexed), Expr (:call , lv ( : stridedpointer ), indexed)))
200
207
end
201
208
nothing
202
209
end
@@ -205,7 +212,7 @@ function add_load!(
205
212
ls:: LoopSet , var:: Symbol , indexed:: Symbol , indices:: AbstractVector , elementbytes:: Int = 8
206
213
)
207
214
op = Operation ( length (operations (ls)), var, elementbytes, :getindex , memload, indices, [indexed], NOPARENTS )
208
- add_vptr! (ls, indexed)
215
+ add_vptr! (ls, indexed, identifier (op) )
209
216
pushop! (ls, op, var)
210
217
end
211
218
function add_load_ref! (ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 )
@@ -240,12 +247,12 @@ end
240
247
# ## if it is a literal, that literal is either var"##ZERO#Float##", var"##ONE#Float##", or has to have been assigned to var in the preamble.
241
248
# if it is a literal, that literal has to have been assigned to var in the preamble.
242
249
function add_constant! (ls:: LoopSet , var:: Symbol , elementbytes:: Int = 8 )
243
- pushop! (ls, Operation (length (operations (ls)), var, elementbytes, Symbol (" ##CONSTANT##" ), constant, NODEPENDENCY, NODEPENDENCY , NOPARENTS), var)
250
+ pushop! (ls, Operation (length (operations (ls)), var, elementbytes, Symbol (" ##CONSTANT##" ), constant, NODEPENDENCY, Symbol[] , NOPARENTS), var)
244
251
end
245
252
function add_constant! (ls:: LoopSet , var, elementbytes:: Int = 8 )
246
253
sym = gensym (:temp )
247
254
push! (ls. preamble. args, Expr (:(= ), sym, var))
248
- pushop! (ls, Operation (length (operations (ls)), sym, elementbytes, Symbol (" ##CONSTANT##" ), constant, NODEPENDENCY, NODEPENDENCY , NOPARENTS), sym)
255
+ pushop! (ls, Operation (length (operations (ls)), sym, elementbytes, Symbol (" ##CONSTANT##" ), constant, NODEPENDENCY, Symbol[] , NOPARENTS), sym)
249
256
end
250
257
# This version has loop dependencies. var gets assigned to sym when lowering.
251
258
function add_constant! (ls:: LoopSet , var:: Symbol , deps:: Vector{Symbol} , sym:: Symbol = gensym (:constant ), elementbytes:: Int = 8 )
260
267
function pushparent! (parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , parent:: Operation )
261
268
push! (parents, parent)
262
269
mergesetv! (deps, loopdependencies (parent))
263
- isload (parent) || mergesetv! (reduceddeps, reduceddependencies (parent))
270
+ if ! (isload (parent) || isconstant (parent))
271
+ mergesetv! (reduceddeps, reduceddependencies (parent))
272
+ end
264
273
nothing
265
274
end
275
+ function maybe_cse_load! (ls:: LoopSet , expr:: Expr , elementbytes:: Int = 8 )
276
+ if expr. head === :ref
277
+ array = first (expr. args):: Symbol
278
+ args = @view expr. args[2 : end ]
279
+ elseif expr. head === :call && first (expr. args) === :getindex
280
+ array = (expr. args[2 ]):: Symbol
281
+ args = @view expr. args[3 : end ]
282
+ else
283
+ return add_operation! (ls, gensym (:temporary ), expr, elementbytes)
284
+ end
285
+ id = includesarray (ls, array)
286
+ if id > 0
287
+ ls. operations[id]
288
+ else
289
+ add_load! ( ls, gensym (:temporary ), array, args, elementbytes )
290
+ end
291
+ end
266
292
function add_parent! (
267
293
parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet , var, elementbytes:: Int = 8
268
294
)
@@ -271,9 +297,8 @@ function add_parent!(
271
297
# might add constant
272
298
add_constant! (ls, var, elementbytes)
273
299
end
274
- elseif var isa Expr
275
- temp = gensym (:temporary )
276
- add_operation! (ls, temp, var, elementbytes)
300
+ elseif var isa Expr # CSE candidate
301
+ maybe_cse_load! (ls, var, elementbytes)
277
302
else # assumed constant
278
303
add_constant! (ls, var, elementbytes)
279
304
end
@@ -283,13 +308,23 @@ function add_reduction!(
283
308
parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet , var:: Symbol , elementbytes:: Int = 8
284
309
)
285
310
get! (ls. opdict, var) do
286
- p = add_constant! (ls, var, elementbytes)
287
- push! (ls. outer_reductions, identifier (p))
288
- p
311
+ add_constant! (ls, var, elementbytes)
312
+ # push!(ls.outer_reductions, identifier(p))
313
+ # p
289
314
end
290
315
# pushparent!(parents, deps, reduceddeps, parent)
291
316
end
292
-
317
+ function add_reduction_update_parent! (
318
+ parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet ,
319
+ var:: Symbol , instr:: Symbol , elementbytes:: Int = 8
320
+ )
321
+ parent = getop (ls, var)
322
+ setdiffv! (reduceddeps, deps, loopdependencies (parent))
323
+ pushparent! (parents, deps, reduceddeps, parent) # deps and reduced deps will not be disjoint
324
+ op = Operation (length (operations (ls)), var, elementbytes, instr, compute, deps, reduceddeps, parents)
325
+ parent. instruction === Symbol (" ##CONSTANT##" ) && push! (ls. outer_reductions, identifier (op))
326
+ pushop! (ls, op, var) # note this overwrites the entry in the operations dict, but not the vector
327
+ end
293
328
function add_compute! (ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 )
294
329
@assert ex. head === :call
295
330
instr = instruction (first (ex. args)):: Symbol
@@ -308,20 +343,18 @@ function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
308
343
end
309
344
end
310
345
if reduction # arg[reduction] is the reduction
311
- parent = getop ( ls, var)
312
- setdiffv! (reduceddeps, deps, loopdependencies (parent))
313
- pushparent! (parents, deps, reduceddeps, parent) # deps and reduced deps will not be disjoint
314
- # append!(reduceddependencies(parent), reduceddeps )
346
+ add_reduction_update_parent! (parents, deps, reduceddeps, ls, var, instr, elementbytes )
347
+ else
348
+ op = Operation ( length ( operations (ls)), var, elementbytes, instr, compute, deps, reduceddeps, parents)
349
+ pushop! (ls, op, var )
315
350
end
316
- op = Operation (length (operations (ls)), var, elementbytes, instr, compute, deps, reduceddeps, parents)
317
- pushop! (ls, op, var) # note this overwrites the entry in the operations dict, but not the vector
318
351
end
319
352
function add_store! (
320
353
ls:: LoopSet , indexed:: Symbol , var:: Symbol , indices:: AbstractVector , elementbytes:: Int = 8
321
354
)
322
355
parent = getop (ls, var)
323
356
op = Operation ( length (operations (ls)), indexed, elementbytes, :setindex! , memstore, indices, reduceddependencies (parent), [parent] )
324
- add_vptr! (ls, indexed)
357
+ add_vptr! (ls, indexed, identifier (op) )
325
358
pushop! (ls, op, var)
326
359
end
327
360
function add_store_ref! (ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 )
@@ -414,11 +447,11 @@ function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
414
447
isunrolled = (unrolled ∈ loopdependencies (op)) + 1
415
448
istiled = (loopistiled ? (tiled ∈ loopdependencies (op)) : false ) + 1
416
449
optype = Int (op. node_type) + 1
417
- after_loop = (length (reduceddependencies (op)) > 0 ) + 1
450
+ after_loop = isload (op) ? 1 : (length (reduceddependencies (op)) > 0 ) + 1
418
451
push! (lo[optype,isunrolled,istiled,after_loop,_n], op)
419
452
end
420
453
end
421
- @show 3 , ro, order
454
+ # 3, ro, order
422
455
end
423
456
424
457
0 commit comments