Skip to content

Commit 088795c

Browse files
committed
Infs to literals
1 parent 673b7e3 commit 088795c

File tree

1 file changed

+84
-70
lines changed

1 file changed

+84
-70
lines changed

src/vectorizationbase_compat/contract_pass.jl

Lines changed: 84 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -76,80 +76,85 @@ function muladd_arguments!(argv, mod, f = first(argv))
7676
end
7777

7878
function recursive_muladd_search!(call, argv, mod, cnmul::Bool = false, csub::Bool = false)
79-
if length(argv) < 3
80-
muladd_arguments!(argv, mod)
81-
return length(call.args) == 4, cnmul, csub
82-
end
83-
fun = first(argv)
84-
isadd = fun === :+ || fun === :add_fast || fun === :vadd || (fun == :(Base.FastMath.add_fast))::Bool
85-
issub = fun === :- || fun === :sub_fast || fun === :vsub || (fun == :(Base.FastMath.sub_fast))::Bool
86-
if isadd
87-
argv[1] = :add_fast
88-
elseif issub
89-
argv[1] = :sub_fast
90-
else
91-
muladd_arguments!(argv, mod, fun)
92-
return length(call.args) == 4, cnmul, csub
79+
if length(argv) < 3
80+
muladd_arguments!(argv, mod)
81+
return length(call.args) == 4, cnmul, csub
82+
end
83+
fun = first(argv)
84+
isadd = fun === :+ || fun === :add_fast || fun === :vadd || (fun == :(Base.FastMath.add_fast))::Bool
85+
issub = fun === :- || fun === :sub_fast || fun === :vsub || (fun == :(Base.FastMath.sub_fast))::Bool
86+
if isadd
87+
argv[1] = :add_fast
88+
elseif issub
89+
argv[1] = :sub_fast
90+
else
91+
muladd_arguments!(argv, mod, fun)
92+
return length(call.args) == 4, cnmul, csub
93+
end
94+
exargs = @view(argv[2:end])
95+
for i eachindex(exargs)
96+
if exargs[i] === :Inf
97+
exargs[i] === Inf
9398
end
94-
exargs = @view(argv[2:end])
95-
issub && @assert length(exargs) == 2
96-
for (i,ex) enumerate(exargs)
97-
if ex isa Expr && ex.head === :call
98-
exa = ex.args
99-
f = first(exa)
100-
exav = @view(exa[2:end])
101-
if f === :* || f === :mul_fast || f === :vmul || (f == :(Base.FastMath.mul_fast))::Bool
102-
a, b = mulexpr(exav)
103-
call.args[2] = a
104-
call.args[3] = b
105-
if length(exargs) == 2
106-
push!(call.args, exargs[3 - i])
107-
else
108-
push!(call.args, append_args_skip!(Expr(:call, :add_fast), exargs, i, mod))
109-
end
110-
if issub
111-
csub = i == 1
112-
cnmul = !csub
113-
end
114-
return true, cnmul, csub
115-
elseif isadd
116-
found, cnmul, csub = recursive_muladd_search!(call, exa, mod)
117-
if found
118-
if csub
119-
call.args[4] = if length(exargs) == 2
120-
Expr(:call, :sub_fast, exargs[3 - i], call.args[4])
121-
else
122-
Expr(:call, :sub_fast, append_args_skip!(Expr(:call, :add_fast), exargs, i, mod), call.args[4])
123-
end
124-
else
125-
call.args[4] = append_args_skip!(Expr(:call, :add_fast, call.args[4]), exargs, i, mod)
126-
end
127-
return true, cnmul, false
128-
end
129-
elseif issub
130-
found, cnmul, csub = recursive_muladd_search!(call, exa, mod)
131-
if found
132-
if i == 1
133-
if csub
134-
call.args[4] = Expr(:call, :add_fast, call.args[4], exargs[3 - i])
135-
else
136-
call.args[4] = Expr(:call, :sub_fast, call.args[4], exargs[3 - i])
137-
end
138-
else
139-
cnmul = !cnmul
140-
if csub
141-
call.args[4] = Expr(:call, :add_fast, exargs[3 - i], call.args[4])
142-
else
143-
call.args[4] = Expr(:call, :sub_fast, exargs[3 - i], call.args[4])
144-
end
145-
csub = false
146-
end
147-
return true, cnmul, csub
148-
end
99+
end
100+
issub && @assert length(exargs) == 2
101+
for (i,ex) enumerate(exargs)
102+
if ex isa Expr && ex.head === :call
103+
exa = ex.args
104+
f = first(exa)
105+
exav = @view(exa[2:end])
106+
if f === :* || f === :mul_fast || f === :vmul || (f == :(Base.FastMath.mul_fast))::Bool
107+
a, b = mulexpr(exav)
108+
call.args[2] = a
109+
call.args[3] = b
110+
if length(exargs) == 2
111+
push!(call.args, exargs[3 - i])
112+
else
113+
push!(call.args, append_args_skip!(Expr(:call, :add_fast), exargs, i, mod))
114+
end
115+
if issub
116+
csub = i == 1
117+
cnmul = !csub
118+
end
119+
return true, cnmul, csub
120+
elseif isadd
121+
found, cnmul, csub = recursive_muladd_search!(call, exa, mod)
122+
if found
123+
if csub
124+
call.args[4] = if length(exargs) == 2
125+
Expr(:call, :sub_fast, exargs[3 - i], call.args[4])
126+
else
127+
Expr(:call, :sub_fast, append_args_skip!(Expr(:call, :add_fast), exargs, i, mod), call.args[4])
149128
end
129+
else
130+
call.args[4] = append_args_skip!(Expr(:call, :add_fast, call.args[4]), exargs, i, mod)
131+
end
132+
return true, cnmul, false
150133
end
134+
elseif issub
135+
found, cnmul, csub = recursive_muladd_search!(call, exa, mod)
136+
if found
137+
if i == 1
138+
if csub
139+
call.args[4] = Expr(:call, :add_fast, call.args[4], exargs[3 - i])
140+
else
141+
call.args[4] = Expr(:call, :sub_fast, call.args[4], exargs[3 - i])
142+
end
143+
else
144+
cnmul = !cnmul
145+
if csub
146+
call.args[4] = Expr(:call, :add_fast, exargs[3 - i], call.args[4])
147+
else
148+
call.args[4] = Expr(:call, :sub_fast, exargs[3 - i], call.args[4])
149+
end
150+
csub = false
151+
end
152+
return true, cnmul, csub
153+
end
154+
end
151155
end
152-
length(call.args) == 4, cnmul, csub
156+
end
157+
length(call.args) == 4, cnmul, csub
153158
end
154159

155160
function capture_a_muladd(ex::Expr, mod)
@@ -190,6 +195,15 @@ end
190195
function capture_muladd(ex::Expr, mod)
191196
while true
192197
ex.head === :ref && return ex
198+
if Meta.isexpr(ex, :call, 2)
199+
if (ex.args[1] === :(-))
200+
if (ex.args[2] isa Number)
201+
return -ex.args[2]
202+
elseif ex.args[2] === :Inf
203+
return -Inf
204+
end
205+
end
206+
end
193207
found, ex = capture_a_muladd(ex, mod)
194208
found || return ex
195209
end

0 commit comments

Comments
 (0)