@@ -21,16 +21,23 @@ function mergesetdiffv!(
21
21
end
22
22
nothing
23
23
end
24
+ # Everything in arg2 (s1) that isn't in arg3 (s2) is added to arg1 (s3)
24
25
function setdiffv! (s3:: AbstractVector{T} , s1:: AbstractVector{T} , s2:: AbstractVector{T} ) where {T}
25
26
for s ∈ s1
26
27
(s ∈ s2) || (s ∉ s3 && push! (s3, s))
27
28
end
28
29
end
30
+ function setdiffv! (s4:: AbstractVector{T} , s3:: AbstractVector{T} , s1:: AbstractVector{T} , s2:: AbstractVector{T} ) where {T}
31
+ for s ∈ s1
32
+ (s ∈ s2) ? (s ∉ s4 && push! (s4, s)) : (s ∉ s3 && push! (s3, s))
33
+ end
34
+ end
29
35
function update_deps! (deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , parent:: Operation )
30
- mergesetdiffv ! (deps, loopdependencies (parent), reduceddependencies (parent))
36
+ mergesetv ! (deps, loopdependencies (parent)) # , reduceddependencies(parent))
31
37
if ! (isload (parent) || isconstant (parent)) && parent. instruction. instr ∉ (:reduced_add , :reduced_prod , :reduce_to_add , :reduce_to_prod )
32
38
mergesetv! (reduceddeps, reduceddependencies (parent))
33
39
end
40
+ #
34
41
nothing
35
42
end
36
43
@@ -42,19 +49,19 @@ function pushparent!(mpref::ArrayReferenceMetaPosition, parent::Operation)
42
49
pushparent! (mpref. parents, mpref. loopdependencies, mpref. reduceddeps, parent)
43
50
end
44
51
function add_parent! (
45
- parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet , var, elementbytes:: Int = 8
52
+ parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet , var, elementbytes:: Int , position :: Int
46
53
)
47
54
parent = if var isa Symbol
48
55
getop (ls, var, elementbytes)
49
56
elseif var isa Expr # CSE candidate
50
- add_operation! (ls, gensym (:temporary ), var, elementbytes)
57
+ add_operation! (ls, gensym (:temporary ), var, elementbytes, position )
51
58
else # assumed constant
52
59
add_constant! (ls, var, elementbytes)
53
60
end
54
61
pushparent! (parents, deps, reduceddeps, parent)
55
62
end
56
63
function add_reduction! (
57
- parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet , var:: Symbol , elementbytes:: Int = 8
64
+ parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet , var:: Symbol , elementbytes:: Int
58
65
)
59
66
get! (ls. opdict, var) do
60
67
add_constant! (ls, var, elementbytes)
@@ -80,10 +87,10 @@ function update_reduction_status!(parentvec::Vector{Operation}, deps::Vector{Sym
80
87
end
81
88
end
82
89
function add_reduction_update_parent! (
83
- parents :: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet ,
84
- var :: Symbol , instr:: Symbol , directdependency:: Bool , elementbytes:: Int = 8
90
+ vparents :: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet ,
91
+ parent :: Operation , instr:: Symbol , directdependency:: Bool , elementbytes:: Int
85
92
)
86
- parent = getop (ls, var, elementbytes )
93
+ var = name (parent )
87
94
isouterreduction = parent. instruction === LOOPCONSTANT
88
95
Instr = instruction (ls, instr)
89
96
instrclass = reduction_instruction_class (Instr) # key allows for faster lookups
@@ -110,27 +117,27 @@ function add_reduction_update_parent!(
110
117
reductsym = var
111
118
reductcombine = Symbol (" " )
112
119
end
113
- setdiffv! (reduceddeps, deps, loopdependencies (reductinit))
114
120
combineddeps = copy (deps); mergesetv! (combineddeps, reduceddeps)
115
- directdependency && pushparent! (parents , deps, reduceddeps, reductinit)# parent) # deps and reduced deps will not be disjoint
116
- update_reduction_status! (parents , combineddeps, name (reductinit))
121
+ directdependency && pushparent! (vparents , deps, reduceddeps, reductinit)# parent) # deps and reduced deps will not be disjoint
122
+ update_reduction_status! (vparents , combineddeps, name (reductinit))
117
123
# this is the op added by add_compute
118
- op = Operation (length (operations (ls)), reductsym, elementbytes, instr, compute, deps, reduceddeps, parents )
124
+ op = Operation (length (operations (ls)), reductsym, elementbytes, instr, compute, deps, reduceddeps, vparents )
119
125
parent. instruction === LOOPCONSTANT && push! (ls. outer_reductions, identifier (op))
120
126
opout = pushop! (ls, op, var) # note this overwrites the entry in the operations dict, but not the vector
127
+ # isouterreduction || iszero(length(reduceddeps)) && return opout
121
128
isouterreduction && return opout
122
129
# create child op, which is the reduction combination
123
- childdeps = Symbol[]; childrdeps = Symbol[]; childparents = Operation[]
124
- pushparent! (childparents, childdeps, childrdeps, op) # reduce op
125
- pushparent! (childparents, childdeps, childrdeps, parent) # to
130
+ childrdeps = Symbol[]; childparents = Operation[ op, parent ]
131
+ childdeps = loopdependencies (reductinit)
132
+ setdiffv! ( childrdeps, loopdependencies (op), childdeps)
126
133
child = Operation (
127
134
length (operations (ls)), name (parent), elementbytes, reductcombine, compute, childdeps, childrdeps, childparents
128
135
)
129
136
pushop! (ls, child, name (parent))
130
137
opout
131
138
end
132
139
function add_compute! (
133
- ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 ,
140
+ ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int , position :: Int ,
134
141
mpref:: Union{Nothing,ArrayReferenceMetaPosition} = nothing
135
142
)
136
143
@assert ex. head === :call
@@ -149,12 +156,12 @@ function add_compute!(
149
156
if isref
150
157
if mpref == argref
151
158
reduction = true
152
- add_load! (ls, var, mpref , elementbytes)
159
+ add_load! (ls, var, argref , elementbytes)
153
160
else
154
161
pushparent! (parents, deps, reduceddeps, add_load! (ls, gensym (:tempload ), argref, elementbytes))
155
162
end
156
163
else
157
- add_parent! (parents, deps, reduceddeps, ls, arg, elementbytes)
164
+ add_parent! (parents, deps, reduceddeps, ls, arg, elementbytes, position )
158
165
end
159
166
elseif arg ∈ ls. loopsymbols
160
167
loopsym = gensym (arg)
@@ -164,11 +171,30 @@ function add_compute!(
164
171
push! (ls. refs_aliasing_syms, loopsymop. ref)
165
172
pushparent! (parents, deps, reduceddeps, loopsymop)
166
173
else
167
- add_parent! (parents, deps, reduceddeps, ls, arg, elementbytes)
174
+ add_parent! (parents, deps, reduceddeps, ls, arg, elementbytes, position )
168
175
end
169
176
end
177
+ if iszero (length (deps)) && reduction
178
+ loopnestview = view (ls. loopsymbols, 1 : position)
179
+ append! (deps, loopnestview)
180
+ append! (reduceddeps, loopnestview)
181
+ else
182
+ loopnestview = view (ls. loopsymbols, 1 : position)
183
+ newloopdeps = Symbol[]; newreduceddeps = Symbol[];
184
+ setdiffv! (newloopdeps, newreduceddeps, deps, loopnestview)
185
+ mergesetv! (newreduceddeps, reduceddeps)
186
+ deps = newloopdeps; reduceddeps = newreduceddeps
187
+ end
170
188
if reduction || search_tree (parents, var)
171
- add_reduction_update_parent! (parents, deps, reduceddeps, ls, var, instr, reduction, elementbytes)
189
+ parent = getop (ls, var, elementbytes)
190
+ setdiffv! (reduceddeps, deps, loopdependencies (parent))
191
+ if length (reduceddeps) == 0
192
+ push! (parents, parent)
193
+ op = Operation (length (operations (ls)), var, elementbytes, instruction (ls,instr), compute, deps, reduceddeps, parents)
194
+ pushop! (ls, op, var)
195
+ else
196
+ add_reduction_update_parent! (parents, deps, reduceddeps, ls, parent, instr, reduction, elementbytes)
197
+ end
172
198
else
173
199
op = Operation (length (operations (ls)), var, elementbytes, instruction (ls,instr), compute, deps, reduceddeps, parents)
174
200
pushop! (ls, op, var)
0 commit comments