@@ -82,7 +82,12 @@ struct LoopSet
82
82
# stridesets::Dict{ShortVector{Symbol},ShortVector{Symbol}}
83
83
preamble:: Expr # TODO : add preamble to lowering
84
84
includedarrays:: Vector{Tuple{Symbol,Int}}
85
+ syms_aliasing_refs:: Vector{Symbol} # O(N) search is faster at small sizes
86
+ refs_aliasing_syms:: Vector{ArrayReference}
87
+ # sym_to_ref_aliases::Dict{Symbol,ArrayReference}
88
+ # ref_to_sym_aliases::Dict{ArrayReference,Symbol}
85
89
end
90
+
86
91
function includesarray (ls:: LoopSet , array:: Symbol )
87
92
for (a,i) ∈ ls. includedarrays
88
93
a === array && return i
@@ -97,7 +102,11 @@ function LoopSet()
97
102
Int[],
98
103
LoopOrder (),
99
104
Expr (:block ,),
100
- Tuple{Symbol,Int}[]
105
+ Tuple{Symbol,Int}[],
106
+ Symbol[],
107
+ ArrayReference[]
108
+ # Dict{Symbol,ArrayReference}()
109
+ # Dict{ArrayReference,Symbol}()
101
110
)
102
111
end
103
112
num_loops (ls:: LoopSet ) = length (ls. loops)
@@ -209,17 +218,33 @@ function add_vptr!(ls::LoopSet, indexed::Symbol, id::Int)
209
218
end
210
219
211
220
function add_load! (
212
- ls:: LoopSet , var:: Symbol , indexed :: Symbol , indices :: AbstractVector , elementbytes:: Int = 8
221
+ ls:: LoopSet , var:: Symbol , ref :: ArrayReference , elementbytes:: Int = 8
213
222
)
214
- op = Operation ( length (operations (ls)), var, elementbytes, :getindex , memload, indices, [indexed], NOPARENTS )
223
+ if ref. loaded[] == true
224
+ op = getop (ls, var)
225
+ @assert var === op. variable
226
+ return op
227
+ end
228
+ push! (ls. syms_aliasing_refs, var)
229
+ push! (ls. refs_aliasing_syms, ref)
230
+ ref. loaded[] = true
231
+ # ls.sym_to_ref_aliases[ var ] = ref
232
+ # ls.ref_to_sym_aliases[ ref ] = var
233
+ op = Operation (
234
+ length (operations (ls)), var, elementbytes,
235
+ :getindex , memload, loopdependencies (ref),
236
+ NODEPENDENCY, NOPARENTS, ref
237
+ )
215
238
add_vptr! (ls, indexed, identifier (op))
216
239
pushop! (ls, op, var)
217
240
end
218
241
function add_load_ref! (ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 )
219
- add_load! (ls, var, ex. args[1 ], @view (ex. args[2 : end ]), elementbytes)
242
+ ref = ref_from_ref (ex)
243
+ add_load! (ls, var, ref, elementbytes)
220
244
end
221
245
function add_load_getindex! (ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 )
222
- add_load! (ls, var, ex. args[2 ], @view (ex. args[3 : end ]), elementbytes)
246
+ ref = ref_from_getindex (ex)
247
+ add_load! (ls, var, ref, elementbytes)
223
248
end
224
249
function instruction (x)
225
250
x isa Symbol ? x : last (x. args). value
@@ -274,20 +299,31 @@ function pushparent!(parents::Vector{Operation}, deps::Vector{Symbol}, reducedde
274
299
end
275
300
function maybe_cse_load! (ls:: LoopSet , expr:: Expr , elementbytes:: Int = 8 )
276
301
if expr. head === :ref
277
- array = first (expr. args):: Symbol
278
- args = @view expr. args[2 : end ]
302
+ offset = 0
303
+ # array = first(expr.args)::Symbol
304
+ # args = @view expr.args[2:end]
305
+ # ref = ref_from_ref(expr)
279
306
elseif expr. head === :call && first (expr. args) === :getindex
280
- array = (expr. args[2 ]):: Symbol
281
- args = @view expr. args[3 : end ]
307
+ offset = 1
308
+ # array = (expr.args[2])::Symbol
309
+ # args = @view expr.args[3:end]
310
+ # ref = ref_from_getindex(expr)
282
311
else
283
312
return add_operation! (ls, gensym (:temporary ), expr, elementbytes)
284
313
end
285
- id = includesarray (ls, array)
286
- if id > 0
287
- ls. operations[id]
288
- else
314
+ ref = ArrayReference ( ex. args[1 + offset], @view (ex. args[2 + offset: end ]) ):: ArrayReference
315
+ id = findfirst (r -> r == ref, ls. refs_aliasing_syms)
316
+ if id === nothing
289
317
add_load! ( ls, gensym (:temporary ), array, args, elementbytes )
318
+ else
319
+ ls. syms_aliasing_refs[id]
290
320
end
321
+ # id = includesarray(ls, array)
322
+ # if id > 0
323
+ # ls.operations[id]
324
+ # else
325
+ # add_load!( ls, gensym(:temporary), array, args, elementbytes )
326
+ # end
291
327
end
292
328
function add_parent! (
293
329
parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet , var, elementbytes:: Int = 8
@@ -325,7 +361,7 @@ function add_reduction_update_parent!(
325
361
parent. instruction === Symbol (" ##CONSTANT##" ) && push! (ls. outer_reductions, identifier (op))
326
362
pushop! (ls, op, var) # note this overwrites the entry in the operations dict, but not the vector
327
363
end
328
- function add_compute! (ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 )
364
+ function add_compute! (ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 , ref = nothing )
329
365
@assert ex. head === :call
330
366
instr = instruction (first (ex. args)):: Symbol
331
367
args = @view (ex. args[2 : end ])
@@ -338,6 +374,9 @@ function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
338
374
if arg === var
339
375
reduction = true
340
376
add_reduction! (parents, deps, reduceddeps, ls, arg, elementbytes)
377
+ elseif ref == arg
378
+ reduction = true
379
+ add_load! (ls, var, ref, elementbytes)
341
380
else
342
381
add_parent! (parents, deps, reduceddeps, ls, arg, elementbytes)
343
382
end
@@ -402,9 +441,12 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
402
441
if RHS isa Symbol
403
442
lrhs = RHS
404
443
elseif RHS isa Expr
444
+ # need to check of LHS appears in RHS
405
445
# assign RHS to lrhs
406
- lrhs = gensym (:RHS )
407
- add_operation! (ls, lrhs, RHS, elementbytes)
446
+ ref = ArrayReference (LHS)
447
+ id = findfirst (r -> r == ref, ls. refs_aliasing_syms)
448
+ lrhs = id === nothing ? gensym (:RHS ) : ls. syms_aliasing_refs[id]
449
+ add_operation! (ls, lrhs, RHS, elementbytes, ref)
408
450
end
409
451
add_store_ref! (ls, lrhs, LHS, elementbytes)
410
452
else
0 commit comments