@@ -88,6 +88,13 @@ struct LoopSet
88
88
# ref_to_sym_aliases::Dict{ArrayReference,Symbol}
89
89
end
90
90
91
+ # function op_to_ref(ls::LoopSet, op::Operation)
92
+ # s = op.variable
93
+ # id = findfirst(ls.syms_aliasing_regs)
94
+ # @assert id !== nothing
95
+ # ls.refs_aliasing_syms[id]
96
+ # end
97
+
91
98
function includesarray (ls:: LoopSet , array:: Symbol )
92
99
for (a,i) ∈ ls. includedarrays
93
100
a === array && return i
@@ -235,7 +242,7 @@ function add_load!(
235
242
:getindex , memload, loopdependencies (ref),
236
243
NODEPENDENCY, NOPARENTS, ref
237
244
)
238
- add_vptr! (ls, indexed , identifier (op))
245
+ add_vptr! (ls, ref . array , identifier (op))
239
246
pushop! (ls, op, var)
240
247
end
241
248
function add_load_ref! (ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 )
@@ -311,12 +318,16 @@ function maybe_cse_load!(ls::LoopSet, expr::Expr, elementbytes::Int = 8)
311
318
else
312
319
return add_operation! (ls, gensym (:temporary ), expr, elementbytes)
313
320
end
314
- ref = ArrayReference ( ex. args[1 + offset], @view (ex. args[2 + offset: end ]) ):: ArrayReference
321
+ ref = ArrayReference (
322
+ expr. args[1 + offset],
323
+ @view (expr. args[2 + offset: end ]),
324
+ Ref (false )
325
+ ):: ArrayReference
315
326
id = findfirst (r -> r == ref, ls. refs_aliasing_syms)
316
327
if id === nothing
317
- add_load! ( ls, gensym (:temporary ), array, args , elementbytes )
328
+ add_load! ( ls, gensym (:temporary ), ref , elementbytes )
318
329
else
319
- ls . syms_aliasing_refs[id]
330
+ getop (ls, ls . syms_aliasing_refs[id])
320
331
end
321
332
# id = includesarray(ls, array)
322
333
# if id > 0
@@ -371,7 +382,7 @@ function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8,
371
382
# op = Operation( length(operations(ls)), var, elementbytes, instr, compute )
372
383
reduction = false
373
384
for arg ∈ args
374
- if arg === var
385
+ if var === arg
375
386
reduction = true
376
387
add_reduction! (parents, deps, reduceddeps, ls, arg, elementbytes)
377
388
elseif ref == arg
@@ -389,18 +400,20 @@ function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8,
389
400
end
390
401
end
391
402
function add_store! (
392
- ls:: LoopSet , indexed :: Symbol , var:: Symbol , indices :: AbstractVector , elementbytes:: Int = 8
403
+ ls:: LoopSet , var:: Symbol , ref :: ArrayReference , elementbytes:: Int = 8
393
404
)
394
405
parent = getop (ls, var)
395
- op = Operation ( length (operations (ls)), indexed , elementbytes, :setindex! , memstore, indices , reduceddependencies (parent), [parent] )
396
- add_vptr! (ls, indexed , identifier (op))
406
+ op = Operation ( length (operations (ls)), ref . array , elementbytes, :setindex! , memstore, loopdependencies (ref) , reduceddependencies (parent), [parent], ref )
407
+ add_vptr! (ls, ref . array , identifier (op))
397
408
pushop! (ls, op, var)
398
409
end
399
410
function add_store_ref! (ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 )
400
- add_store! (ls, ex. args[1 ], var, @view (ex. args[2 : end ]), elementbytes)
411
+ ref = ref_from_ref (ex)
412
+ add_store! (ls, var, ref, elementbytes)
401
413
end
402
414
function add_store_setindex! (ls:: LoopSet , ex:: Expr , elementbytes:: Int = 8 )
403
- add_store! (ls, ex. args[2 ], ex. args[3 ], @view (ex. args[4 : end ]), elementbytes)
415
+ ref = ref_from_setindex (ex)
416
+ add_store! (ls, var, ref, elementbytes)
404
417
end
405
418
# add operation assigns X to var
406
419
function add_operation! (
@@ -418,6 +431,21 @@ function add_operation!(
418
431
throw (" Expression not recognized:\n $x " )
419
432
end
420
433
end
434
+ function add_operation! (
435
+ ls:: LoopSet , LHS_sym:: Symbol , RHS:: Expr , LHS_ref:: ArrayReference , elementbytes:: Int = 8
436
+ )
437
+ if RHS. head === :ref # || (RHS.head === :call && first(RHS.args) === :getindex)
438
+ add_load! (ls, LHS_sym, LHS_ref, elementbytes)
439
+ elseif RHS. head === :call
440
+ if first (RHS. args) === :getindex
441
+ add_load! (ls, LHS_sym, LHS_ref, elementbytes)
442
+ else
443
+ add_compute! (ls, LHS_sym, RHS, elementbytes, LHS_ref)
444
+ end
445
+ else
446
+ throw (" Expression not recognized:\n $x " )
447
+ end
448
+ end
421
449
function Base. push! (ls:: LoopSet , ex:: Expr , elementbytes:: Int = 8 )
422
450
if ex. head === :call
423
451
finex = first (ex. args):: Symbol
@@ -446,7 +474,9 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
446
474
ref = ArrayReference (LHS)
447
475
id = findfirst (r -> r == ref, ls. refs_aliasing_syms)
448
476
lrhs = id === nothing ? gensym (:RHS ) : ls. syms_aliasing_refs[id]
449
- add_operation! (ls, lrhs, RHS, elementbytes, ref)
477
+ # we pass ref, so it can compare references within RHS, and realize
478
+ # they equal lrhs
479
+ add_operation! (ls, lrhs, RHS, ref, elementbytes)
450
480
end
451
481
add_store_ref! (ls, lrhs, LHS, elementbytes)
452
482
else
0 commit comments