Skip to content

Commit fe94ecf

Browse files
committed
Search for parent refs as well as parent ops, also improve muladd-finder.
1 parent ddca686 commit fe94ecf

File tree

3 files changed

+139
-33
lines changed

3 files changed

+139
-33
lines changed

src/add_compute.jl

Lines changed: 67 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,52 @@ function search_tree(opv::Vector{Operation}, var::Symbol) # relies on cycles bei
8686
end
8787
false
8888
end
89+
90+
function update_for_ref_reduction!()
91+
if varname(mpref) === var
92+
id = findfirst(r -> r == mpref.mref, ls.refs_aliasing_syms)
93+
mpref.varname = var = isnothing(id) ? var : ls.syms_aliasing_refs[id]
94+
reduction_ind = ind
95+
mergesetv!(deps, loopdependencies(add_load!(ls, argref, elementbytes)))
96+
else
97+
pushparent!(vparents, deps, reduceddeps, add_load!(ls, argref, elementbytes))
98+
end
99+
end
100+
search_tree_for_ref(ls::LoopSet, opv::Vector{Operation}, ::Nothing, var::Symbol) = var, false
101+
function search_tree_for_ref(ls::LoopSet, opv::Vector{Operation}, mpref::ArrayReferenceMetaPosition, var::Symbol) # relies on cycles being forbidden
102+
# isref, argref = tryrefconvert(ls, arg, elementbytes, varname(mpref))
103+
# if isref
104+
# if mpref == argref
105+
# if varname(mpref) === var
106+
# id = findfirst(r -> r == mpref.mref, ls.refs_aliasing_syms)
107+
# mpref.varname = var = isnothing(id) ? var : ls.syms_aliasing_refs[id]
108+
# reduction_ind = ind
109+
# mergesetv!(deps, loopdependencies(add_load!(ls, argref, elementbytes)))
110+
# else
111+
# pushparent!(vparents, deps, reduceddeps, add_load!(ls, argref, elementbytes))
112+
# end
113+
# else
114+
# argref.varname = gensym!(ls, "tempload")
115+
# pushparent!(vparents, deps, reduceddeps, add_load!(ls, argref, elementbytes))
116+
# end
117+
# else
118+
# add_parent!(vparents, deps, reduceddeps, ls, arg, elementbytes, position)
119+
# end
120+
for opp opv
121+
if opp.ref == mpref.mref
122+
if varname(mpref) === var
123+
id = findfirst(r -> r == mpref.mref, ls.refs_aliasing_syms)
124+
# @show var = isnothing(id) ? var : ls.syms_aliasing_refs[id]
125+
mpref.varname = var = isnothing(id) ? var : ls.syms_aliasing_refs[id]
126+
# @show mpref.varname
127+
return var, true
128+
end
129+
end
130+
var, found = search_tree_for_ref(ls, parents(opp), mpref, var)
131+
found && return (var, found)
132+
end
133+
var, false
134+
end
89135
function search_tree(opv::Vector{Operation}, var::Operation) # relies on cycles being forbidden
90136
for opp opv
91137
opp === var && return true
@@ -281,26 +327,35 @@ function add_compute!(
281327
mergesetv!(newreduceddeps, reduceddeps)
282328
deps = newloopdeps; reduceddeps = newreduceddeps
283329
end
330+
# @show reduction, search_tree(vparents, var) ex var vparents mpref get(ls.opdict, var, nothing) search_tree_for_ref(ls, vparents, mpref, var) # relies on cycles being forbidden
284331
op = if reduction || search_tree(vparents, var)
285-
parent = ls.opdict[var]
286-
setdiffv!(reduceddeps, deps, loopdependencies(parent))
287-
# parent = getop(ls, var, elementbytes)
288-
# if length(reduceddeps) == 0
289-
if all(!in(deps), reduceddeps)
290-
insert!(vparents, reduction_ind, parent)
291-
mergesetv!(deps, loopdependencies(parent))
332+
add_reduction!(ls, var, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
333+
else
334+
var, found = search_tree_for_ref(ls, vparents, mpref, var)
335+
if found
336+
add_reduction!(ls, var, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
337+
else
292338
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, vparents)
293339
pushop!(ls, op, var)
294-
else
295-
add_reduction_update_parent!(vparents, deps, reduceddeps, ls, parent, instr, reduction_ind, elementbytes)
296340
end
297-
else
298-
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, vparents)
299-
pushop!(ls, op, var)
300341
end
301342
# maybe_const_compute!(ls, op, elementbytes, position)
302343
op
303344
end
345+
function add_reduction!(ls::LoopSet, var::Symbol, reduceddeps, deps, vparents, reduction_ind, elementbytes, instr)
346+
parent = ls.opdict[var]
347+
setdiffv!(reduceddeps, deps, loopdependencies(parent))
348+
# parent = getop(ls, var, elementbytes)
349+
# if length(reduceddeps) == 0
350+
if all(!in(deps), reduceddeps)
351+
insert!(vparents, reduction_ind, parent)
352+
mergesetv!(deps, loopdependencies(parent))
353+
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, vparents)
354+
pushop!(ls, op, var)
355+
else
356+
add_reduction_update_parent!(vparents, deps, reduceddeps, ls, parent, instr, reduction_ind, elementbytes)
357+
end
358+
end
304359

