Skip to content

Commit 8f58ef4

Browse files
committed
Retain arg order for reductions.
1 parent bd794dd commit 8f58ef4

File tree

1 file changed

+10
-11
lines changed

1 file changed

+10
-11
lines changed

src/add_compute.jl

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ end
115115

116116
function add_reduction_update_parent!(
117117
vparents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet,
118-
parent::Operation, instr::Symbol, directdependency::Bool, elementbytes::Int
118+
parent::Operation, instr::Symbol, reduction_ind::Int, elementbytes::Int
119119
)
120120
var = name(parent)
121121
isouterreduction = parent.instruction === LOOPCONSTANT
@@ -153,12 +153,10 @@ function add_reduction_update_parent!(
153153
end
154154
combineddeps = copy(deps); mergesetv!(combineddeps, reduceddeps)
155155
# directdependency && pushparent!(vparents, deps, reduceddeps, reductinit)#parent) # deps and reduced deps will not be disjoint
156-
if directdependency
156+
if reduction_ind > 0 # if is directdependency
157+
insert!(vparents, reduction_ind, reductinit)
157158
if instr (:-, :vsub!, :vsub, :/, :vfdiv!, :vfidiv!)
158-
pushfirst!(vparents, reductinit)
159159
update_deps!(deps, reduceddeps, reductinit)#parent) # deps and reduced deps will not be disjoint
160-
else
161-
push!(vparents, reductinit)
162160
end
163161
# elseif !isouterreduction
164162
# substitute_op_in_parents!(vparents, reductinit, parent)
@@ -192,16 +190,16 @@ function add_compute!(
192190
parents = Operation[]
193191
deps = Symbol[]
194192
reduceddeps = Symbol[]
195-
reduction = false
196-
for arg args
193+
reduction_ind = 0
194+
for (ind,arg) enumerate(args)
197195
if var === arg
198-
reduction = true
196+
reduction_ind = ind
199197
add_reduction!(parents, deps, reduceddeps, ls, arg, elementbytes)
200198
elseif arg isa Expr
201199
isref, argref = tryrefconvert(ls, arg, elementbytes)
202200
if isref
203201
if mpref == argref
204-
reduction = true
202+
reduction_ind = ind
205203
add_load!(ls, var, argref, elementbytes)
206204
else
207205
pushparent!(parents, deps, reduceddeps, add_load!(ls, gensym(:tempload), argref, elementbytes))
@@ -216,6 +214,7 @@ function add_compute!(
216214
add_parent!(parents, deps, reduceddeps, ls, arg, elementbytes, position)
217215
end
218216
end
217+
reduction = reduction_ind > 0
219218
if iszero(length(deps)) && reduction
220219
loopnestview = view(ls.loopsymbols, 1:position)
221220
append!(deps, loopnestview)
@@ -231,11 +230,11 @@ function add_compute!(
231230
parent = getop(ls, var, elementbytes)
232231
setdiffv!(reduceddeps, deps, loopdependencies(parent))
233232
if length(reduceddeps) == 0
234-
push!(parents, parent)
233+
insert!(parents, reduction_ind, parent)
235234
op = Operation(length(operations(ls)), var, elementbytes, instruction(ls,instr), compute, deps, reduceddeps, parents)
236235
pushop!(ls, op, var)
237236
else
238-
add_reduction_update_parent!(parents, deps, reduceddeps, ls, parent, instr, reduction, elementbytes)
237+
add_reduction_update_parent!(parents, deps, reduceddeps, ls, parent, instr, reduction_ind, elementbytes)
239238
end
240239
else
241240
op = Operation(length(operations(ls)), var, elementbytes, instruction(ls,instr), compute, deps, reduceddeps, parents)

0 commit comments

Comments
 (0)