Skip to content

Commit ce37b6b

Browse files
committed
Add hacky fix for a case of mistaken reduction, fixes #347.
1 parent 7a4875a commit ce37b6b

File tree

4 files changed

+173
-23
lines changed

4 files changed

+173
-23
lines changed

src/parse/add_compute.jl

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,47 @@ function add_anon_func!(ls::LoopSet, LHS::Symbol, f::Expr, ex::Expr, position::I
309309
end
310310
return retop
311311
end
312+
# TODO: DRY, this is similar to `find_samename_constparent` in `condense_loopset.jl`
313+
function find_inner_reduct_parent(op::Operation, opname::Symbol)
314+
for opp parents(op)
315+
(((isconstant(opp)) && (name(opp) === opname))) && return opp
316+
opptemp = find_samename_constparent(opp, opname)
317+
opptemp === opp || return opptemp
318+
end
319+
op
320+
end
321+
322+
function maybe_fix_reduced_deps!(ls::LoopSet, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, parent::Operation, mpref::ArrayReferenceMetaPosition, position::Int)
323+
loopdeps_parent = loopdependencies(parent)
324+
reduceddeps_parent = reduceddependencies(parent)
325+
loopdeps_mpref = loopdependencies(mpref)
326+
loopdeps_new = Symbol[]
327+
# pushv = Vector{Symbol}[loopdeps_new, reduceddeps_parent]
328+
instr = instruction(parent).instr
329+
pparent_id = findfirst(Base.Fix2(===,name(parent)) name, parents(parent))
330+
pparent_id === nothing && return deps, reduceddeps
331+
pparent = parents(parent)[pparent_id]
332+
@assert length(loopdependencies(pparent)) == length(loopdeps_parent) + length(reduceddeps_parent)
333+
reduceddeps_pparent = reduceddependencies(pparent)
334+
# if instr === :identity
335+
# push!(pushv,
336+
# if Base.sym_in(instr, :ident
337+
for ld loopdeps_parent
338+
if ld loopdeps_mpref
339+
push!(loopdeps_new, ld)
340+
else
341+
push!(reduceddeps_parent, ld)
342+
push!(reduceddeps_pparent, ld)
343+
# foreach(Base.Fix2(push!, ld), pushv)
344+
end
345+
end
346+
parent.dependencies = loopdeps_new
347+
reduct_init = find_inner_reduct_parent(pparent, name(pparent))
348+
reduct_init.dependencies = loopdeps_new
349+
reduct_init.reduced_children = reduceddeps_pparent
350+
# ld = loopdependencies(mpref)
351+
return loopdeps_new, copy(reduceddeps_parent)
352+
end
312353
function add_compute!(
313354
ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int, position::Int,
314355
mpref::Union{Nothing,ArrayReferenceMetaPosition} = nothing
@@ -350,8 +391,8 @@ function add_compute!(
350391
pushparent!(vparents, deps, reduceddeps, add_load!(ls, argref, elementbytes))
351392
end
352393
else
353-
add_parent!(vparents, deps, reduceddeps, ls, arg, elementbytes, position)
354-
end
394+
add_parent!(vparents, deps, reduceddeps, ls, arg, elementbytes, position)
395+
end
355396
elseif arg ls.loopsymbols
356397
loopsymop = add_loopvalue!(ls, arg, elementbytes)
357398
pushparent!(vparents, deps, reduceddeps, loopsymop)
@@ -370,6 +411,10 @@ function add_compute!(
370411
mergesetv!(newreduceddeps, reduceddeps)
371412
deps = newloopdeps; reduceddeps = newreduceddeps
372413
end
414+
# fix for #347
415+
if (mpref nothing) && ((reduction_ind 0) & (mpref.varname === var) & (position > length(loopdependencies(mpref)))) && isone(length(vparents)) && (position == length(loopdependencies(only(vparents))))
416+
deps, reduceddeps = maybe_fix_reduced_deps!(ls, deps,reduceddeps, only(vparents), mpref, position)
417+
end
373418
# @show reduction, search_tree(vparents, var) ex var vparents mpref get(ls.opdict, var, nothing) search_tree_for_ref(ls, vparents, mpref, var) # relies on cycles being forbidden
374419
if reduction || search_tree(vparents, var)
375420
return add_reduction!(ls, var, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)

src/parse/add_stores.jl

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ function cse_store!(ls::LoopSet, op::Operation)
77
ls.operations[id] = op
88
ls.opdict[op.variable] = op
99
end
10+
1011
function add_store!(ls::LoopSet, op::Operation, add_pvar::Bool = !any(r -> r == op.ref, ls.refs_aliasing_syms))
1112
@assert isstore(op)
1213
if add_pvar
@@ -26,29 +27,29 @@ end
2627

2728

2829
function add_store!(
29-
ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementbytes::Int, parent = getop(ls, varname(mpref), mpref.loopdependencies, elementbytes)
30+
ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementbytes::Int, parent = getop(ls, varname(mpref), mpref.loopdependencies, elementbytes)
3031
)
31-
isload(parent) && return add_copystore!(ls, parent, mpref, elementbytes)
32-
vparents = mpref.parents
33-
ldref = mpref.loopdependencies
34-
reduceddeps = mpref.reduceddeps
35-
pvar = name(parent)
36-
id = length(ls.operations)
37-
# try to cse store, by replacing the previous one
38-
mref = mpref.mref
39-
add_pvar = true
40-
for opp operations(ls)
41-
if mref == opp.ref
42-
isstore(opp) && (id = opp.identifier)
43-
add_pvar = false
44-
break
45-
end
46-
# add_pvar &= (name(first(parents(opp))) != pvar)
32+
isload(parent) && return add_copystore!(ls, parent, mpref, elementbytes)
33+
vparents = mpref.parents
34+
ldref = mpref.loopdependencies
35+
reduceddeps = mpref.reduceddeps
36+
pvar = name(parent)
37+
id = length(ls.operations)
38+
# try to cse store, by replacing the previous one
39+
mref = mpref.mref
40+
add_pvar = true
41+
for opp operations(ls)
42+
if mref == opp.ref
43+
isstore(opp) && (id = opp.identifier)
44+
add_pvar = false
45+
break
4746
end
48-
pushfirst!(vparents, parent)
49-
update_deps!(ldref, reduceddeps, parent)
50-
op = Operation( id, name(mpref), elementbytes, :setindex!, memstore, mpref )
51-
add_store!(ls, op, add_pvar)
47+
# add_pvar &= (name(first(parents(opp))) != pvar)
48+
end
49+
pushfirst!(vparents, parent)
50+
update_deps!(ldref, reduceddeps, parent)
51+
op = Operation( id, name(mpref), elementbytes, :setindex!, memstore, mpref )
52+
add_store!(ls, op, add_pvar)
5253
end
5354

5455
function add_store!(

test/grouptests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ end
9797
@time include("steprange.jl")
9898

9999
@time include("gemm.jl")
100+
101+
@time include("inner_reductions.jl")
100102
end
101103

102104
end

test/inner_reductions.jl

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
2+
using LoopVectorization, Test
3+
4+
function reference_mul4!(target_arr, src, range_a, range_b, padded_axis_a, padded_axis_b)
5+
@inbounds @fastmath for a1i eachindex(range_a),
6+
a2i eachindex(range_a),
7+
b1i eachindex(range_b),
8+
b2i eachindex(range_b)
9+
a1 = range_a[a1i];
10+
a2 = range_a[a2i]
11+
b1 = range_b[b1i]; b2 = range_b[b2i]
12+
contribution = zero(eltype(target_arr))
13+
for i_a padded_axis_a, i_b padded_axis_b
14+
contribution += src[i_a, i_b] * src[i_a+a1,i_b+b1] * src[i_a+a2,i_b+b2]
15+
end
16+
target_arr[b1i, b2i] += contribution
17+
end
18+
end
19+
function mul4_turbo_v1!(target_arr, src, range_a, range_b, padded_axis_a, padded_axis_b)
20+
@turbo for a1i eachindex(range_a),
21+
a2i eachindex(range_a),
22+
b1i eachindex(range_b),
23+
b2i eachindex(range_b)
24+
a1 = range_a[a1i];
25+
a2 = range_a[a2i]
26+
b1 = range_b[b1i]; b2 = range_b[b2i]
27+
contribution = zero(eltype(target_arr))
28+
for i_a padded_axis_a, i_b padded_axis_b
29+
contribution += src[i_a, i_b] * src[i_a+a1,i_b+b1] * src[i_a+a2,i_b+b2]
30+
end
31+
target_arr[b1i, b2i] += contribution
32+
end
33+
target_arr
34+
end
35+
function mul4_turbo_v2!(target_arr, src, range_a, range_b, padded_axis_a, padded_axis_b)
36+
@turbo for a1i eachindex(range_a),
37+
a2i eachindex(range_a),
38+
b1i eachindex(range_b),
39+
b2i eachindex(range_b)
40+
a1 = range_a[a1i];
41+
a2 = range_a[a2i]
42+
b1 = range_b[b1i]; b2 = range_b[b2i]
43+
contribution = target_arr[b1i, b2i]
44+
for i_a padded_axis_a, i_b padded_axis_b
45+
contribution += src[i_a, i_b] * src[i_a+a1,i_b+b1] * src[i_a+a2,i_b+b2]
46+
end
47+
target_arr[b1i, b2i] = contribution
48+
end
49+
end
50+
function mul4_turbo_v3!(target_arr, src, range_a, range_b, padded_axis_a, padded_axis_b)
51+
@turbo for a1i eachindex(range_a),
52+
a2i eachindex(range_a),
53+
b1i eachindex(range_b),
54+
b2i eachindex(range_b),
55+
i_a padded_axis_a,
56+
i_b padded_axis_b
57+
a1 = range_a[a1i];
58+
a2 = range_a[a2i]
59+
b1 = range_b[b1i]; b2 = range_b[b2i]
60+
target_arr[b1i, b2i] += src[i_a, i_b] * src[i_a+a1,i_b+b1] * src[i_a+a2,i_b+b2]
61+
end
62+
end
63+
function mul4_turbo_v4!(target_arr, src, range_a, range_b, padded_axis_a, padded_axis_b)
64+
@turbo for b1i eachindex(range_b), b2i eachindex(range_b)
65+
b1 = range_b[b1i]; b2 = range_b[b2i]
66+
contribution = zero(eltype(target_arr))
67+
for i_a padded_axis_a, i_b padded_axis_b, a1i eachindex(range_a), a2i eachindex(range_a)
68+
a1 = range_a[a1i]; a2 = range_a[a2i]
69+
contribution += src[i_a, i_b] * src[i_a+a1,i_b+b1] * src[i_a+a2,i_b+b2]
70+
end
71+
target_arr[b1i, b2i] += contribution
72+
end
73+
end
74+
75+
@testset "Inner reductions" begin
76+
src = ones(19, 101)
77+
max_a = 7; max_b = 9
78+
range_a = -max_a:max_a
79+
range_b = -max_b:max_b
80+
81+
target_arr = zeros(Float64, length(range_b), length(range_b))
82+
83+
padded_axis_b = (first(axes(src,2)) .+ max_b):(last(axes(src,2)) - max_b)
84+
padded_axis_a = (first(axes(src,1)) .+ max_a):(last(axes(src,1)) - max_a)
85+
86+
target_ref = zero(target_arr);
87+
reference_mul4!(target_ref, src, range_a, range_b, padded_axis_a, padded_axis_b)
88+
89+
mul4_turbo_v1!(target_arr, src, range_a, range_b, padded_axis_a, padded_axis_b)
90+
@test target_arr target_ref
91+
target_arr .= 0;
92+
mul4_turbo_v2!(target_arr, src, range_a, range_b, padded_axis_a, padded_axis_b)
93+
@test target_arr target_ref
94+
target_arr .= 0;
95+
mul4_turbo_v3!(target_arr, src, range_a, range_b, padded_axis_a, padded_axis_b)
96+
@test target_arr target_ref
97+
target_arr .= 0;
98+
mul4_turbo_v4!(target_arr, src, range_a, range_b, padded_axis_a, padded_axis_b)
99+
@test target_arr target_ref
100+
101+
end
102+

0 commit comments

Comments
 (0)