Skip to content

Commit b7b508d

Browse files
authored
Opt out of reduction search in some cases, fixes #288 (#289)
1 parent eb109f8 commit b7b508d

File tree

3 files changed

+61
-9
lines changed

3 files changed

+61
-9
lines changed

src/parse/add_compute.jl

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -307,19 +307,17 @@ function add_compute!(
307307
deps = newloopdeps; reduceddeps = newreduceddeps
308308
end
309309
# @show reduction, search_tree(vparents, var) ex var vparents mpref get(ls.opdict, var, nothing) search_tree_for_ref(ls, vparents, mpref, var) # relies on cycles being forbidden
310-
op = if reduction || search_tree(vparents, var)
311-
add_reduction!(ls, var, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
310+
if reduction || search_tree(vparents, var)
311+
return add_reduction!(ls, var, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
312312
else
313-
var, found = search_tree_for_ref(ls, vparents, mpref, var)
314-
if found
315-
add_reduction!(ls, var, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
316-
else
317-
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, vparents)
318-
pushop!(ls, op, var)
313+
if mpref nothing && ((length(loopdependencies(mpref)) < position) | (length(reduceddependencies(mpref)) > 0))
314+
var, found = search_tree_for_ref(ls, vparents, mpref, var)
315+
found && return add_reduction!(ls, var, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
319316
end
317+
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, vparents)
318+
return pushop!(ls, op, var)
320319
end
321320
# maybe_const_compute!(ls, op, elementbytes, position)
322-
op
323321
end
324322
function add_reduction!(ls::LoopSet, var::Symbol, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
325323
parent = ls.opdict[var]

test/reduction_untangling.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
2+
# issue 288
3+
function not_a_reduction!(A, B)
4+
@turbo for j 1:size(A.re, 2)
5+
jre = B.re[j]
6+
jim = B.im[j]
7+
for i 1:size(A.re, 1)
8+
ire = B.re[i]
9+
iim = B.im[i]
10+
cisim = iim * jre - ire * jim
11+
cisre = ire * jre + iim * jim
12+
ρre_i = A.re[i,j]
13+
ρim_i = A.im[i,j]
14+
re_out = ρre_i * cisre - ρim_i * cisim
15+
im_out = ρre_i * cisim + ρim_i * cisre
16+
A.re[i,j] = re_out
17+
A.im[i,j] = im_out
18+
end
19+
end
20+
return nothing
21+
end
22+
function not_a_reduction_noturbo!(A, B)
23+
@turbo for j 1:size(A.re, 2)
24+
jre = B.re[j]
25+
jim = B.im[j]
26+
for i 1:size(A.re, 1)
27+
ire = B.re[i]
28+
iim = B.im[i]
29+
cisim = iim * jre - ire * jim
30+
cisre = ire * jre + iim * jim
31+
ρre_i = A.re[i,j]
32+
ρim_i = A.im[i,j]
33+
re_out = ρre_i * cisre - ρim_i * cisim
34+
im_out = ρre_i * cisim + ρim_i * cisre
35+
A.re[i,j] = re_out
36+
A.im[i,j] = im_out
37+
end
38+
end
39+
return nothing
40+
end
41+
42+
@testset "Untangle reductions" begin
43+
N = 11
44+
A1 = (re = rand(N,N), im = rand(N,N))
45+
A2 = deepcopy(A1)
46+
B = (re = rand(N), im = rand(N))
47+
not_a_reduction!(A1, B)
48+
not_a_reduction_noturbo!(A2, B)
49+
@test A1.re A2.re
50+
@test A1.im A2.im
51+
end
52+

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ const START_TIME = time()
5656
@time include("special.jl")
5757

5858
@time include("multiassignments.jl")
59+
60+
@time include("reduction_untangling.jl")
5961
end
6062

6163
@time if LOOPVECTORIZATION_TEST == "all" || LOOPVECTORIZATION_TEST == "part2"

0 commit comments

Comments
 (0)