Skip to content

Commit 92a71a5

Browse files
committed
More mul and add nesting
1 parent 0f9ab73 commit 92a71a5

File tree

1 file changed

+67
-58
lines changed

1 file changed

+67
-58
lines changed

src/vectorizationbase_compat/contract_pass.jl

Lines changed: 67 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ function recursive_muladd_search!(call, argv, mod, cnmul::Bool = false, csub::Bo
113113
end
114114
return true, cnmul, csub
115115
elseif isadd
116-
found, cnmul, csub = recursive_muladd_search!(call, exa, mod)
116+
found, cnmul, csub = recursive_muladd_search!(call, exa, mod)
117117
if found
118118
if csub
119119
call.args[4] = if length(exargs) == 2
@@ -153,72 +153,81 @@ function recursive_muladd_search!(call, argv, mod, cnmul::Bool = false, csub::Bo
153153
end
154154

155155
function capture_a_muladd(ex::Expr, mod)
156-
call = Expr(:call, Symbol(""), Symbol(""), Symbol(""))
157-
found, nmul, sub = recursive_muladd_search!(call, ex.args, mod)
158-
found || return false, ex
159-
# found || return ex
160-
# a, b, c = call.args[2], call.args[3], call.args[4]
161-
# call.args[2], call.args[3], call.args[4] = c, a, b
162-
f = if nmul && sub
163-
:vfnmsub_fast
164-
elseif nmul
165-
:vfnmadd_fast
166-
elseif sub
167-
:vfmsub_fast
168-
else
169-
:vfmadd_fast
170-
end
171-
if mod === nothing
172-
call.args[1] = f
173-
else
174-
call.args[1] = Expr(:(.), mod, QuoteNote(f))#_fast))
156+
call = Expr(:call, Symbol(""), Symbol(""), Symbol(""))
157+
found, nmul, sub = recursive_muladd_search!(call, ex.args, mod)
158+
if !found
159+
if length(ex.args) > 3
160+
f = ex.args[1]
161+
if (f === :add_fast) | (f === :mul_fast)
162+
newex = Expr(:call, f, ex.args[2], ex.args[3])
163+
for i 4:length(ex.args)
164+
newex = Expr(:call, f, newex, ex.args[i])
165+
end
166+
ex = newex
167+
end
175168
end
176-
true, call
169+
return false, ex
170+
end
171+
# found || return ex
172+
# a, b, c = call.args[2], call.args[3], call.args[4]
173+
# call.args[2], call.args[3], call.args[4] = c, a, b
174+
f = if nmul && sub
175+
:vfnmsub_fast
176+
elseif nmul
177+
:vfnmadd_fast
178+
elseif sub
179+
:vfmsub_fast
180+
else
181+
:vfmadd_fast
182+
end
183+
if mod === nothing
184+
call.args[1] = f
185+
else
186+
call.args[1] = Expr(:(.), mod, QuoteNote(f))#_fast))
187+
end
188+
true, call
177189
end
178190
function capture_muladd(ex::Expr, mod)
179191
while true
180-
ex.head === :ref && return ex
181-
found, ex = capture_a_muladd(ex, mod)
182-
found || return ex
183-
end
192+
ex.head === :ref && return ex
193+
found, ex = capture_a_muladd(ex, mod)
194+
found || return ex
195+
end
184196
end
185-
186-
function append_update_args!(call::Expr, ex::Expr)
187-
for i 2:length(ex.args)
188-
push!(call.args, ex.args[i])
189-
end
190-
push!(call.args, ex.args[1])
191-
nothing
197+
function append_update_args(f, ex::Expr)
198+
call = Expr(:call, f)
199+
for i 2:length(ex.args)
200+
push!(call.args, ex.args[i])
201+
end
202+
push!(call.args, ex.args[1])
203+
nothing
192204
end
193205
contract_pass!(::Any, ::Any) = nothing
194206
function contract!(expr::Expr, ex::Expr, i::Int, mod)
195-
# if ex.head === :call
196-
# expr.args[i] = capture_muladd(ex, mod)
197-
if ex.head === :(+=)
198-
call = Expr(:call, :add_fast)
199-
append_update_args!(call, ex)
200-
expr.args[i] = ex = Expr(:(=), first(ex.args), call)
201-
elseif ex.head === :(-=)
202-
call = Expr(:call, :sub_fast)
203-
append!(call.args, ex.args)
204-
expr.args[i] = ex = Expr(:(=), first(ex.args), call)
205-
elseif ex.head === :(*=)
206-
call = Expr(:call, :mul_fast)
207-
append_update_args!(call, ex)
208-
expr.args[i] = ex = Expr(:(=), first(ex.args), call)
209-
elseif ex.head === :(/=)
210-
call = Expr(:call, :div_fast)
211-
append!(call.args, ex.args)
212-
expr.args[i] = ex = Expr(:(=), first(ex.args), call)
213-
end
214-
if ex.head === :(=)
215-
RHS = ex.args[2]
216-
# @show ex
217-
if RHS isa Expr && Base.sym_in(RHS.head, (:call,:if))
218-
ex.args[2] = capture_muladd(RHS, mod)
219-
end
207+
# if ex.head === :call
208+
# expr.args[i] = capture_muladd(ex, mod)
209+
if ex.head === :(+=)
210+
call = append_update_args(:add_fast, ex)
211+
expr.args[i] = ex = Expr(:(=), first(ex.args), call)
212+
elseif ex.head === :(-=)
213+
call = Expr(:call, :sub_fast)
214+
append!(call.args, ex.args)
215+
expr.args[i] = ex = Expr(:(=), first(ex.args), call)
216+
elseif ex.head === :(*=)
217+
call = append_update_args(:mul_fast, ex)
218+
expr.args[i] = ex = Expr(:(=), first(ex.args), call)
219+
elseif ex.head === :(/=)
220+
call = Expr(:call, :div_fast)
221+
append!(call.args, ex.args)
222+
expr.args[i] = ex = Expr(:(=), first(ex.args), call)
223+
end
224+
if ex.head === :(=)
225+
RHS = ex.args[2]
226+
if RHS isa Expr && Base.sym_in(RHS.head, (:call,:if))
227+
ex.args[2] = capture_muladd(RHS, mod)
220228
end
221-
contract_pass!(expr.args[i], mod)
229+
end
230+
contract_pass!(expr.args[i], mod)
222231
end
223232
# contract_pass(x) = x # x will probably be a symbol
224233
function contract_pass!(expr::Expr, mod = nothing)

0 commit comments

Comments
 (0)