158
158
end
159
159
WT = W * T_size
160
160
V = VType{W,T}
161
-
161
+ vectorize_body (N, Nsym, VType{W,T}, unroll_factor, n, body, vecdict, gcpreserve, Wshift, log2unroll, mod)
162
+ end
163
+ @noinline function vectorize_body (
164
+ N, Nsym, :: Type{V} , unroll_factor, n, body, vecdict, gcpreserve, Wshift, log2unroll, mod
165
+ ) where {W,T,V <: Union{SVec{W,T},Vec{W,T}} }
162
166
indexed_expressions = Dict {Symbol,Symbol} () # Symbol, gensymbol
163
167
164
168
itersym = gensym (:i )
@@ -276,34 +280,53 @@ end
276
280
end
277
281
end
278
282
# ## now we walk the body to look for reductions
283
+ add_reductions! (q, V, reduction_symbols, unroll_factor, mod)
284
+ # display(q)
285
+ # We are using pointers, so better add a GC.@preserve.
286
+ # gcpreserve = true
287
+ # gcpreserve = false
288
+ if gcpreserve
289
+ return quote
290
+ $ (Expr (:macrocall ,
291
+ Expr (:., :GC , QuoteNode (Symbol (" @preserve" ))),
292
+ LineNumberNode (@__LINE__ ), (keys (indexed_expressions)). .. , q
293
+ ))
294
+ nothing
295
+ end
296
+ else
297
+ return q
298
+ end
299
+ end
300
+
301
+ function add_reductions! (q, :: Type{V} , reduction_symbols, unroll_factor, mod) where {W,T,V <: Union{SVec{W,T},Vec{W,T}} }
279
302
if unroll_factor == 1
280
303
for ((sym,op),gsym) ∈ reduction_symbols
281
- if op == :+ || op == :-
304
+ if op === :+ || op = == :-
282
305
pushfirst! (q. args, :($ gsym = $ mod. vbroadcast ($ V,zero ($ T))))
283
- elseif op == :* || op == :/
306
+ elseif op === :* || op = == :/
284
307
pushfirst! (q. args, :($ gsym = $ mod. vbroadcast ($ V,one ($ T))))
285
308
end
286
- if op == :+
309
+ if op === :+
287
310
push! (q. args, :($ sym = Base. FastMath. add_fast ($ sym, $ mod. vsum ($ gsym))))
288
- elseif op == :-
311
+ elseif op === :-
289
312
push! (q. args, :($ sym = Base. FastMath. sub_fast ($ sym, $ mod. vsum ($ gsym))))
290
- elseif op == :*
313
+ elseif op === :*
291
314
push! (q. args, :($ sym = Base. FastMath. mul_fast ($ sym, $ mod. SIMDPirates. vprod ($ gsym))))
292
- elseif op == :/
315
+ elseif op === :/
293
316
push! (q. args, :($ sym = Base. FastMath. div_fast ($ sym, $ mod. SIMDPirates. vprod ($ gsym))))
294
317
end
295
318
end
296
319
else
297
320
for ((sym,op),gsym_base) ∈ reduction_symbols
298
321
for uf ∈ 0 : unroll_factor- 1
299
322
gsym = Symbol (gsym_base, :_ , uf)
300
- if op == :+ || op == :-
323
+ if op === :+ || op = == :-
301
324
pushfirst! (q. args, :($ gsym = $ mod. vbroadcast ($ V,zero ($ T))))
302
- elseif op == :* || op == :/
325
+ elseif op === :* || op = == :/
303
326
pushfirst! (q. args, :($ gsym = $ mod. vbroadcast ($ V,one ($ T))))
304
327
end
305
328
end
306
- func = ((op == :* ) | (op == :/ )) ? :($ mod. evmul) : :($ mod. evadd)
329
+ func = ((op === :* ) | (op = == :/ )) ? :($ mod. evmul) : :($ mod. evadd)
307
330
uf_new = unroll_factor
308
331
while uf_new > 1
309
332
uf_new, uf_prev = uf_new >> 1 , uf_new
@@ -316,33 +339,19 @@ end
316
339
end
317
340
end
318
341
gsym = Symbol (gsym_base, :_ , 0 )
319
- if op == :+
342
+ if op === :+
320
343
push! (q. args, :($ sym = Base. FastMath. add_fast ($ sym, $ mod. vsum ($ gsym))))
321
- elseif op == :-
344
+ elseif op === :-
322
345
push! (q. args, :($ sym = Base. FastMath. sub_fast ($ sym, $ mod. vsum ($ gsym))))
323
- elseif op == :*
346
+ elseif op === :*
324
347
push! (q. args, :($ sym = Base. FastMath. mul_fast ($ sym, $ mod. SIMDPirates. vprod ($ gsym))))
325
- elseif op == :/
348
+ elseif op === :/
326
349
push! (q. args, :($ sym = Base. FastMath. div_fast ($ sym, $ mod. SIMDPirates. vprod ($ gsym))))
327
350
end
328
351
end
329
352
end
330
353
push! (q. args, nothing )
331
- # display(q)
332
- # We are using pointers, so better add a GC.@preserve.
333
- # gcpreserve = true
334
- # gcpreserve = false
335
- if gcpreserve
336
- return quote
337
- $ (Expr (:macrocall ,
338
- Expr (:., :GC , QuoteNode (Symbol (" @preserve" ))),
339
- LineNumberNode (@__LINE__ ), (keys (indexed_expressions)). .. , q
340
- ))
341
- nothing
342
- end
343
- else
344
- return q
345
- end
354
+ nothing
346
355
end
347
356
348
357
function insert_mask (x, masksym, reduction_symbols, default_module = :LoopVectorization )
@@ -617,8 +626,10 @@ function vectorload!(
617
626
else
618
627
throw (" Currently only supports up to 2 indices for some reason." )
619
628
end
620
- elseif f === :zero || f === :one
621
- return Expr (:call , :vbroadcast , V, x)
629
+ elseif f === :zero
630
+ return Expr (:call , Expr (:(.), mod, QuoteNode (:vbroadcast )), V, zero (T))
631
+ elseif f === :one
632
+ return Expr (:call , Expr (:(.), mod, QuoteNode (:vbroadcast )), V, one (T))
622
633
else
623
634
return x
624
635
end
635
646
itersym = :iter , declared_iter_sym = nothing , VectorizationDict = SLEEFPiratesDict, mod = :LoopVectorization
636
647
) where {W,T,V <: Union{Vec{W,T},SVec{W,T}} }
637
648
q = prewalk (expr) do x
649
+ # @show x
638
650
if x isa Symbol
639
651
if x === declared_iter_sym
640
652
isymvec = gensym (itersym)
0 commit comments