Skip to content

Commit 76dfaa6

Browse files
committed
Improved compile times.
1 parent d5fa600 commit 76dfaa6

File tree

2 files changed

+169
-77
lines changed

2 files changed

+169
-77
lines changed

src/LoopVectorization.jl

Lines changed: 162 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module LoopVectorization
33
using VectorizationBase, SIMDPirates, SLEEFPirates, MacroTools, Parameters
44
using VectorizationBase: REGISTER_SIZE, extract_data, num_vector_load_expr
55
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul
6-
using MacroTools: @capture, prewalk, postwalk
6+
using MacroTools: prewalk, postwalk
77

88
export vectorizable, @vectorize, @vvectorize
99

@@ -61,46 +61,82 @@ const SLEEFPiratesDict = Dict{Symbol,Tuple{Symbol,Symbol}}(
6161

6262

6363

64-
64+
# @noinline function _spirate(ex, dict, macro_escape = true, mod = :LoopVectorization)
65+
# ex = postwalk(ex) do x
66+
# if @capture(x, a_ += b_)
67+
# return :($a = $mod.vadd($a, $b))
68+
# elseif @capture(x, a_ -= b_)
69+
# return :($a = $mod.vsub($a, $b))
70+
# elseif @capture(x, a_ *= b_)
71+
# return :($a = $mod.vmul($a, $b))
72+
# elseif @capture(x, a_ /= b_)
73+
# return :($a = $mod.vdiv($a, $b))
74+
# elseif @capture(x, Base.FastMath.add_fast(a__))
75+
# return :($mod.vadd($(a...)))
76+
# elseif @capture(x, Base.FastMath.sub_fast(a__))
77+
# return :($mod.vsub($(a...)))
78+
# elseif @capture(x, Base.FastMath.mul_fast(a__))
79+
# return :($mod.vmul($(a...)))
80+
# elseif @capture(x, Base.FastMath.div_fast(a__))
81+
# return :($mod.vfdiv($(a...)))
82+
# elseif @capture(x, a_ / sqrt(b_))
83+
# return :($a * $mod.rsqrt($b))
84+
# elseif @capture(x, inv(sqrt(a_)))
85+
# return :($mod.rsqrt($a))
86+
# elseif @capture(x, @horner a__)
87+
# return SIMDPirates.horner(a...)
88+
# elseif @capture(x, Base.Math.muladd(a_, b_, c_))
89+
# return :( $mod.vmuladd($a, $b, $c) )
90+
# elseif isa(x, Symbol) && !occursin("@", string(x))
91+
# vec_mod, vec_sym = get(dict, x, (:not_found,:not_found))
92+
# if vec_sym != :not_found
93+
# return :($mod.$vec_mod.$vec_sym)
94+
# else
95+
# vec_sym = get(VECTOR_SYMBOLS, x, :not_found)
96+
# return vec_sym == :not_found ? x : :($mod.SIMDPirates.$(vec_sym))
97+
# end
98+
# else
99+
# return x
100+
# end
101+
# end
102+
# macro_escape ? esc(ex) : ex
103+
# end
65104

66105
@noinline function _spirate(ex, dict, macro_escape = true, mod = :LoopVectorization)
67106
ex = postwalk(ex) do x
68-
if @capture(x, a_ += b_)
69-
return :($a = $mod.vadd($a, $b))
70-
elseif @capture(x, a_ -= b_)
71-
return :($a = $mod.vsub($a, $b))
72-
elseif @capture(x, a_ *= b_)
73-
return :($a = $mod.vmul($a, $b))
74-
elseif @capture(x, a_ /= b_)
75-
return :($a = $mod.vdiv($a, $b))
76-
elseif @capture(x, Base.FastMath.add_fast(a__))
77-
return :($mod.vadd($(a...)))
78-
elseif @capture(x, Base.FastMath.sub_fast(a__))
79-
return :($mod.vsub($(a...)))
80-
elseif @capture(x, Base.FastMath.mul_fast(a__))
81-
return :($mod.vmul($(a...)))
82-
elseif @capture(x, Base.FastMath.div_fast(a__))
83-
return :($mod.vfdiv($(a...)))
84-
elseif @capture(x, a_ / sqrt(b_))
85-
return :($a * $mod.rsqrt($b))
86-
elseif @capture(x, inv(sqrt(a_)))
87-
return :($mod.rsqrt($a))
88-
elseif @capture(x, @horner a__)
89-
return SIMDPirates.horner(a...)
90-
elseif @capture(x, Base.Math.muladd(a_, b_, c_))
91-
return :( $mod.vmuladd($a, $b, $c) )
92-
elseif isa(x, Symbol) && !occursin("@", string(x))
93-
vec_mod, vec_sym = get(dict, x, (:not_found,:not_found))
94-
if vec_sym != :not_found
95-
return :($mod.$vec_mod.$vec_sym)
96-
else
97-
vec_sym = get(VECTOR_SYMBOLS, x, :not_found)
98-
return vec_sym == :not_found ? x : :($mod.SIMDPirates.$(vec_sym))
107+
if x isa Symbol
108+
vec_mod, vec_sym = get(dict, x) do
109+
mod, get(VECTOR_SYMBOLS, x) do
110+
x
111+
end
99112
end
113+
return x === vec_sym ? x : Expr(:(.), vec_mod === mod ? mod : Expr(:(.), mod, QuoteNode(vec_mod)), QuoteNode(vec_sym))
114+
end
115+
x isa Expr || return x
116+
xexpr::Expr = x
117+
# if xexpr.head === :macrocall && first(xexpr.args) === Symbol("@horner")
118+
# return SIMDPirates.horner(@view(xexpr.args[3:end])...)
119+
# end
120+
xexpr.head === :call || return x
121+
f = first(xexpr.args)
122+
if f == :(Base.FastMath.add_fast)
123+
vf = :vadd
124+
elseif f == :(Base.FastMath.sub_fast)
125+
vf = :vsub
126+
elseif f == :(Base.FastMath.mul_fast)
127+
vf = :vmul
128+
elseif f == :(Base.FastMath.div_fast)
129+
vf = :vfdiv
130+
elseif f == :(Base.FastMath.sqrt)
131+
vf = :vsqrt
132+
elseif f == :(Base.Math.muladd)
133+
vf = :vmuladd
100134
else
101-
return x
135+
return xexpr
102136
end
137+
return Expr(:call, Expr(:(.), mod, QuoteNode(vf)), @view(x.args[2:end])...)
103138
end
139+
# println(ex)
104140
macro_escape ? esc(ex) : ex
105141
end
106142

@@ -129,15 +165,15 @@ end
129165

130166
@noinline function vectorize_body(N, Tsym::Symbol, uf, n, body, vecdict = SLEEFPiratesDict, VType = SVec, gcpreserve::Bool = true , mod = :LoopVectorization)
131167
if Tsym == :Float32
132-
vectorize_body(N, Float32, uf, n, body, vecdict, VType, mod)
168+
vectorize_body(N, Float32, uf, n, body, vecdict, VType, gcpreserve, mod)
133169
elseif Tsym == :Float64
134-
vectorize_body(N, Float64, uf, n, body, vecdict, VType, mod)
170+
vectorize_body(N, Float64, uf, n, body, vecdict, VType, gcpreserve, mod)
135171
else
136172
throw("Type $Tsym is not supported.")
137173
end
138174
end
139175
@noinline function vectorize_body(
140-
N, ::Type{T}, unroll_factor::Int, n::Symbol, body::Array{Any},
176+
N, ::Type{T}, unroll_factor::Int, n::Symbol, body,
141177
vecdict::Dict{Symbol,Tuple{Symbol,Symbol}} = SLEEFPiratesDict,
142178
@nospecialize(VType = SVec), gcpreserve::Bool = true, mod = :LoopVectorization
143179
) where {T}
@@ -203,9 +239,12 @@ end
203239
## body preamble must define indexed symbols
204240
## we only need that for loads.
205241
dicts = (indexed_expressions, reduction_symbols, loaded_exprs, loop_constants_dict)
206-
push!(main_body.args,
207-
_vectorloads!(main_body, q, dicts, V, loop_constants_quote, b;
208-
itersym = itersym, declared_iter_sym = n, VectorizationDict = vecdict, mod = mod)
242+
push!(
243+
main_body.args,
244+
_vectorloads!(
245+
main_body, q, dicts, V, loop_constants_quote, b;
246+
itersym = itersym, declared_iter_sym = n, VectorizationDict = vecdict, mod = mod
247+
)
209248
)# |> x -> (@show(x), _pirate(x)))
210249
end
211250
# @show main_body
@@ -350,6 +389,7 @@ function insert_mask(x, masksym, reduction_symbols, default_module = :LoopVector
350389
local fs::Symbol, mf::Expr, f::Union{Symbol,Expr}, call::Expr, a::Symbol
351390
if x.head === :(=) # check for reductions
352391
x.args[2] isa Expr || return x
392+
# @show x
353393
a = x.args[1]
354394
call = x.args[2]
355395
f = first(call.args)
@@ -617,8 +657,10 @@ function vectorload!(
617657
else
618658
throw("Currently only supports up to 2 indices for some reason.")
619659
end
620-
elseif f === :zero || f === :one
621-
return Expr(:call, :vbroadcast, V, x)
660+
elseif f === :zero
661+
return Expr(:call, Expr(:(.), mod, QuoteNode(:vbroadcast)), V, zero(T))
662+
elseif f === :one
663+
return Expr(:call, Expr(:(.), mod, QuoteNode(:vbroadcast)), V, one(T))
622664
else
623665
return x
624666
end
@@ -675,13 +717,84 @@ Returns true if a substitution was made, false otherwise.
675717
subbed, expr
676718
end
677719

678-
679-
"""
680-
Arguments are
681-
@vectorize Type UnrollFactor forloop
682-
683-
The default type is Float64, and default UnrollFactor is 1 (no unrolling).
684-
"""
720+
# function loop_components(expr::Expr)
721+
# expr.head === :for || throw("Macro must be applied to a for loop.")
722+
# iterdef = expr.args[1]
723+
# itersym = iterdef.args[1]
724+
# iterrange = iterdef.args[2]
725+
# @assert iterrange isa Expr
726+
# @assert length(expr.args) == 2
727+
# body = expr.args[2]
728+
# iterlength = if iterrange.head === :call
729+
# if iterrange.args[1] === :(:)
730+
# if iterrange.args[2] == 1
731+
# iterrange.args[3]
732+
# else
733+
# Expr(:(-), iterrange.args[3], iterrange.args[2])
734+
# end
735+
# elseif iterrange.args[1] === :eachindex
736+
# if length(iterrange.args) == 2
737+
# Expr(:call, :length, iterrange.args[2])
738+
# else
739+
# il = Expr(:call, :min)
740+
# for i ∈ 2:length(iterrange.args)
741+
# push!(il.args, Expr(:call, :length, iterrange.args[i]))
742+
# end
743+
# il
744+
# end
745+
# else
746+
# throw("could not match loop expression.")
747+
# end
748+
# end
749+
# @show iterdef, itersym
750+
# iterlength, itersym, body
751+
# end
752+
753+
754+
# # Arguments are
755+
# # @vectorize Type UnrollFactor forloop
756+
757+
# # The default type is Float64, and default UnrollFactor is 1 (no unrolling).
758+
759+
760+
# for vec ∈ (false,true)
761+
# if vec
762+
# V = Vec
763+
# macroname = :vvectorize
764+
# else
765+
# V = SVec
766+
# macroname = :vectorize
767+
# end
768+
# for gcpreserve ∈ (true,false)
769+
# if !gcpreserve
770+
# macroname = Symbol(macroname, :_unsafe)
771+
# end
772+
# @eval macro $macroname(expr)
773+
# iterlength, itersym, body = loop_components(expr)
774+
# esc(vectorize_body(iterlength, Float64, 1, itersym, body, SLEEFPiratesDict, $V, $gcpreserve))
775+
# end
776+
# @eval macro $macroname(type, expr)
777+
# iterlength, itersym, body = loop_components(expr)
778+
# esc(vectorize_body(iterlength, type, 1, itersym, body, SLEEFPiratesDict, $V, $gcpreserve))
779+
# end
780+
# @eval macro $macroname(unroll_factor::Integer, expr)
781+
# iterlength, itersym, body = loop_components(expr)
782+
# esc(vectorize_body(iterlength, Float64, unroll_factor, itersym, body, SLEEFPiratesDict, $V, $gcpreserve))
783+
# end
784+
# @eval macro $macroname(type, unroll_factor::Integer, expr)
785+
# iterlength, itersym, body = loop_components(expr)
786+
# esc(vectorize_body(iterlength, type, unroll_factor, itersym, body, SLEEFPiratesDict, $V, $gcpreserve))
787+
# end
788+
# @eval macro $macroname(type, mod::Union{Symbol,Module}, expr)
789+
# iterlength, itersym, body = loop_components(expr)
790+
# esc(vectorize_body(iterlength, type, 1, itersym, body, SLEEFPiratesDict, $V, $gcpreserve, mod))
791+
# end
792+
# @eval macro $macroname(type, mod::Union{Symbol,Module}, unroll_factor::Integer, expr)
793+
# iterlength, itersym, body = loop_components(expr)
794+
# esc(vectorize_body(iterlength, type, unroll_factor, itersym, body, SLEEFPiratesDict, $V, $gcpreserve, mod))
795+
# end
796+
# end
797+
# end
685798

686799
for vec (false,true)
687800
if vec

src/precompile.jl

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,19 @@
11
function _precompile_()
22
ccall(:jl_generating_output, Cint, ()) == 1 || return nothing
3-
isdefined(MacroTools, Symbol("#19#20")) && precompile(Tuple{getfield(MacroTools, Symbol("#19#20")),Float64})
4-
isdefined(MacroTools, Symbol("#19#20")) && precompile(Tuple{getfield(MacroTools, Symbol("#19#20")),Int64})
5-
isdefined(MacroTools, Symbol("#19#20")) && precompile(Tuple{getfield(MacroTools, Symbol("#19#20")),Module})
6-
isdefined(MacroTools, Symbol("#19#20")) && precompile(Tuple{getfield(MacroTools, Symbol("#19#20")),QuoteNode})
7-
isdefined(MacroTools, Symbol("#19#20")) && precompile(Tuple{getfield(MacroTools, Symbol("#19#20")),QuoteNode})
3+
isdefined(LoopVectorization, Symbol("#1#4")) && precompile(Tuple{getfield(LoopVectorization, Symbol("#1#4")),Expr})
4+
isdefined(LoopVectorization, Symbol("#1#4")) && precompile(Tuple{getfield(LoopVectorization, Symbol("#1#4")),Expr})
85
isdefined(MacroTools, Symbol("#19#20")) && precompile(Tuple{getfield(MacroTools, Symbol("#19#20")),Symbol})
96
isdefined(MacroTools, Symbol("#19#20")) && precompile(Tuple{getfield(MacroTools, Symbol("#19#20")),Symbol})
10-
isdefined(MacroTools, Symbol("#19#20")) && precompile(Tuple{getfield(MacroTools, Symbol("#19#20")),Type})
11-
isdefined(MacroTools, Symbol("#19#20")) && precompile(Tuple{getfield(MacroTools, Symbol("#19#20")),Type})
12-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),Float64})
13-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),Float64})
14-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),Int64})
15-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),Int64})
16-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),Int64})
17-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),LineNumberNode})
18-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),LineNumberNode})
19-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),Module})
20-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),Module})
21-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),QuoteNode})
22-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),QuoteNode})
23-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),QuoteNode})
24-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),QuoteNode})
257
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),Symbol})
268
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),Symbol})
27-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),Symbol})
28-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),Symbol})
29-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),Type})
30-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),Type})
31-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),Type})
32-
isdefined(MacroTools, Symbol("#21#22")) && precompile(Tuple{getfield(MacroTools, Symbol("#21#22")),Type})
33-
precompile(Tuple{Core.kwftype(typeof(LoopVectorization._vectorloads!)),NamedTuple{(:itersym, :declared_iter_sym, :VectorizationDict, :mod),Tuple{Symbol,Symbol,Dict{Symbol,Tuple{Symbol,Symbol}},Module}},typeof(LoopVectorization._vectorloads!),Expr,Expr,Tuple{Dict{Symbol,Symbol},Dict{Tuple{Symbol,Symbol},Symbol},Dict{Expr,Symbol},Dict{Expr,Symbol}},Type,Int64,Type,Expr,Expr})
34-
precompile(Tuple{Core.kwftype(typeof(LoopVectorization._vectorloads!)),NamedTuple{(:itersym, :declared_iter_sym, :VectorizationDict, :mod),Tuple{Symbol,Symbol,Dict{Symbol,Tuple{Symbol,Symbol}},Symbol}},typeof(LoopVectorization._vectorloads!),Expr,Expr,Tuple{Dict{Symbol,Symbol},Dict{Tuple{Symbol,Symbol},Symbol},Dict{Expr,Symbol},Dict{Expr,Symbol}},Type,Int64,Type,Expr,Expr})
9+
precompile(Tuple{Core.kwftype(typeof(LoopVectorization._vectorloads!)),NamedTuple{(:itersym, :declared_iter_sym, :VectorizationDict, :mod),Tuple{Symbol,Symbol,Dict{Symbol,Tuple{Symbol,Symbol}},Module}},typeof(LoopVectorization._vectorloads!),Expr,Expr,Tuple{Dict{Symbol,Symbol},Dict{Tuple{Symbol,Symbol},Symbol},Dict{Expr,Symbol},Dict{Expr,Symbol}},Type{NTuple{8,VecElement{Float64}}},Expr,Expr})
10+
precompile(Tuple{Core.kwftype(typeof(LoopVectorization._vectorloads!)),NamedTuple{(:itersym, :declared_iter_sym, :VectorizationDict, :mod),Tuple{Symbol,Symbol,Dict{Symbol,Tuple{Symbol,Symbol}},Symbol}},typeof(LoopVectorization._vectorloads!),Expr,Expr,Tuple{Dict{Symbol,Symbol},Dict{Tuple{Symbol,Symbol},Symbol},Dict{Expr,Symbol},Dict{Expr,Symbol}},Type{NTuple{8,VecElement{Float64}}},Expr,Expr})
3511
precompile(Tuple{typeof(LoopVectorization.add_masks),Expr,Symbol,Dict{Tuple{Symbol,Symbol},Symbol},Module})
3612
precompile(Tuple{typeof(LoopVectorization.add_masks),Expr,Symbol,Dict{Tuple{Symbol,Symbol},Symbol},Symbol})
13+
precompile(Tuple{typeof(LoopVectorization.vectorize_assign_linear_index),Symbol,Expr,Symbol,Dict{Symbol,Symbol},Symbol,Symbol,Module})
3714
precompile(Tuple{typeof(LoopVectorization.vectorize_body),Int64,Type{Float64},Int64,Symbol,Array{Any,1},Dict{Symbol,Tuple{Symbol,Symbol}},Any,Bool,Module})
3815
precompile(Tuple{typeof(LoopVectorization.vectorize_body),Int64,Type{Float64},Int64,Symbol,Array{Any,1},Dict{Symbol,Tuple{Symbol,Symbol}},Any,Bool})
3916
precompile(Tuple{typeof(LoopVectorization.vectorize_body),Symbol,Type{Float64},Int64,Symbol,Array{Any,1},Dict{Symbol,Tuple{Symbol,Symbol}},Any,Bool})
17+
precompile(Tuple{typeof(LoopVectorization.vectorize_linear_index!),Expr,Dict{Expr,Symbol},Dict{Symbol,Symbol},Symbol,Expr,Symbol,Symbol,Symbol,Type})
18+
precompile(Tuple{typeof(LoopVectorization.vectorize_linear_index!),Expr,Dict{Expr,Symbol},Dict{Symbol,Symbol},Symbol,Symbol,Symbol,Symbol,Module,Type})
4019
end

0 commit comments

Comments
 (0)