Skip to content

Commit 832e346

Browse files
committed
Some changes, outlining a couple functions.
1 parent d5fa600 commit 832e346

File tree

3 files changed

+130
-52
lines changed

3 files changed

+130
-52
lines changed

src/LoopVectorization.jl

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,11 @@ end
158158
end
159159
WT = W * T_size
160160
V = VType{W,T}
161-
161+
vectorize_body(N, Nsym, VType{W,T}, unroll_factor, n, body, vecdict, gcpreserve, Wshift, log2unroll, mod)
162+
end
163+
@noinline function vectorize_body(
164+
N, Nsym, ::Type{V}, unroll_factor, n, body, vecdict, gcpreserve, Wshift, log2unroll, mod
165+
) where {W,T,V <: Union{SVec{W,T},Vec{W,T}}}
162166
indexed_expressions = Dict{Symbol,Symbol}() # Symbol, gensymbol
163167

164168
itersym = gensym(:i)
@@ -276,34 +280,53 @@ end
276280
end
277281
end
278282
### now we walk the body to look for reductions
283+
add_reductions!(q, V, reduction_symbols, unroll_factor, mod)
284+
# display(q)
285+
# We are using pointers, so better add a GC.@preserve.
286+
# gcpreserve = true
287+
# gcpreserve = false
288+
if gcpreserve
289+
return quote
290+
$(Expr(:macrocall,
291+
Expr(:., :GC, QuoteNode(Symbol("@preserve"))),
292+
LineNumberNode(@__LINE__), (keys(indexed_expressions))..., q
293+
))
294+
nothing
295+
end
296+
else
297+
return q
298+
end
299+
end
300+
301+
function add_reductions!(q, ::Type{V}, reduction_symbols, unroll_factor, mod) where {W,T,V <: Union{SVec{W,T},Vec{W,T}}}
279302
if unroll_factor == 1
280303
for ((sym,op),gsym) reduction_symbols
281-
if op == :+ || op == :-
304+
if op === :+ || op === :-
282305
pushfirst!(q.args, :($gsym = $mod.vbroadcast($V,zero($T))))
283-
elseif op == :* || op == :/
306+
elseif op === :* || op === :/
284307
pushfirst!(q.args, :($gsym = $mod.vbroadcast($V,one($T))))
285308
end
286-
if op == :+
309+
if op === :+
287310
push!(q.args, :($sym = Base.FastMath.add_fast($sym, $mod.vsum($gsym))))
288-
elseif op == :-
311+
elseif op === :-
289312
push!(q.args, :($sym = Base.FastMath.sub_fast($sym, $mod.vsum($gsym))))
290-
elseif op == :*
313+
elseif op === :*
291314
push!(q.args, :($sym = Base.FastMath.mul_fast($sym, $mod.SIMDPirates.vprod($gsym))))
292-
elseif op == :/
315+
elseif op === :/
293316
push!(q.args, :($sym = Base.FastMath.div_fast($sym, $mod.SIMDPirates.vprod($gsym))))
294317
end
295318
end
296319
else
297320
for ((sym,op),gsym_base) reduction_symbols
298321
for uf 0:unroll_factor-1
299322
gsym = Symbol(gsym_base, :_, uf)
300-
if op == :+ || op == :-
323+
if op === :+ || op === :-
301324
pushfirst!(q.args, :($gsym = $mod.vbroadcast($V,zero($T))))
302-
elseif op == :* || op == :/
325+
elseif op === :* || op === :/
303326
pushfirst!(q.args, :($gsym = $mod.vbroadcast($V,one($T))))
304327
end
305328
end
306-
func = ((op == :*) | (op == :/)) ? :($mod.evmul) : :($mod.evadd)
329+
func = ((op === :*) | (op === :/)) ? :($mod.evmul) : :($mod.evadd)
307330
uf_new = unroll_factor
308331
while uf_new > 1
309332
uf_new, uf_prev = uf_new >> 1, uf_new
@@ -316,33 +339,19 @@ end
316339
end
317340
end
318341
gsym = Symbol(gsym_base, :_, 0)
319-
if op == :+
342+
if op === :+
320343
push!(q.args, :($sym = Base.FastMath.add_fast($sym, $mod.vsum($gsym))))
321-
elseif op == :-
344+
elseif op === :-
322345
push!(q.args, :($sym = Base.FastMath.sub_fast($sym, $mod.vsum($gsym))))
323-
elseif op == :*
346+
elseif op === :*
324347
push!(q.args, :($sym = Base.FastMath.mul_fast($sym, $mod.SIMDPirates.vprod($gsym))))
325-
elseif op == :/
348+
elseif op === :/
326349
push!(q.args, :($sym = Base.FastMath.div_fast($sym, $mod.SIMDPirates.vprod($gsym))))
327350
end
328351
end
329352
end
330353
push!(q.args, nothing)
331-
# display(q)
332-
# We are using pointers, so better add a GC.@preserve.
333-
# gcpreserve = true
334-
# gcpreserve = false
335-
if gcpreserve
336-
return quote
337-
$(Expr(:macrocall,
338-
Expr(:., :GC, QuoteNode(Symbol("@preserve"))),
339-
LineNumberNode(@__LINE__), (keys(indexed_expressions))..., q
340-
))
341-
nothing
342-
end
343-
else
344-
return q
345-
end
354+
nothing
346355
end
347356

