@@ -144,6 +144,16 @@ whenever size(A) is known at compile time. Seems to be the case for Julia 1.1.
144
144
# ntuple(i -> (i-1) * stride, Val(N))
145
145
# end
146
146
147
+ function replace_syms_i (expr, set, i)
148
+ postwalk (expr) do ex
149
+ if ex isa Symbol && ex ∈ set
150
+ return Symbol (ex, :_ , i)
151
+ else
152
+ return ex
153
+ end
154
+ end
155
+ end
156
+
147
157
@noinline function vectorize_body (N, Tsym:: Symbol , uf, n, body, vecdict = SLEEFPiratesDict, VType = SVec, mod = :LoopVectorization )
148
158
if Tsym == :Float32
149
159
vectorize_body (N, Float32, uf, n, body, vecdict, VType, mod)
247
257
$ main_body
248
258
$ itersym += $ W
249
259
end
250
- # Eventually we should unroll reductions as well. Should also check if the compiler ever does that for us (doubt it)
251
- for u in 1 : unroll_factor
260
+ if unroll_factor == 1
261
+ push! (unrolled_loop_body_expr. args, unrolled_loop_body_iter)
262
+ else
263
+ ulb = unrolled_loop_body_iter
264
+ rep_syms = Set (values (reduction_symbols))
265
+ unrolled_loop_body_iter = replace_syms_i (ulb, rep_syms, 0 )
252
266
push! (unrolled_loop_body_expr. args, unrolled_loop_body_iter)
267
+ for u in 1 : unroll_factor- 1
268
+ push! (unrolled_loop_body_expr. args, replace_syms_i (ulb, rep_syms, u))
269
+ end
253
270
end
254
271
unadjitersym = gensym (:unadjitersym )
255
272
if loop_max_expr isa Integer && loop_max_expr <= 1
@@ -296,22 +313,55 @@ end
296
313
end
297
314
end
298
315
# ## now we walk the body to look for reductions
299
- for ((sym,op),gsym) ∈ reduction_symbols
300
- if op == :+ || op == :-
301
- pushfirst! (q. args, :($ gsym = $ mod. vbroadcast ($ V,zero ($ T))))
302
- elseif op == :* || op == :/
303
- pushfirst! (q. args, :($ gsym = $ mod. vbroadcast ($ V,one ($ T))))
316
+ if unroll_factor == 1
317
+ for ((sym,op),gsym) ∈ reduction_symbols
318
+ if op == :+ || op == :-
319
+ pushfirst! (q. args, :($ gsym = $ mod. vbroadcast ($ V,zero ($ T))))
320
+ elseif op == :* || op == :/
321
+ pushfirst! (q. args, :($ gsym = $ mod. vbroadcast ($ V,one ($ T))))
322
+ end
323
+ if op == :+
324
+ push! (q. args, :($ sym = Base. FastMath. add_fast ($ sym, $ mod. vsum ($ gsym))))
325
+ elseif op == :-
326
+ push! (q. args, :($ sym = Base. FastMath. sub_fast ($ sym, $ mod. vsum ($ gsym))))
327
+ elseif op == :*
328
+ push! (q. args, :($ sym = Base. FastMath. mul_fast ($ sym, $ mod. SIMDPirates. vprod ($ gsym))))
329
+ elseif op == :/
330
+ push! (q. args, :($ sym = Base. FastMath. div_fast ($ sym, $ mod. SIMDPirates. vprod ($ gsym))))
331
+ end
304
332
end
305
- if op == :+
306
- # push!(q.args, :(@show $sym, $gsym))
307
- push! (q. args, :($ sym = Base. FastMath. add_fast ($ sym, $ mod. vsum ($ gsym))))
308
- # push!(q.args, :(@show $sym, $gsym))
309
- elseif op == :-
310
- push! (q. args, :($ sym = Base. FastMath. sub_fast ($ sym, $ mod. vsum ($ gsym))))
311
- elseif op == :*
312
- push! (q. args, :($ sym = Base. FastMath. mul_fast ($ sym, $ mod. SIMDPirates. vprod ($ gsym))))
313
- elseif op == :/
314
- push! (q. args, :($ sym = Base. FastMath. div_fast ($ sym, $ mod. SIMDPirates. vprod ($ gsym))))
333
+ else
334
+ for ((sym,op),gsym_base) ∈ reduction_symbols
335
+ for uf ∈ 0 : unroll_factor- 1
336
+ gsym = Symbol (gsym_base, :_ , uf)
337
+ if op == :+ || op == :-
338
+ pushfirst! (q. args, :($ gsym = $ mod. vbroadcast ($ V,zero ($ T))))
339
+ elseif op == :* || op == :/
340
+ pushfirst! (q. args, :($ gsym = $ mod. vbroadcast ($ V,one ($ T))))
341
+ end
342
+ end
343
+ func = ((op == :* ) | (op == :/ )) ? :($ mod. vmul) : :($ mod. vadd)
344
+ uf_new = unroll_factor
345
+ while uf_new > 1
346
+ uf_new, uf_prev = uf_new >> 1 , uf_new
347
+ for uf ∈ 0 : uf_new - 1 # reduce half divisible by two
348
+ push! (q. args, Expr (:(= ), Symbol (gsym_base, :_ , uf), Expr (:call , func, Symbol (gsym_base, :_ , 2 uf), Symbol (gsym_base, :_ , 2 uf + 1 ))))
349
+ end
350
+ uf_firstrem = 2 uf_new
351
+ for uf ∈ uf_firstrem: uf_prev - 1
352
+ push! (q. args, Expr (:(= ), Symbol (gsym_base, :_ , uf - uf_firstrem), Expr (:call , func, Symbol (gsym_base, :_ , uf - uf_firstrem), Symbol (gsym_base, :_ , uf))))
353
+ end
354
+ end
355
+ gsym = Symbol (gsym_base, :_ , 0 )
356
+ if op == :+
357
+ push! (q. args, :($ sym = Base. FastMath. add_fast ($ sym, $ mod. vsum ($ gsym))))
358
+ elseif op == :-
359
+ push! (q. args, :($ sym = Base. FastMath. sub_fast ($ sym, $ mod. vsum ($ gsym))))
360
+ elseif op == :*
361
+ push! (q. args, :($ sym = Base. FastMath. mul_fast ($ sym, $ mod. SIMDPirates. vprod ($ gsym))))
362
+ elseif op == :/
363
+ push! (q. args, :($ sym = Base. FastMath. div_fast ($ sym, $ mod. SIMDPirates. vprod ($ gsym))))
364
+ end
315
365
end
316
366
end
317
367
push! (q. args, nothing )
0 commit comments