Skip to content

Commit 58db05b

Browse files
committed
Updating loopsets to better represent aliasing of certain symbols and memory locations, and prep to allow indices to have constant integers.
1 parent 2307d26 commit 58db05b

File tree

6 files changed

+281
-23
lines changed

6 files changed

+281
-23
lines changed

src/constructors.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,13 @@ macro avx(q)
1818
esc(lower(LoopSet(q)))
1919
end
2020

21+
#=
22+
@generated function vmaterialize(
23+
dest::AbstractArray{T,N}, bc::BC
24+
) where {T,N,BC <: Base.Broadcast.Broadcasted}
25+
# we have an N dimensional loop.
26+
# need to construct the LoopSet
27+
28+
end
29+
=#
30+

src/determinestrategy.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

22
# TODO: FIXME for general case
3+
# wrong for transposed matrices, and certain views/SubArrays.
34
unitstride(op, s) = first(loopdependencies(op)) === s
45

56
function cost(op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int = op.elementbytes)

src/graphs.jl

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,12 @@ struct LoopSet
8282
# stridesets::Dict{ShortVector{Symbol},ShortVector{Symbol}}
8383
preamble::Expr # TODO: add preamble to lowering
8484
includedarrays::Vector{Tuple{Symbol,Int}}
85+
syms_aliasing_refs::Vector{Symbol} # O(N) search is faster at small sizes
86+
refs_aliasing_syms::Vector{ArrayReference}
87+
# sym_to_ref_aliases::Dict{Symbol,ArrayReference}
88+
# ref_to_sym_aliases::Dict{ArrayReference,Symbol}
8589
end
90+
8691
function includesarray(ls::LoopSet, array::Symbol)
8792
for (a,i) ls.includedarrays
8893
a === array && return i
@@ -97,7 +102,11 @@ function LoopSet()
97102
Int[],
98103
LoopOrder(),
99104
Expr(:block,),
100-
Tuple{Symbol,Int}[]
105+
Tuple{Symbol,Int}[],
106+
Symbol[],
107+
ArrayReference[]
108+
# Dict{Symbol,ArrayReference}()
109+
# Dict{ArrayReference,Symbol}()
101110
)
102111
end
103112
num_loops(ls::LoopSet) = length(ls.loops)
@@ -209,17 +218,33 @@ function add_vptr!(ls::LoopSet, indexed::Symbol, id::Int)
209218
end
210219