348357
function insert_mask(x, masksym, reduction_symbols, default_module = :LoopVectorization)
@@ -617,8 +626,10 @@ function vectorload!(
617626
else
618627
throw("Currently only supports up to 2 indices for some reason.")
619628
end
620-
elseif f === :zero || f === :one
621-
return Expr(:call, :vbroadcast, V, x)
629+
elseif f === :zero
630+
return Expr(:call, Expr(:(.), mod, QuoteNode(:vbroadcast)), V, zero(T))
631+
elseif f === :one
632+
return Expr(:call, Expr(:(.), mod, QuoteNode(:vbroadcast)), V, one(T))
622633
else
623634
return x
624635
end
@@ -635,6 +646,7 @@ end
635646
itersym = :iter, declared_iter_sym = nothing, VectorizationDict = SLEEFPiratesDict, mod = :LoopVectorization
636647
) where {W,T,V <: Union{Vec{W,T},SVec{W,T}}}
637648
q = prewalk(expr) do x
649+
# @show x
638650
if x isa Symbol
639651
if x === declared_iter_sym
640652
isymvec = gensym(itersym)

src/constructors.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11

22
### This file contains convenience functions for constructing LoopSets.
33

4-
function loopset_from_expr(qe::Expr)
5-
q = contract_pass(qe)
4+
function loopset_from_expr(q::Expr)
5+
q = contract_pass(q)
6+
67
postwalk(q) do ex
78

89
end

src/contract_pass.jl

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,71 @@
11

2+
function check_negative(x)
3+
x isa Expr || return false
4+
x.head === :call || return false
5+
length(x.args) == 2 || return false
6+
a = first(x.args)
7+
return (a === :(-) || a == :(Base.FastMath.sub_fast))
8+
end
9+
10+
function capture_muladd(ex::Expr)
11+
# These are guaranteed by calling contract_pass
12+
# ex isa Expr || return ex
13+
# ex.head === :call || return ex
14+
args = ex.args
15+
f = first(args)::Union{Symbol,Expr}
16+
fplus = (f === :+)::Bool | (f == :(Base.FastMath.add_fast))
17+
fminus = (f === :-)::Bool | (f == :(Base.FastMath.sub_fast))
18+
(fplus | fminus) || return ex
19+
Nargs = length(args)
20+
Nargs > 2 || return ex
21+
j = 2
22+
while j Nargs
23+
argsⱼ = args[j]
24+
if argsⱼ isa Expr && (first(argsⱼ.args) === :* || first(argsⱼ.args) == :(Base.FastMath.mul_fast))
25+
break
26+
end
27+
j += 1
28+
end
29+
j > Nargs && return ex
30+
mulexpr::Expr = args[j]
31+
if Nargs == 3
32+
c = args[j == 2 ? 3 : 2]
33+
else
34+
c = Expr(:call, :vadd)
35+
for i 2:Nargs
36+
i == j || push!(c.args, args[i])
37+
end
38+
end
39+
isnmul = any(check_negative, @view(mulexpr.args[2:end]))
40+
a = mulexpr.args[2]
41+
b = if length(mulexpr.args) == 3 # two arg mul
42+
mulexpr.args[3]
43+
else
44+
Expr(:call, :vmul, @view(mulexpr.args[3:end])...)
45+
end
46+
cf = if fplus
47+
if isnmul
48+
:vfnmadd
49+
else
50+
:vmuladd
51+
end
52+
else
53+
if isnmul
54+
:vfnmsub
55+
else
56+
:vfnmadd
57+
end
58+
end
59+
Expr(:call, cf, a, b, c)
60+
end
261

