Skip to content

Commit ecc39aa

Browse files
committed
In presence of mixed inner and outer reductions, place all inner reductions in inner most loop.
1 parent d4e1873 commit ecc39aa

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

src/modeling/determinestrategy.jl

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1182,7 +1182,28 @@ struct LoopOrders
11821182
buff::Vector{Symbol}
11831183
end
11841184

1185-
function LoopOrders(ls::LoopSet)
1185+
function outer_reduct_loopordersplit(ls::LoopSet)
1186+
ops = operations(ls)
1187+
nonouterreducts = Int[]
1188+
for i eachindex(ops)
1189+
i ls.outer_reductions || push!(nonouterreducts, i)
1190+
end
1191+
reductsyms = Symbol[]
1192+
nonreductsyms = Symbol[]
1193+
for l ls.loopsymbols
1194+
isreduct = false
1195+
for opid nonouterreducts
1196+
if l reduceddependencies(ops[opid])
1197+
isreduct = true
1198+
push!(reductsyms, l)
1199+
break
1200+
end
1201+
end
1202+
isreduct || push!(nonreductsyms, l)
1203+
end
1204+
reductsyms, nonreductsyms
1205+
end
1206+
function loopordersplit(ls::LoopSet)
11861207
reductsyms = Symbol[]
11871208
nonreductsyms = Symbol[]
11881209
for l ls.loopsymbols
@@ -1196,6 +1217,14 @@ function LoopOrders(ls::LoopSet)
11961217
end
11971218
isreduct || push!(nonreductsyms, l)
11981219
end
1220+
reductsyms, nonreductsyms
1221+
end
1222+
function LoopOrders(ls::LoopSet)
1223+
if length(ls.outer_reductions) == 0
1224+
reductsyms, nonreductsyms = loopordersplit(ls)
1225+
else
1226+
reductsyms, nonreductsyms = outer_reduct_loopordersplit(ls)
1227+
end
11991228
LoopOrders(nonreductsyms, reductsyms, Vector{Symbol}(undef, length(ls.loopsymbols)))
12001229
end
12011230

0 commit comments

Comments
 (0)