211220
function add_load!(
212-
ls::LoopSet, var::Symbol, indexed::Symbol, indices::AbstractVector, elementbytes::Int = 8
221+
ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int = 8
213222
)
214-
op = Operation( length(operations(ls)), var, elementbytes, :getindex, memload, indices, [indexed], NOPARENTS )
223+
if ref.loaded[] == true
224+
op = getop(ls, var)
225+
@assert var === op.variable
226+
return op
227+
end
228+
push!(ls.syms_aliasing_refs, var)
229+
push!(ls.refs_aliasing_syms, ref)
230+
ref.loaded[] = true
231+
# ls.sym_to_ref_aliases[ var ] = ref
232+
# ls.ref_to_sym_aliases[ ref ] = var
233+
op = Operation(
234+
length(operations(ls)), var, elementbytes,
235+
:getindex, memload, loopdependencies(ref),
236+
NODEPENDENCY, NOPARENTS, ref
237+
)
215238
add_vptr!(ls, indexed, identifier(op))
216239
pushop!(ls, op, var)
217240
end
218241
function add_load_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
219-
add_load!(ls, var, ex.args[1], @view(ex.args[2:end]), elementbytes)
242+
ref = ref_from_ref(ex)
243+
add_load!(ls, var, ref, elementbytes)
220244
end
221245
function add_load_getindex!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
222-
add_load!(ls, var, ex.args[2], @view(ex.args[3:end]), elementbytes)
246+
ref = ref_from_getindex(ex)
247+
add_load!(ls, var, ref, elementbytes)
223248
end
224249
function instruction(x)
225250
x isa Symbol ? x : last(x.args).value
@@ -274,20 +299,31 @@ function pushparent!(parents::Vector{Operation}, deps::Vector{Symbol}, reducedde
274299
end
275300
function maybe_cse_load!(ls::LoopSet, expr::Expr, elementbytes::Int = 8)
276301
if expr.head === :ref
277-
array = first(expr.args)::Symbol
278-
args = @view expr.args[2:end]
302+
offset = 0
303+
# array = first(expr.args)::Symbol
304+
# args = @view expr.args[2:end]
305+
# ref = ref_from_ref(expr)
279306
elseif expr.head === :call && first(expr.args) === :getindex
280-
array = (expr.args[2])::Symbol
281-
args = @view expr.args[3:end]
307+
offset = 1
308+
# array = (expr.args[2])::Symbol
309+
# args = @view expr.args[3:end]
310+
# ref = ref_from_getindex(expr)
282311
else
283312
return add_operation!(ls, gensym(:temporary), expr, elementbytes)
284313
end
285-
id = includesarray(ls, array)
286-
if id > 0
287-
ls.operations[id]
288-
else
314+
ref = ArrayReference( ex.args[1+offset], @view(ex.args[2+offset:end]) )::ArrayReference
315+
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
316+
if id === nothing
289317
add_load!( ls, gensym(:temporary), array, args, elementbytes )
318+
else
319+
ls.syms_aliasing_refs[id]
290320
end
321+
# id = includesarray(ls, array)
322+
# if id > 0
323+
# ls.operations[id]
324+
# else
325+
# add_load!( ls, gensym(:temporary), array, args, elementbytes )
326+
# end
291327
end
292328
function add_parent!(
293329
parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet, var, elementbytes::Int = 8
@@ -325,7 +361,7 @@ function add_reduction_update_parent!(
325361
parent.instruction === Symbol("##CONSTANT##") && push!(ls.outer_reductions, identifier(op))
326362
pushop!(ls, op, var) # note this overwrites the entry in the operations dict, but not the vector
327363
end
328-
function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
364+
function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8, ref = nothing)
329365
@assert ex.head === :call
330366
instr = instruction(first(ex.args))::Symbol
331367
args = @view(ex.args[2:end])
@@ -338,6 +374,9 @@ function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
338374
if arg === var
339375
reduction = true
340376
add_reduction!(parents, deps, reduceddeps, ls, arg, elementbytes)
377+
elseif ref == arg
378+
reduction = true
379+
add_load!(ls, var, ref, elementbytes)
341380
else
342381
add_parent!(parents, deps, reduceddeps, ls, arg, elementbytes)
343382
end
@@ -402,9 +441,12 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
402441
if RHS isa Symbol
403442
lrhs = RHS
404443
elseif RHS isa Expr
444+
# need to check of LHS appears in RHS
405445
# assign RHS to lrhs
406-
lrhs = gensym(:RHS)
407-
add_operation!(ls, lrhs, RHS, elementbytes)
446+
ref = ArrayReference(LHS)
447+
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
448+
lrhs = id === nothing ? gensym(:RHS) : ls.syms_aliasing_refs[id]
449+
add_operation!(ls, lrhs, RHS, elementbytes, ref)
408450
end
409451
add_store_ref!(ls, lrhs, LHS, elementbytes)
410452
else

src/lowering.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ end
522522
function gc_preserve(ls::LoopSet, q::Expr)
523523
length(ls.includedarrays) == 0 && return q # is this even possible?
524524
gcp = Expr(:macrocall, Expr(:(.), :GC, QuoteNode(Symbol("@preserve"))), LineNumberNode(@__LINE__, @__FILE__))
525-
for (array,i) ls.includedarrays
525+
for (array,_) ls.includedarrays
526526
push!(gcp.args, array)
527527
end
528528
push!(q.args, nothing)
@@ -697,4 +697,5 @@ function lower(ls::LoopSet)
697697
end
698698

699699
Base.convert(::Type{Expr}, ls::LoopSet) = lower(ls)
700+
Base.show(io::IO, ls::LoopSet) = println(io, lower(ls))
700701

src/operations.jl

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,52 @@
1+
struct ArrayReference
2+
array::Symbol
3+
ref::Vector{Union{Symbol,Int}}
4+
loaded::Base.RefValue{Bool}
5+
end
6+
function Base.hash(x::ArrayReference, h::UInt)
7+
@inbounds for n eachindex(x)
8+
h = hash(x.ref[n], h)
9+
end
10+
hash(x.array, h)
11+
end
12+
loopdependencies(ref::ArrayReference) = filter(i -> i isa Symbol, ref.ref)
13+
function Base.isequal(x::ArrayReference, y::ArrayReference)
14+
x.array === y.array || return false
15+
nrefs = length(x.ref)
16+
nrefs == length(y.ref) || return false
17+
all(n -> x.ref[n] === y.ref[n], 1:nrefs)
18+
# for n ∈ 1:nrefs
19+
# x.ref[n] === y.ref[n] || return false
20+
# end
21+
# true
22+
end
23+
24+
Base.:(==)(x::ArrayReference, y::ArrayReference) = isequal(x, y)
25+
26+
function ref_from_ref(ex::Expr)
27+
ArrayReference( ex.args[1], @view(ex.args[2:end]), Ref(false) )
28+
end
29+
function ref_from_getindex(ex::Expr)
30+
ArrayReference( ex.args[2], @view(ex.args[3:end]), Ref(false) )
31+
end
32+
function ArrayReference(ex::Expr)
33+
ex.head === :ref ? ref_from_ref(ex) : ref_from_getindex(ex)
34+
end
35+
function Base.:(==)(x::ArrayReference, y::Expr)
36+
if y.head === :ref
37+
isequal(x, ref_from_ref(y))
38+
elseif y.head === :call && first(y.args) === :getindex
39+
isequal(x, ref_from_getindex(y))
40+
else
41+
false
42+
end
43+
end
44+
Base.:(==)(x::ArrayReference, y) = false
45+
46+
47+
48+
# Avoid memory allocations by accessing this
49+
const NOTAREFERENCE = ArrayReference(Symbol(""), Union{Symbol,Int}[])
150

251
@enum OperationType begin
352
constant
@@ -17,8 +66,7 @@ symbolic metadata contains info on direct dependencies / placement within loop.
1766
if isload(op) -> Symbol(:vptr_, first(op.reduced_deps))
1867
if istore(op) -> Symbol(:vptr_, op.variable)
1968
is how we access the memory.
20-
If numerical_metadata[i] == -1
21-
Symbol(:stride_, op.variable, :_, op.symbolic_metadata[i])
69+
2270
is the stride for loop index
2371
symbolic_metadata[i]
2472
"""
@@ -28,9 +76,10 @@ struct Operation
2876
elementbytes::Int
2977
instruction::Symbol
3078
node_type::OperationType
31-
dependencies::Vector{Symbol}
79+
dependencies::Vector{Symbol}#::Vector{Symbol}
3280
reduced_deps::Vector{Symbol}
3381
parents::Vector{Operation}
82+
ref::ArrayReference
3483
# children::Vector{Operation}
3584
# numerical_metadata::Vector{Int} # stride of -1 indicates dynamic
3685
# symbolic_metadata::Vector{Symbol}
@@ -42,19 +91,21 @@ struct Operation
4291
node_type,
4392
dependencies = Symbol[],
4493
reduced_deps = Symbol[],
45-
parents = Operation[]
94+
parents = Operation[],
95+
ref::ArrayReference = NOTAREFERENCE
4696
)
4797
new(
4898
identifier, variable, elementbytes, instruction, node_type,
4999
convert(Vector{Symbol},dependencies),
50100
convert(Vector{Symbol},reduced_deps),
51-
convert(Vector{Operation},parents)
101+
convert(Vector{Operation},parents),
102+
ref
52103
)
53104
end
54105
end
55106

56107
# negligible save on allocations for operations that don't need these (eg, constants).
57-
const NODEPENDENCY = Symbol[]
108+
const NODEPENDENCY = Union{Symbol,Int}[]
58109
const NOPARENTS = Operation[]
59110

60111

0 commit comments

Comments
 (0)