@@ -3,7 +3,7 @@ module LoopVectorization
3
3
using VectorizationBase, SIMDPirates, SLEEFPirates, MacroTools, Parameters
4
4
using VectorizationBase: REGISTER_SIZE, extract_data, num_vector_load_expr
5
5
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul
6
- using MacroTools: @capture , prewalk, postwalk
6
+ using MacroTools: prewalk, postwalk
7
7
8
8
export vectorizable, @vectorize , @vvectorize
9
9
@@ -61,46 +61,82 @@ const SLEEFPiratesDict = Dict{Symbol,Tuple{Symbol,Symbol}}(
61
61
62
62
63
63
64
-
64
+ # @noinline function _spirate(ex, dict, macro_escape = true, mod = :LoopVectorization)
65
+ # ex = postwalk(ex) do x
66
+ # if @capture(x, a_ += b_)
67
+ # return :($a = $mod.vadd($a, $b))
68
+ # elseif @capture(x, a_ -= b_)
69
+ # return :($a = $mod.vsub($a, $b))
70
+ # elseif @capture(x, a_ *= b_)
71
+ # return :($a = $mod.vmul($a, $b))
72
+ # elseif @capture(x, a_ /= b_)
73
+ # return :($a = $mod.vdiv($a, $b))
74
+ # elseif @capture(x, Base.FastMath.add_fast(a__))
75
+ # return :($mod.vadd($(a...)))
76
+ # elseif @capture(x, Base.FastMath.sub_fast(a__))
77
+ # return :($mod.vsub($(a...)))
78
+ # elseif @capture(x, Base.FastMath.mul_fast(a__))
79
+ # return :($mod.vmul($(a...)))
80
+ # elseif @capture(x, Base.FastMath.div_fast(a__))
81
+ # return :($mod.vfdiv($(a...)))
82
+ # elseif @capture(x, a_ / sqrt(b_))
83
+ # return :($a * $mod.rsqrt($b))
84
+ # elseif @capture(x, inv(sqrt(a_)))
85
+ # return :($mod.rsqrt($a))
86
+ # elseif @capture(x, @horner a__)
87
+ # return SIMDPirates.horner(a...)
88
+ # elseif @capture(x, Base.Math.muladd(a_, b_, c_))
89
+ # return :( $mod.vmuladd($a, $b, $c) )
90
+ # elseif isa(x, Symbol) && !occursin("@", string(x))
91
+ # vec_mod, vec_sym = get(dict, x, (:not_found,:not_found))
92
+ # if vec_sym != :not_found
93
+ # return :($mod.$vec_mod.$vec_sym)
94
+ # else
95
+ # vec_sym = get(VECTOR_SYMBOLS, x, :not_found)
96
+ # return vec_sym == :not_found ? x : :($mod.SIMDPirates.$(vec_sym))
97
+ # end
98
+ # else
99
+ # return x
100
+ # end
101
+ # end
102
+ # macro_escape ? esc(ex) : ex
103
+ # end
65
104
66
105
@noinline function _spirate (ex, dict, macro_escape = true , mod = :LoopVectorization )
67
106
ex = postwalk (ex) do x
68
- if @capture (x, a_ += b_)
69
- return :($ a = $ mod. vadd ($ a, $ b))
70
- elseif @capture (x, a_ -= b_)
71
- return :($ a = $ mod. vsub ($ a, $ b))
72
- elseif @capture (x, a_ *= b_)
73
- return :($ a = $ mod. vmul ($ a, $ b))
74
- elseif @capture (x, a_ /= b_)
75
- return :($ a = $ mod. vdiv ($ a, $ b))
76
- elseif @capture (x, Base. FastMath. add_fast (a__))
77
- return :($ mod. vadd ($ (a... )))
78
- elseif @capture (x, Base. FastMath. sub_fast (a__))
79
- return :($ mod. vsub ($ (a... )))
80
- elseif @capture (x, Base. FastMath. mul_fast (a__))
81
- return :($ mod. vmul ($ (a... )))
82
- elseif @capture (x, Base. FastMath. div_fast (a__))
83
- return :($ mod. vfdiv ($ (a... )))
84
- elseif @capture (x, a_ / sqrt (b_))
85
- return :($ a * $ mod. rsqrt ($ b))
86
- elseif @capture (x, inv (sqrt (a_)))
87
- return :($ mod. rsqrt ($ a))
88
- elseif @capture (x, @horner a__)
89
- return SIMDPirates. horner (a... )
90
- elseif @capture (x, Base. Math. muladd (a_, b_, c_))
91
- return :( $ mod. vmuladd ($ a, $ b, $ c) )
92
- elseif isa (x, Symbol) && ! occursin (" @" , string (x))
93
- vec_mod, vec_sym = get (dict, x, (:not_found ,:not_found ))
94
- if vec_sym != :not_found
95
- return :($ mod.$ vec_mod.$ vec_sym)
96
- else
97
- vec_sym = get (VECTOR_SYMBOLS, x, :not_found )
98
- return vec_sym == :not_found ? x : :($ mod. SIMDPirates.$ (vec_sym))
107
+ if x isa Symbol
108
+ vec_mod, vec_sym = get (dict, x) do
109
+ mod, get (VECTOR_SYMBOLS, x) do
110
+ x
111
+ end
99
112
end
113
+ return x === vec_sym ? x : Expr (:(.), vec_mod === mod ? mod : Expr (:(.), mod, QuoteNode (vec_mod)), QuoteNode (vec_sym))
114
+ end
115
+ x isa Expr || return x
116
+ xexpr:: Expr = x
117
+ # if xexpr.head === :macrocall && first(xexpr.args) === Symbol("@horner")
118
+ # return SIMDPirates.horner(@view(xexpr.args[3:end])...)
119
+ # end
120
+ xexpr. head === :call || return x
121
+ f = first (xexpr. args)
122
+ if f == :(Base. FastMath. add_fast)
123
+ vf = :vadd
124
+ elseif f == :(Base. FastMath. sub_fast)
125
+ vf = :vsub
126
+ elseif f == :(Base. FastMath. mul_fast)
127
+ vf = :vmul
128
+ elseif f == :(Base. FastMath. div_fast)
129
+ vf = :vfdiv
130
+ elseif f == :(Base. FastMath. sqrt)
131
+ vf = :vsqrt
132
+ elseif f == :(Base. Math. muladd)
133
+ vf = :vmuladd
100
134
else
101
- return x
135
+ return xexpr
102
136
end
137
+ return Expr (:call , Expr (:(.), mod, QuoteNode (vf)), @view (x. args[2 : end ]). .. )
103
138
end
139
+ # println(ex)
104
140
macro_escape ? esc (ex) : ex
105
141
end
106
142
@@ -129,15 +165,15 @@ end
129
165
130
166
@noinline function vectorize_body (N, Tsym:: Symbol , uf, n, body, vecdict = SLEEFPiratesDict, VType = SVec, gcpreserve:: Bool = true , mod = :LoopVectorization )
131
167
if Tsym == :Float32
132
- vectorize_body (N, Float32, uf, n, body, vecdict, VType, mod)
168
+ vectorize_body (N, Float32, uf, n, body, vecdict, VType, gcpreserve, mod)
133
169
elseif Tsym == :Float64
134
- vectorize_body (N, Float64, uf, n, body, vecdict, VType, mod)
170
+ vectorize_body (N, Float64, uf, n, body, vecdict, VType, gcpreserve, mod)
135
171
else
136
172
throw (" Type $Tsym is not supported." )
137
173
end
138
174
end
139
175
@noinline function vectorize_body (
140
- N, :: Type{T} , unroll_factor:: Int , n:: Symbol , body:: Array{Any} ,
176
+ N, :: Type{T} , unroll_factor:: Int , n:: Symbol , body,
141
177
vecdict:: Dict{Symbol,Tuple{Symbol,Symbol}} = SLEEFPiratesDict,
142
178
@nospecialize (VType = SVec), gcpreserve:: Bool = true , mod = :LoopVectorization
143
179
) where {T}
203
239
# # body preamble must define indexed symbols
204
240
# # we only need that for loads.
205
241
dicts = (indexed_expressions, reduction_symbols, loaded_exprs, loop_constants_dict)
206
- push! (main_body. args,
207
- _vectorloads! (main_body, q, dicts, V, loop_constants_quote, b;
208
- itersym = itersym, declared_iter_sym = n, VectorizationDict = vecdict, mod = mod)
242
+ push! (
243
+ main_body. args,
244
+ _vectorloads! (
245
+ main_body, q, dicts, V, loop_constants_quote, b;
246
+ itersym = itersym, declared_iter_sym = n, VectorizationDict = vecdict, mod = mod
247
+ )
209
248
)# |> x -> (@show(x), _pirate(x)))
210
249
end
211
250
# @show main_body
@@ -350,6 +389,7 @@ function insert_mask(x, masksym, reduction_symbols, default_module = :LoopVector
350
389
local fs:: Symbol , mf:: Expr , f:: Union{Symbol,Expr} , call:: Expr , a:: Symbol
351
390
if x. head === :(= ) # check for reductions
352
391
x. args[2 ] isa Expr || return x
392
+ # @show x
353
393
a = x. args[1 ]
354
394
call = x. args[2 ]
355
395
f = first (call. args)
@@ -617,8 +657,10 @@ function vectorload!(
617
657
else
618
658
throw (" Currently only supports up to 2 indices for some reason." )
619
659
end
620
- elseif f === :zero || f === :one
621
- return Expr (:call , :vbroadcast , V, x)
660
+ elseif f === :zero
661
+ return Expr (:call , Expr (:(.), mod, QuoteNode (:vbroadcast )), V, zero (T))
662
+ elseif f === :one
663
+ return Expr (:call , Expr (:(.), mod, QuoteNode (:vbroadcast )), V, one (T))
622
664
else
623
665
return x
624
666
end
@@ -675,13 +717,84 @@ Returns true if a substitution was made, false otherwise.
675
717
subbed, expr
676
718
end
677
719
678
-
679
- """
680
- Arguments are
681
- @vectorize Type UnrollFactor forloop
682
-
683
- The default type is Float64, and default UnrollFactor is 1 (no unrolling).
684
- """
720
+ # function loop_components(expr::Expr)
721
+ # expr.head === :for || throw("Macro must be applied to a for loop.")
722
+ # iterdef = expr.args[1]
723
+ # itersym = iterdef.args[1]
724
+ # iterrange = iterdef.args[2]
725
+ # @assert iterrange isa Expr
726
+ # @assert length(expr.args) == 2
727
+ # body = expr.args[2]
728
+ # iterlength = if iterrange.head === :call
729
+ # if iterrange.args[1] === :(:)
730
+ # if iterrange.args[2] == 1
731
+ # iterrange.args[3]
732
+ # else
733
+ # Expr(:(-), iterrange.args[3], iterrange.args[2])
734
+ # end
735
+ # elseif iterrange.args[1] === :eachindex
736
+ # if length(iterrange.args) == 2
737
+ # Expr(:call, :length, iterrange.args[2])
738
+ # else
739
+ # il = Expr(:call, :min)
740
+ # for i ∈ 2:length(iterrange.args)
741
+ # push!(il.args, Expr(:call, :length, iterrange.args[i]))
742
+ # end
743
+ # il
744
+ # end
745
+ # else
746
+ # throw("could not match loop expression.")
747
+ # end
748
+ # end
749
+ # @show iterdef, itersym
750
+ # iterlength, itersym, body
751
+ # end
752
+
753
+
754
+ # # Arguments are
755
+ # # @vectorize Type UnrollFactor forloop
756
+
757
+ # # The default type is Float64, and default UnrollFactor is 1 (no unrolling).
758
+
759
+
760
+ # for vec ∈ (false,true)
761
+ # if vec
762
+ # V = Vec
763
+ # macroname = :vvectorize
764
+ # else
765
+ # V = SVec
766
+ # macroname = :vectorize
767
+ # end
768
+ # for gcpreserve ∈ (true,false)
769
+ # if !gcpreserve
770
+ # macroname = Symbol(macroname, :_unsafe)
771
+ # end
772
+ # @eval macro $macroname(expr)
773
+ # iterlength, itersym, body = loop_components(expr)
774
+ # esc(vectorize_body(iterlength, Float64, 1, itersym, body, SLEEFPiratesDict, $V, $gcpreserve))
775
+ # end
776
+ # @eval macro $macroname(type, expr)
777
+ # iterlength, itersym, body = loop_components(expr)
778
+ # esc(vectorize_body(iterlength, type, 1, itersym, body, SLEEFPiratesDict, $V, $gcpreserve))
779
+ # end
780
+ # @eval macro $macroname(unroll_factor::Integer, expr)
781
+ # iterlength, itersym, body = loop_components(expr)
782
+ # esc(vectorize_body(iterlength, Float64, unroll_factor, itersym, body, SLEEFPiratesDict, $V, $gcpreserve))
783
+ # end
784
+ # @eval macro $macroname(type, unroll_factor::Integer, expr)
785
+ # iterlength, itersym, body = loop_components(expr)
786
+ # esc(vectorize_body(iterlength, type, unroll_factor, itersym, body, SLEEFPiratesDict, $V, $gcpreserve))
787
+ # end
788
+ # @eval macro $macroname(type, mod::Union{Symbol,Module}, expr)
789
+ # iterlength, itersym, body = loop_components(expr)
790
+ # esc(vectorize_body(iterlength, type, 1, itersym, body, SLEEFPiratesDict, $V, $gcpreserve, mod))
791
+ # end
792
+ # @eval macro $macroname(type, mod::Union{Symbol,Module}, unroll_factor::Integer, expr)
793
+ # iterlength, itersym, body = loop_components(expr)
794
+ # esc(vectorize_body(iterlength, type, unroll_factor, itersym, body, SLEEFPiratesDict, $V, $gcpreserve, mod))
795
+ # end
796
+ # end
797
+ # end
685
798
686
799
for vec ∈ (false ,true )
687
800
if vec
0 commit comments