@@ -60,14 +60,14 @@ LoopOrder() = LoopOrder(Vector{Operation}[],Symbol[])
60
60
Base. empty! (lo:: LoopOrder ) = foreach (empty!, lo. oporder)
61
61
function Base. resize! (lo:: LoopOrder , N:: Int )
62
62
Nold = length (lo. loopnames)
63
- resize! (lo. oporder, 24 N )
64
- for n ∈ 24 Nold + 1 : 24 N
63
+ resize! (lo. oporder, 32 N )
64
+ for n ∈ 32 Nold + 1 : 32 N
65
65
lo. oporder[n] = Operation[]
66
66
end
67
67
resize! (lo. loopnames, N)
68
68
lo
69
69
end
70
- Base. size (lo:: LoopOrder ) = (3 ,2 ,2 ,2 ,length (lo. loopnames))
70
+ Base. size (lo:: LoopOrder ) = (4 ,2 ,2 ,2 ,length (lo. loopnames))
71
71
Base. @propagate_inbounds Base. getindex (lo:: LoopOrder , i:: Int ) = lo. oporder[i]
72
72
Base. @propagate_inbounds Base. getindex (lo:: LoopOrder , i... ) = lo. oporder[LinearIndices (size (lo))[i... ]]
73
73
@@ -80,6 +80,7 @@ struct LoopSet
80
80
loop_order:: LoopOrder
81
81
# stridesets::Dict{ShortVector{Symbol},ShortVector{Symbol}}
82
82
preamble:: Expr # TODO : add preamble to lowering
83
+ includedarrays:: Vector{Symbol}
83
84
end
84
85
function LoopSet ()
85
86
LoopSet (
@@ -88,13 +89,14 @@ function LoopSet()
88
89
Operation[],
89
90
Int[],
90
91
LoopOrder (),
91
- Expr (:block ,)
92
+ Expr (:block ,),
93
+ Symbol[]
92
94
)
93
95
end
94
96
num_loops (ls:: LoopSet ) = length (ls. loops)
95
97
function oporder (ls:: LoopSet )
96
98
N = length (ls. loop_order. loopnames)
97
- reshape (ls. loop_order. oporder, (3 ,2 ,2 ,2 ,N))
99
+ reshape (ls. loop_order. oporder, (4 ,2 ,2 ,2 ,N))
98
100
end
99
101
names (ls:: LoopSet ) = ls. loop_order. loopnames
100
102
isstaticloop (ls:: LoopSet , s:: Symbol ) = ls. loops[s]. hintexact
@@ -163,7 +165,7 @@ function register_single_loop!(ls::LoopSet, looprange::Expr)
163
165
Loop (itersym, N)
164
166
end
165
167
elseif f === :eachindex
166
- N = gensym (:loop , itersym)
168
+ N = gensym (Symbol ( :loop , itersym) )
167
169
push! (ls. preamble. args, Expr (:(= ), N, Expr (:call , :length , r. args[2 ])))
168
170
Loop (itersym, N)
169
171
else
@@ -191,11 +193,19 @@ function add_loop!(ls::LoopSet, q::Expr, elementbytes::Int = 8)
191
193
Base. push! (ls, q, elementbytes)
192
194
end
193
195
end
196
+ function add_vptr! (ls:: LoopSet , indexed:: Symbol )
197
+ if indexed ∉ ls. includedarrays
198
+ push! (ls. includedarrays, indexed)
199
+ push! (ls. preamble. args, Expr (:(= ), Symbol (:vptr_ , indexed), Expr (:call , Expr (:(.), :VectorizationBase , QuoteNode (:stridedpointer )), indexed)))
200
+ end
201
+ nothing
202
+ end
194
203
195
204
function add_load! (
196
205
ls:: LoopSet , var:: Symbol , indexed:: Symbol , indices:: AbstractVector , elementbytes:: Int = 8
197
206
)
198
207
op = Operation ( length (operations (ls)), var, elementbytes, :getindex , memload, indices, [indexed], NOPARENTS )
208
+ add_vptr! (ls, indexed)
199
209
pushop! (ls, op, var)
200
210
end
201
211
function add_load_ref! (ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 )
@@ -226,12 +236,26 @@ function setdiffv!(s3::AbstractVector{T}, s1::AbstractVector{T}, s2::AbstractVec
226
236
(s ∈ s2) || (s ∉ s3 && push! (s3, s))
227
237
end
228
238
end
229
- function add_constant! (ls:: LoopSet , var:: Symbol , elementbytes:: Int = 8 , deps = NODEPENDENCY)
230
- pushop! (ls, Operation (length (operations (ls)), var, elementbytes, :undef , constant, deps, NODEPENDENCY, NOPARENTS), var)
239
+ # This version has no dependencies, and thus will not be lowered
240
+ # ## if it is a literal, that literal is either var"##ZERO#Float##", var"##ONE#Float##", or has to have been assigned to var in the preamble.
241
+ # if it is a literal, that literal has to have been assigned to var in the preamble.
242
+ function add_constant! (ls:: LoopSet , var:: Symbol , elementbytes:: Int = 8 )
243
+ pushop! (ls, Operation (length (operations (ls)), var, elementbytes, Symbol (" ##CONSTANT##" ), constant, NODEPENDENCY, NODEPENDENCY, NOPARENTS), var)
231
244
end
232
- function add_constant! (ls, var, elementbytes:: Int = 8 , sym = gensym (:constant ), deps = NODEPENDENCY)
245
+ function add_constant! (ls:: LoopSet , var, elementbytes:: Int = 8 )
246
+ sym = gensym (:temp )
233
247
push! (ls. preamble. args, Expr (:(= ), sym, var))
234
- add_constant! (ls, sym, elementbytes, deps)
248
+ pushop! (ls, Operation (length (operations (ls)), sym, elementbytes, Symbol (" ##CONSTANT##" ), constant, NODEPENDENCY, NODEPENDENCY, NOPARENTS), sym)
249
+ end
250
+ # This version has loop dependencies. var gets assigned to sym when lowering.
251
+ function add_constant! (ls:: LoopSet , var:: Symbol , deps:: Vector{Symbol} , sym:: Symbol = gensym (:constant ), elementbytes:: Int = 8 )
252
+ # length(deps) == 0 && push!(ls.preamble.args, Expr(:(=), sym, var))
253
+ pushop! (ls, Operation (length (operations (ls)), sym, elementbytes, var, constant, deps, NODEPENDENCY, NOPARENTS), sym)
254
+ end
255
+ function add_constant! (ls:: LoopSet , var, deps:: Vector{Symbol} , sym:: Symbol = gensym (:constant ), elementbytes:: Int = 8 )
256
+ sym2 = gensym (:temp )
257
+ push! (ls. preamble. args, Expr (:(= ), sym2, var))
258
+ pushop! (ls, Operation (length (operations (ls)), sym, elementbytes, sym2, constant, deps, NODEPENDENCY, NOPARENTS), sym)
235
259
end
236
260
function pushparent! (parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , parent:: Operation )
237
261
push! (parents, parent)
258
282
function add_reduction! (
259
283
parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet , var:: Symbol , elementbytes:: Int = 8
260
284
)
261
- parent = get! (ls. opdict, var) do
285
+ get! (ls. opdict, var) do
262
286
p = add_constant! (ls, var, elementbytes)
263
287
push! (ls. outer_reductions, identifier (p))
264
288
p
@@ -287,6 +311,7 @@ function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
287
311
parent = getop (ls, var)
288
312
setdiffv! (reduceddeps, deps, loopdependencies (parent))
289
313
pushparent! (parents, deps, reduceddeps, parent) # deps and reduced deps will not be disjoint
314
+ # append!(reduceddependencies(parent), reduceddeps)
290
315
end
291
316
op = Operation (length (operations (ls)), var, elementbytes, instr, compute, deps, reduceddeps, parents)
292
317
pushop! (ls, op, var) # note this overwrites the entry in the operations dict, but not the vector
@@ -296,6 +321,7 @@ function add_store!(
296
321
)
297
322
parent = getop (ls, var)
298
323
op = Operation ( length (operations (ls)), indexed, elementbytes, :setindex! , memstore, indices, reduceddependencies (parent), [parent] )
324
+ add_vptr! (ls, indexed)
299
325
pushop! (ls, op, var)
300
326
end
301
327
function add_store_ref! (ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 )
@@ -335,7 +361,7 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
335
361
if RHS isa Expr
336
362
add_operation! (ls, LHS, RHS, elementbytes)
337
363
else
338
- add_constant! (ls, RHS, elementbytes, LHS, [keys (ls. loops)... ])
364
+ add_constant! (ls, RHS, [keys (ls. loops)... ], LHS, elementbytes )
339
365
end
340
366
elseif LHS isa Expr
341
367
@assert LHS. head === :ref
363
389
function fillorder! (ls:: LoopSet , order:: Vector{Symbol} , loopistiled:: Bool )
364
390
lo = ls. loop_order
365
391
ro = lo. loopnames # reverse order; will have same order as lo
366
- copyto! (lo. names, order)
392
+ # @show 1, ro, order
393
+ # copyto!(ro, order)
394
+ # @show 2, ro, order
367
395
empty! (lo)
368
396
nloops = length (order)
369
397
if loopistiled
@@ -378,17 +406,19 @@ function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
378
406
for _n ∈ 1 : nloops
379
407
n = 1 + nloops - _n
380
408
ro[_n] = loopsym = order[n]
409
+ # loopsym = order[n]
381
410
for (id,op) ∈ enumerate (operations (ls))
382
411
included_vars[id] && continue
383
- loopsym ∈ dependencies (op) || continue
412
+ loopsym ∈ loopdependencies (op) || continue
384
413
included_vars[id] = true
385
414
isunrolled = (unrolled ∈ loopdependencies (op)) + 1
386
- istiled = (loopistiled ? false : (tiled ∈ loopdependencies (op))) + 1
387
- optype = Int (op. node_type)
415
+ istiled = (loopistiled ? (tiled ∈ loopdependencies (op)) : false ) + 1
416
+ optype = Int (op. node_type) + 1
388
417
after_loop = (length (reduceddependencies (op)) > 0 ) + 1
389
418
push! (lo[optype,isunrolled,istiled,after_loop,_n], op)
390
419
end
391
420
end
421
+ @show 3 , ro, order
392
422
end
393
423
394
424
0 commit comments