@@ -170,7 +170,7 @@ choose_num_blocks(nt, ::StaticInt{NC} = lv_max_num_threads()) where {NC} = @inbo
170
170
171
171
function choose_num_threads (:: Val{C} , :: Val{NT} , x) where {C,NT}
172
172
fx = Base. uitofp (Float64, x)
173
- min (Base. fptoui (UInt, Base. ceil_llvm (5.0852672001495816e-11 * C* Base. sqrt_llvm (fx))), NT)
173
+ min (Base. fptoui (UInt, Base. ceil_llvm (0.05460264079015985 * C* Base. sqrt_llvm (fx))), NT)
174
174
end
175
175
function push_loop_length_expr! (q:: Expr , ls:: LoopSet )
176
176
l = 1
@@ -209,17 +209,18 @@ function divrem_fast(numerator, denominator)
209
209
end
210
210
211
211
function outer_reduct_combine_expressions (ls:: LoopSet , retv)
212
- q = Expr (:block , :(var"#load#thread#ret#" = ThreadingUtilities. load (var"#thread#ptr#" , typeof ($ retv), 64 )))
212
+ gf = GlobalRef (Core, :getfield )
213
+ q = Expr (:block , :(var"#load#thread#ret#" = $ gf (ThreadingUtilities. load (var"#thread#ptr#" , typeof ($ retv), 64 ),2 ,false )))
213
214
for (i,or) ∈ enumerate (ls. outer_reductions)
214
215
op = ls. operations[or]
215
216
var = name (op)
216
217
mvar = mangledvar (op)
217
218
instr = instruction (op)
218
219
out = Symbol (mvar, " ##onevec##" )
219
- instrcall = callexpr ( instr)
220
+ instrcall = Expr ( :call , lv ( reduce_to_onevecunroll ( instr)) )
220
221
push! (instrcall. args, Expr (:call , lv (:vecmemaybe ), out))
221
222
if length (ls. outer_reductions) > 1
222
- push! (instrcall. args, Expr (:call , lv (:vecmemaybe ), Expr (:call , GlobalRef (Core, :getfield ) , Symbol (" #load#thread#ret#" ), i, false )))
223
+ push! (instrcall. args, Expr (:call , lv (:vecmemaybe ), Expr (:call , gf , Symbol (" #load#thread#ret#" ), i, false )))
223
224
else
224
225
push! (instrcall. args, Expr (:call , lv (:vecmemaybe ), Symbol (" #load#thread#ret#" )))
225
226
end
@@ -320,8 +321,15 @@ function thread_one_loops_expr(
320
321
ls:: LoopSet , ua:: UnrollArgs , valid_thread_loop:: Vector{Bool} , ntmax:: UInt , c:: Float64 ,
321
322
UNROLL:: Tuple{Bool,Int8,Int8,Int,Int,Int,Int,Int,Int,Int,UInt} , OPS:: Expr , ARF:: Expr , AM:: Expr , LPSYM:: Expr
322
323
)
323
- choose_nthread = :(choose_num_threads (Val {$c} (), Val {$ntmax} ()))
324
- push_loop_length_expr! (choose_nthread, ls)
324
+ if all (isstaticloop, ls. loops)
325
+ _num_threads = choose_num_threads (Val (c), Val (ntmax), 1 ):: UInt
326
+ _num_threads > 1 || return avx_body (ls, UNROLL)
327
+ choose_nthread = Expr (:(= ), Symbol (" #nthreads#" ), _num_threads)
328
+ else
329
+ choose_nthread = :(choose_num_threads (Val {$(c/looplengthprod(ls))} (), Val {$ntmax} ()))
330
+ push_loop_length_expr! (choose_nthread, ls)
331
+ choose_nthread = Expr (:(= ), Symbol (" #nthreads#" ), choose_nthread)
332
+ end
325
333
threadedid = findfirst (valid_thread_loop):: Int
326
334
threadedloop = getloop (ls, threadedid)
327
335
define_len, define_num_unrolls, loopstart, iterstop, looprange, lastrange = thread_loop_summary! (ls, ua, threadedloop, false )
@@ -336,7 +344,7 @@ function thread_one_loops_expr(
336
344
loop_boundary! (lastboundexpr, loop)
337
345
end
338
346
end
339
- _avx_call_ = :(_avx_! (Val {$UNROLL} (), $ OPS, $ ARF, $ AM, $ LPSYM, $ lastboundexpr, var"#vargs#" ))
347
+ _avx_call_ = :(_avx_! (Val {$UNROLL} (), $ OPS, $ ARF, $ AM, $ LPSYM, ( $ lastboundexpr, var"#vargs#" ) ))
340
348
update_return_values = if length (ls. outer_reductions) > 0
341
349
retv = loopset_return_value (ls, Val (false ))
342
350
_avx_call_ = Expr (:(= ), retv, _avx_call_)
@@ -347,7 +355,7 @@ function thread_one_loops_expr(
347
355
# @unpack u₁loop, u₂loop, vloop, u₁, u₂max = ua
348
356
iterdef = define_block_size (threadedloop, ua. vloop, 0 , ls. vector_width[])
349
357
q = quote
350
- var"#nthreads#" = $ choose_nthread # UInt
358
+ $ choose_nthread # UInt
351
359
$ define_len
352
360
$ define_num_unrolls
353
361
var"#nthreads#" = Base. min (var"#nthreads#" , var"#num#unrolls#thread#0#" )
@@ -365,8 +373,8 @@ function thread_one_loops_expr(
365
373
VectorizationBase. assume (var"#thread#mask#" ≠ zero (var"#thread#mask#" ))
366
374
var"#trailzing#zeros#" = Base. trailing_zeros (var"#thread#mask#" ) % UInt32
367
375
var"#nblock#size#thread#0#" = Core. ifelse (
368
- var"#thread#launch#count#" < (var"#nrem#thread#" % UInt32),
369
- var"#base#block#size#thread#0#" + var"#block#rem#step#" ,
376
+ var"#thread#launch#count#" < (var"#nrem#thread#0# " % UInt32),
377
+ var"#base#block#size#thread#0#" + var"#block#rem#step#0# " ,
370
378
var"#base#block#size#thread#0#"
371
379
)
372
380
var"#trailzing#zeros#" += 0x00000001
@@ -381,7 +389,7 @@ function thread_one_loops_expr(
381
389
var"#thread#mask#" >>>= var"#trailzing#zeros#"
382
390
383
391
var"#iter#start#0#" = var"#iter#stop#0#"
384
- var"#threads#remain#" = (var"#thread#launch#count#" += 0x00000001 ) ≠ var"$ nrequest#"
392
+ var"#threads#remain#" = (var"#thread#launch#count#" += 0x00000001 ) ≠ var"# nrequest#"
385
393
end
386
394
$ _avx_call_
387
395
var"#thread#id#" = 0x00000000
@@ -427,8 +435,15 @@ function thread_two_loops_expr(
427
435
ls:: LoopSet , ua:: UnrollArgs , valid_thread_loop:: Vector{Bool} , ntmax:: UInt , c:: Float64 ,
428
436
UNROLL:: Tuple{Bool,Int8,Int8,Int,Int,Int,Int,Int,Int,Int,UInt} , OPS:: Expr , ARF:: Expr , AM:: Expr , LPSYM:: Expr
429
437
)
430
- choose_nthread = :(choose_num_threads (Val {$c} (), Val {$ntmax} ()))
431
- push_loop_length_expr! (choose_nthread, ls)
438
+ if all (isstaticloop, ls. loops)
439
+ _num_threads = choose_num_threads (Val (c), Val (ntmax), 1 ):: UInt
440
+ _num_threads > 1 || return avx_body (ls, UNROLL)
441
+ choose_nthread = Expr (:(= ), Symbol (" #nthreads#" ), _num_threads)
442
+ else
443
+ choose_nthread = :(choose_num_threads (Val {$(c/looplengthprod(ls))} (), Val {$ntmax} ()))
444
+ push_loop_length_expr! (choose_nthread, ls)
445
+ choose_nthread = Expr (:(= ), Symbol (" #nthreads#" ), choose_nthread)
446
+ end
432
447
threadedid1 = threadedid2 = 0
433
448
for (i,v) ∈ enumerate (valid_thread_loop)
434
449
v || continue
@@ -472,7 +487,7 @@ function thread_two_loops_expr(
472
487
iterdef1 = define_block_size (threadedloop1, vloop, 0 , ls. vector_width[])
473
488
iterdef2 = define_block_size (threadedloop2, vloop, 1 , ls. vector_width[])
474
489
q = quote
475
- var"#nthreads#" = $ choose_nthread # UInt
490
+ $ choose_nthread # UInt
476
491
$ define_len1
477
492
$ define_len2
478
493
$ define_num_unrolls1
0 commit comments