@@ -62,21 +62,6 @@ const SLEEFPiratesDict = Dict{Symbol,Tuple{Symbol,Symbol}}(
62
62
63
63
@noinline function _spirate (ex, dict, macro_escape = true , mod = :LoopVectorization )
64
64
ex = postwalk (ex) do x
65
- # @show x
66
- # if @capture(x, LoopVectorization.SIMDPirates.vadd(LoopVectorization.SIMDPirates.vmul(a_, b_), c_)) || @capture(x, LoopVectorization.SIMDPirates.vadd(c_, LoopVectorization.SIMDPirates.vmul(a_, b_)))
67
- # return :(LoopVectorization.SIMDPirates.vmuladd($a, $b, $c))
68
- # elseif @capture(x, LoopVectorization.SIMDPirates.vadd(LoopVectorization.SIMDPirates.vmul(a_, b_), LoopVectorization.SIMDPirates.vmul(c_, d_), e_)) || @capture(x, LoopVectorization.SIMDPirates.vadd(LoopVectorization.SIMDPirates.vmul(a_, b_), e_, LoopVectorization.SIMDPirates.vmul(c_, d_))) || @capture(x, LoopVectorization.SIMDPirates.vadd(e_, LoopVectorization.SIMDPirates.vmul(a_, b_), LoopVectorization.SIMDPirates.vmul(c_, d_)))
69
- # return :(LoopVectorization.SIMDPirates.vmuladd($a, $b, LoopVectorization.SIMDPirates.vmuladd($c, $d, $e)))
70
- # elseif @capture(x, LoopVectorization.SIMDPirates.vadd(LoopVectorization.SIMDPirates.vmul(b_, c_), LoopVectorization.SIMDPirates.vmul(d_, e_), LoopVectorization.SIMDPirates.vmul(f_, g_), a_)) ||
71
- # @capture(x, LoopVectorization.SIMDPirates.vadd(LoopVectorization.SIMDPirates.vmul(b_, c_), LoopVectorization.SIMDPirates.vmul(d_, e_), a_, LoopVectorization.SIMDPirates.vmul(f_, g_))) ||
72
- # @capture(x, LoopVectorization.SIMDPirates.vadd(LoopVectorization.SIMDPirates.vmul(b_, c_), a_, LoopVectorization.SIMDPirates.vmul(d_, e_), LoopVectorization.SIMDPirates.vmul(f_, g_))) ||
73
- # @capture(x, LoopVectorization.SIMDPirates.vadd(a_, LoopVectorization.SIMDPirates.vmul(b_, c_), LoopVectorization.SIMDPirates.vmul(d_, e_), LoopVectorization.SIMDPirates.vmul(f_, g_)))
74
- # return :(LoopVectorization.SIMDPirates.vmuladd($g, $f, LoopVectorization.SIMDPirates.vmuladd($e, $d, LoopVectorization.SIMDPirates.vmuladd($c, $b, $a))))
75
- # elseif @capture(x, a_ * b_ + c_ - c_) || @capture(x, c_ + a_ * b_ - c_) || @capture(x, a_ * b_ - c_ + c_) || @capture(x, - c_ + a_ * b_ + c_)
76
- # return :(LoopVectorization.SIMDPirates.vmul($a, $b))
77
- # elseif @capture(x, a_ * b_ + c_ - d_) || @capture(x, c_ + a_ * b_ - d_) || @capture(x, a_ * b_ - d_ + c_) || @capture(x, - d_ + a_ * b_ + c_) || @capture(x, LoopVectorization.SIMDPirates.vsub(LoopVectorization.SIMDPirates.vmuladd(a_, b_, c_), d_))
78
- # return :(LoopVectorization.SIMDPirates.vmuladd($a, $b, LoopVectorization.SIMDPirates.vsub($c, $d)))
79
- # elseif @capture(x, a_ += b_)
80
65
if @capture (x, a_ += b_)
81
66
return :($ a = $ mod. vadd ($ a, $ b))
82
67
elseif @capture (x, a_ -= b_)
@@ -122,27 +107,12 @@ end
122
107
123
108
124
109
125
- # mask_expr(W, r) = :($(Expr(:tuple, [i > r ? Core.VecElement{Bool}(false) : Core.VecElement{Bool}(true) for i ∈ 1:W]...)))
126
-
127
110
"""
128
111
Returns the strides necessary to iterate across rows.
129
112
Needs `@inferred` testing / that the compiler optimizes it away
130
113
whenever size(A) is known at compile time. Seems to be the case for Julia 1.1.
131
114
"""
132
115
@inline stride_row (A:: AbstractArray ) = size (A,1 )
133
- # @inline function num_row_strides(A::AbstractArray)
134
- # s = size(A)
135
- # N = s[2]
136
- # for i ∈ 3:length(s)
137
- # N *= s[i]
138
- # end
139
- # N
140
- # end
141
- # @inline function stride_row_iter(A::AbstractArray)
142
- # N = num_row_strides(A)
143
- # stride = stride_row(A)
144
- # ntuple(i -> (i-1) * stride, Val(N))
145
- # end
146
116
147
117
function replace_syms_i (expr, set, i)
148
118
postwalk (expr) do ex
159
129
vectorize_body (N, Float32, uf, n, body, vecdict, VType, mod)
160
130
elseif Tsym == :Float64
161
131
vectorize_body (N, Float64, uf, n, body, vecdict, VType, mod)
162
- # elseif Tsym == :ComplexF32
163
- # vectorize_body(N, ComplexF32, uf, n, body, vecdict, VType)
164
- # elseif Tsym == :ComplexF64
165
- # vectorize_body(N, ComplexF64, uf, n, body, vecdict, VType)
166
132
else
167
133
throw (" Type $Tsym is not supported." )
168
134
end
@@ -183,25 +149,14 @@ end
183
149
else
184
150
log2unroll = 0
185
151
end
186
- #
187
- # W *= unroll_factor
188
- # @show W, REGISTER_SIZE, T_size
189
- # @show T
190
152
WT = W * T_size
191
153
V = VType{W,T}
192
154
193
- # @show body
194
-
195
- # body = _pirate(body)
196
-
197
- # indexed_expressions = Dict{Symbol,Expr}()
198
155
indexed_expressions = Dict {Symbol,Symbol} () # Symbol, gensymbol
199
156
200
157
itersym = gensym (:i )
201
158
# walk the expression, searching for all get index patterns.
202
159
# these will be replaced with
203
- # Plan: definition of q will create vectors
204
-
205
160
main_body = quote end
206
161
reduction_symbols = Dict {Tuple{Symbol,Symbol},Symbol} ()
207
162
loaded_exprs = Dict {Expr,Symbol} ()
0 commit comments