362

463
contract_pass(x) = x # x will probably be a symbol
564
function contract_pass(expr::Expr)::Expr
665
prewalk(expr) do ex
766
if !(ex isa Expr)
867
return ex
9-
elseif ex.head != :call
68+
elseif ex.head !== :call
1069
if ex.head === :(+=)
1170
call = Expr(:call, :(+))
1271
append!(call.args, ex.args)
@@ -23,30 +82,36 @@ function contract_pass(expr::Expr)::Expr
2382
call = Expr(:call, :(/))
2483
append!(call.args, ex.args)
2584
Expr(:(=), first(ex.args), call)
26-
elseif ex.head != :call
27-
ex
28-
end
29-
elseif @capture(ex, f_(c_, g_(a_, b_))) || @capture(ex, f_(g_(a_,b_), c_))
30-
if (f === :(+) || f == :(Base.FastMath.add_fast)) && (g === :(*) || g == :(Base.FastMath.mul_fast))
31-
if a isa Expr && a.head === :call && (first(a.args) === :(-) || first(a.args) == :(Base.FastMath.sub_fast))
32-
Expr(:call, :vnfmadd, a, b, c)
33-
else
34-
Expr(:call, :vmuladd, a, b, c) #Expr(:call, :vfmadd, a, b, c)
35-
end
36-
elseif (f === :(-) || f == :(Base.FastMath.sub_fast)) && (g === :(*) || g == :(Base.FastMath.mul_fast))
37-
if a isa Expr && a.head === :call && (first(a.args) === :(-) || first(a.args) == :(Base.FastMath.sub_fast))
38-
Expr(:call, :vnfmsub, a, b, c)
39-
else
40-
Expr(:call, :vfmsub, a, b, c)
41-
end
4285
else
4386
ex
4487
end
45-
else
46-
ex
88+
else # ex.head === :call
89+
return capture_muladd(ex)
4790
end
4891
end
4992
end
5093

94+
# elseif @capture(ex, f_(c_, g_(a_, b_))) || @capture(ex, f_(g_(a_,b_), c_))
95+
# if (f === :(+) || f == :(Base.FastMath.add_fast)) && (g === :(*) || g == :(Base.FastMath.mul_fast))
96+
# if a isa Expr && a.head === :call && (first(a.args) === :(-) || first(a.args) == :(Base.FastMath.sub_fast))
97+
# Expr(:call, :vfnmadd, a, b, c)
98+
# else
99+
# Expr(:call, :vmuladd, a, b, c) #Expr(:call, :vfmadd, a, b, c)
100+
# end
101+
# elseif (f === :(-) || f == :(Base.FastMath.sub_fast)) && (g === :(*) || g == :(Base.FastMath.mul_fast))
102+
# if a isa Expr && a.head === :call && (first(a.args) === :(-) || first(a.args) == :(Base.FastMath.sub_fast))
103+
# Expr(:call, :vfnmsub, a, b, c)
104+
# else
105+
# Expr(:call, :vfmsub, a, b, c)
106+
# end
107+
# else
108+
# ex
109+
# end
110+
# else
111+
# ex
112+
# end
113+
# end
114+
# end
115+
51116

52117

0 commit comments

Comments
 (0)