305360
function add_compute!(
306361
ls::LoopSet, LHS::Symbol, instr, vparents::Vector{Operation}, elementbytes::Int

src/vectorizationbase_compat/contract_pass.jl

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ function mulexprcost(ex::Expr)
55
base = ex.head === :call ? 10 : 1
66
base + length(ex.args)
77
end
8+
function mul_fast_expr(args)
9+
b = Expr(:call, :mul_fast)
10+
for i 2:length(args)
11+
push!(b.args, args[i])
12+
end
13+
b
14+
end
815
function mulexpr(mulexargs)
916
a = (mulexargs[1])::Union{Symbol,Expr,Number}
1017
if length(mulexargs) == 2
@@ -25,17 +32,17 @@ function mulexpr(mulexargs)
2532
return (c, Expr(:call, :mul_fast, a, b))
2633
end
2734
else
28-
return (a, Expr(:call, :mul_fast, @view(mulexargs[2:end])...)::Expr)
35+
return (a, mul_fast_expr(mulexargs))
2936
end
3037
a = (mulexargs[1])::Union{Symbol,Expr,Number}
3138
b = if length(mulexargs) == 2 # two arg mul
3239
(mulexargs[2])::Union{Symbol,Expr,Number}
3340
else
34-
Expr(:call, :mul_fast, @view(mulexargs[2:end])...)::Expr
41+
mul_fast_expr(mulexargs)
3542
end
3643
a, b
3744
end
38-
function append_args_skip!(call, args, i)
45+
function append_args_skip!(call, args, i, mod)
3946
for j eachindex(args)
4047
j == i && continue
4148
push!(call.args, args[j])
@@ -44,18 +51,29 @@ function append_args_skip!(call, args, i)
4451
end
4552

4653
fastfunc(f) = get(VectorizationBase.FASTDICT, f, f)
47-
function make_fast!(call::Expr)
48-
call.args[1] = fastfunc(first(call.args))
49-
nothing
54+
function muladd_arguments!(argv, mod, f = first(argv))
55+
if f === :*
56+
argv[1] = :mul_fast
57+
else
58+
argv[1] = fastfunc(f)
59+
end
60+
for i 2:length(argv)
61+
a = argv[i]
62+
a isa Expr || continue
63+
argv[i] = capture_muladd(a::Expr, mod)
64+
end
5065
end
5166

52-
function recursive_muladd_search!(call, argv, cnmul::Bool = false, csub::Bool = false)
53-
length(argv) < 3 && (make_fast!(call); return length(call.args) == 4, cnmul, csub)
67+
function recursive_muladd_search!(call, argv, mod, cnmul::Bool = false, csub::Bool = false)
68+
if length(argv) < 3
69+
muladd_arguments!(argv, mod)
70+
return length(call.args) == 4, cnmul, csub
71+
end
5472
fun = first(argv)
5573
isadd = fun === :+ || fun === :add_fast || fun === :vadd || (fun == :(Base.FastMath.add_fast))::Bool
5674
issub = fun === :- || fun === :sub_fast || fun === :vsub || (fun == :(Base.FastMath.sub_fast))::Bool
5775
if !(isadd | issub)
58-
argv[1] = fastfunc(fun)
76+
muladd_arguments!(argv, mod, fun)
5977
return length(call.args) == 4, cnmul, csub
6078
end
6179
exargs = @view(argv[2:end])
@@ -72,29 +90,29 @@ function recursive_muladd_search!(call, argv, cnmul::Bool = false, csub::Bool =
7290
if length(exargs) == 2
7391
push!(call.args, exargs[3 - i])
7492
else
75-
push!(call.args, append_args_skip!(Expr(:call, :add_fast), exargs, i))
93+
push!(call.args, append_args_skip!(Expr(:call, :add_fast), exargs, i, mod))
7694
end
7795
if issub
7896
csub = i == 1
7997
cnmul = !csub
8098
end
8199
return true, cnmul, csub
82100
elseif isadd
83-
found, cnmul, csub = recursive_muladd_search!(call, exa)
101+
found, cnmul, csub = recursive_muladd_search!(call, exa, mod)
84102
if found
85103
if csub
86104
call.args[4] = if length(exargs) == 2
87105
Expr(:call, :sub_fast, exargs[3 - i], call.args[4])
88106
else
89-
Expr(:call, :sub_fast, append_args_skip!(Expr(:call, :add_fast), exargs, i), call.args[4])
107+
Expr(:call, :sub_fast, append_args_skip!(Expr(:call, :add_fast), exargs, i, mod), call.args[4])
90108
end
91109
else
92-
call.args[4] = append_args_skip!(Expr(:call, :add_fast, call.args[4]), exargs, i)
110+
call.args[4] = append_args_skip!(Expr(:call, :add_fast, call.args[4]), exargs, i, mod)
93111
end
94112
return true, cnmul, false
95113
end
96114
elseif issub
97-
found, cnmul, csub = recursive_muladd_search!(call, exa)
115+
found, cnmul, csub = recursive_muladd_search!(call, exa, mod)
98116
if found
99117
if i == 1
100118
if csub
@@ -119,10 +137,11 @@ function recursive_muladd_search!(call, argv, cnmul::Bool = false, csub::Bool =
119137
length(call.args) == 4, cnmul, csub
120138
end
121139

122-
function capture_muladd(ex::Expr, mod)
140+
function capture_a_muladd(ex::Expr, mod)
123141
call = Expr(:call, Symbol(""), Symbol(""), Symbol(""))
124-
found, nmul, sub = recursive_muladd_search!(call, ex.args)
125-
found || return ex
142+
found, nmul, sub = recursive_muladd_search!(call, ex.args, mod)
143+
found || return false, ex
144+
# found || return ex
126145
# a, b, c = call.args[2], call.args[3], call.args[4]
127146
# call.args[2], call.args[3], call.args[4] = c, a, b
128147
f = if nmul && sub
@@ -139,7 +158,13 @@ function capture_muladd(ex::Expr, mod)
139158
else
140159
call.args[1] = Expr(:(.), mod, QuoteNote(f))#_fast))
141160
end
142-
call
161+
true, call
162+
end
163+
function capture_muladd(ex::Expr, mod)
164+
while true
165+
found, ex = capture_a_muladd(ex, mod)
166+
found || return ex
167+
end
143168
end
144169

145170
contract_pass!(::Any, ::Any) = nothing

test/gemm.jl

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -572,7 +572,26 @@
572572
# C[m,n] = Cmn_hi
573573
# end
574574
# end
575-
575+
function doublegemm!(du, u, mat)
576+
@assert size(u, 1) == size(u, 2) == size(mat, 1) == size(mat, 2)
577+
for i2 in 1:size(u, 2), i1 in 1:size(u, 1)
578+
for sum_idx in 1:size(u, 1)
579+
du[i1, i2] += mat[i1, sum_idx] * u[sum_idx, i2] + mat[i2, sum_idx] * u[i1, sum_idx]
580+
end
581+
end
582+
return nothing
583+
end
584+
585+
function doublegemmavx!(du, u, mat)
586+
@assert size(u, 1) == size(u, 2) == size(mat, 1) == size(mat, 2)
587+
@avx for i2 in 1:size(u, 2), i1 in 1:size(u, 1)
588+
for sum_idx in 1:size(u, 1)
589+
du[i1, i2] += mat[i1, sum_idx] * u[sum_idx, i2] + mat[i2, sum_idx] * u[i1, sum_idx]
590+
end
591+
end
592+
return nothing
593+
end
594+
576595
function threegemms!(Ab, Bb, Cb, A, B, C)
577596
M, N = size(Cb); K = size(B,1)
578597
@avx for m in 1:M, k in 1:K, n in 1:N
@@ -647,15 +666,22 @@
647666
# end
648667

649668
for T (Float32, Float64, Int32, Int64)
669+
TC = sizeof(T) == 4 ? Float32 : Float64
670+
R = T <: Integer ? (T(-1000):T(1000)) : T
671+
for M 48:54
672+
C0 = zeros(TC, M, M); C1 = zeros(TC, M, M);
673+
A = rand(R, M, M); B = rand(R, M, M);
674+
doublegemm!(C0, A, B)
675+
doublegemmavx!(C1, A, B)
676+
@test C0 C1
677+
end
650678
# let T = Int32
651679
# exceeds_time_limit() && break
652680
@show T, @__LINE__
653681
# M, K, N = 128, 128, 128;
654682
N = 69;
655683
for M 72:80, K 72:80
656684
# M, K, N = 73, 75, 69;
657-
TC = sizeof(T) == 4 ? Float32 : Float64
658-
R = T <: Integer ? (T(-1000):T(1000)) : T
659685
C = Matrix{TC}(undef, M, N);
660686
A = rand(R, M, K); B = rand(R, K, N);
661687
At = copy(A');

0 commit comments

Comments
 (0)