115
115
116
116
function add_reduction_update_parent! (
117
117
vparents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet ,
118
- parent:: Operation , instr:: Symbol , directdependency :: Bool , elementbytes:: Int
118
+ parent:: Operation , instr:: Symbol , reduction_ind :: Int , elementbytes:: Int
119
119
)
120
120
var = name (parent)
121
121
isouterreduction = parent. instruction === LOOPCONSTANT
@@ -153,12 +153,10 @@ function add_reduction_update_parent!(
153
153
end
154
154
combineddeps = copy (deps); mergesetv! (combineddeps, reduceddeps)
155
155
# directdependency && pushparent!(vparents, deps, reduceddeps, reductinit)#parent) # deps and reduced deps will not be disjoint
156
- if directdependency
156
+ if reduction_ind > 0 # if is directdependency
157
+ insert! (vparents, reduction_ind, reductinit)
157
158
if instr ∈ (:- , :vsub! , :vsub , :/ , :vfdiv! , :vfidiv! )
158
- pushfirst! (vparents, reductinit)
159
159
update_deps! (deps, reduceddeps, reductinit)# parent) # deps and reduced deps will not be disjoint
160
- else
161
- push! (vparents, reductinit)
162
160
end
163
161
# elseif !isouterreduction
164
162
# substitute_op_in_parents!(vparents, reductinit, parent)
@@ -192,16 +190,16 @@ function add_compute!(
192
190
parents = Operation[]
193
191
deps = Symbol[]
194
192
reduceddeps = Symbol[]
195
- reduction = false
196
- for arg ∈ args
193
+ reduction_ind = 0
194
+ for (ind, arg) ∈ enumerate ( args)
197
195
if var === arg
198
- reduction = true
196
+ reduction_ind = ind
199
197
add_reduction! (parents, deps, reduceddeps, ls, arg, elementbytes)
200
198
elseif arg isa Expr
201
199
isref, argref = tryrefconvert (ls, arg, elementbytes)
202
200
if isref
203
201
if mpref == argref
204
- reduction = true
202
+ reduction_ind = ind
205
203
add_load! (ls, var, argref, elementbytes)
206
204
else
207
205
pushparent! (parents, deps, reduceddeps, add_load! (ls, gensym (:tempload ), argref, elementbytes))
@@ -216,6 +214,7 @@ function add_compute!(
216
214
add_parent! (parents, deps, reduceddeps, ls, arg, elementbytes, position)
217
215
end
218
216
end
217
+ reduction = reduction_ind > 0
219
218
if iszero (length (deps)) && reduction
220
219
loopnestview = view (ls. loopsymbols, 1 : position)
221
220
append! (deps, loopnestview)
@@ -231,11 +230,11 @@ function add_compute!(
231
230
parent = getop (ls, var, elementbytes)
232
231
setdiffv! (reduceddeps, deps, loopdependencies (parent))
233
232
if length (reduceddeps) == 0
234
- push ! (parents, parent)
233
+ insert ! (parents, reduction_ind , parent)
235
234
op = Operation (length (operations (ls)), var, elementbytes, instruction (ls,instr), compute, deps, reduceddeps, parents)
236
235
pushop! (ls, op, var)
237
236
else
238
- add_reduction_update_parent! (parents, deps, reduceddeps, ls, parent, instr, reduction , elementbytes)
237
+ add_reduction_update_parent! (parents, deps, reduceddeps, ls, parent, instr, reduction_ind , elementbytes)
239
238
end
240
239
else
241
240
op = Operation (length (operations (ls)), var, elementbytes, instruction (ls,instr), compute, deps, reduceddeps, parents)
0 commit comments