Skip to content

Commit 7d44eac

Browse files
committed
single loop & reduction threading fixes
1 parent 692488c commit 7d44eac

File tree

2 files changed

+30
-15
lines changed

2 files changed

+30
-15
lines changed

src/codegen/lower_threads.jl

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ choose_num_blocks(nt, ::StaticInt{NC} = lv_max_num_threads()) where {NC} = @inbo
170170

171171
function choose_num_threads(::Val{C}, ::Val{NT}, x) where {C,NT}
172172
fx = Base.uitofp(Float64, x)
173-
min(Base.fptoui(UInt, Base.ceil_llvm(5.0852672001495816e-11*C*Base.sqrt_llvm(fx))), NT)
173+
min(Base.fptoui(UInt, Base.ceil_llvm(0.05460264079015985*C*Base.sqrt_llvm(fx))), NT)
174174
end
175175
function push_loop_length_expr!(q::Expr, ls::LoopSet)
176176
l = 1
@@ -209,17 +209,18 @@ function divrem_fast(numerator, denominator)
209209
end
210210

211211
function outer_reduct_combine_expressions(ls::LoopSet, retv)
212-
q = Expr(:block, :(var"#load#thread#ret#" = ThreadingUtilities.load(var"#thread#ptr#", typeof($retv), 64)))
212+
gf = GlobalRef(Core, :getfield)
213+
q = Expr(:block, :(var"#load#thread#ret#" = $gf(ThreadingUtilities.load(var"#thread#ptr#", typeof($retv), 64),2,false)))
213214
for (i,or) enumerate(ls.outer_reductions)
214215
op = ls.operations[or]
215216
var = name(op)
216217
mvar = mangledvar(op)
217218
instr = instruction(op)
218219
out = Symbol(mvar, "##onevec##")
219-
instrcall = callexpr(instr)
220+
instrcall = Expr(:call, lv(reduce_to_onevecunroll(instr)))
220221
push!(instrcall.args, Expr(:call, lv(:vecmemaybe), out))
221222
if length(ls.outer_reductions) > 1
222-
push!(instrcall.args, Expr(:call, lv(:vecmemaybe), Expr(:call, GlobalRef(Core, :getfield), Symbol("#load#thread#ret#"), i, false)))
223+
push!(instrcall.args, Expr(:call, lv(:vecmemaybe), Expr(:call, gf, Symbol("#load#thread#ret#"), i, false)))
223224
else
224225
push!(instrcall.args, Expr(:call, lv(:vecmemaybe), Symbol("#load#thread#ret#")))
225226
end
@@ -320,8 +321,15 @@ function thread_one_loops_expr(
320321
ls::LoopSet, ua::UnrollArgs, valid_thread_loop::Vector{Bool}, ntmax::UInt, c::Float64,
321322
UNROLL::Tuple{Bool,Int8,Int8,Int,Int,Int,Int,Int,Int,Int,UInt}, OPS::Expr, ARF::Expr, AM::Expr, LPSYM::Expr
322323
)
323-
choose_nthread = :(choose_num_threads(Val{$c}(), Val{$ntmax}()))
324-
push_loop_length_expr!(choose_nthread, ls)
324+
if all(isstaticloop, ls.loops)
325+
_num_threads = choose_num_threads(Val(c), Val(ntmax), 1)::UInt
326+
_num_threads > 1 || return avx_body(ls, UNROLL)
327+
choose_nthread = Expr(:(=), Symbol("#nthreads#"), _num_threads)
328+
else
329+
choose_nthread = :(choose_num_threads(Val{$(c/looplengthprod(ls))}(), Val{$ntmax}()))
330+
push_loop_length_expr!(choose_nthread, ls)
331+
choose_nthread = Expr(:(=), Symbol("#nthreads#"), choose_nthread)
332+
end
325333
threadedid = findfirst(valid_thread_loop)::Int
326334
threadedloop = getloop(ls, threadedid)
327335
define_len, define_num_unrolls, loopstart, iterstop, looprange, lastrange = thread_loop_summary!(ls, ua, threadedloop, false)
@@ -336,7 +344,7 @@ function thread_one_loops_expr(
336344
loop_boundary!(lastboundexpr, loop)
337345
end
338346
end
339-
_avx_call_ = :(_avx_!(Val{$UNROLL}(), $OPS, $ARF, $AM, $LPSYM, $lastboundexpr, var"#vargs#"))
347+
_avx_call_ = :(_avx_!(Val{$UNROLL}(), $OPS, $ARF, $AM, $LPSYM, ($lastboundexpr, var"#vargs#")))
340348
update_return_values = if length(ls.outer_reductions) > 0
341349
retv = loopset_return_value(ls, Val(false))
342350
_avx_call_ = Expr(:(=), retv, _avx_call_)
@@ -347,7 +355,7 @@ function thread_one_loops_expr(
347355
# @unpack u₁loop, u₂loop, vloop, u₁, u₂max = ua
348356
iterdef = define_block_size(threadedloop, ua.vloop, 0, ls.vector_width[])
349357
q = quote
350-
var"#nthreads#" = $choose_nthread # UInt
358+
$choose_nthread # UInt
351359
$define_len
352360
$define_num_unrolls
353361
var"#nthreads#" = Base.min(var"#nthreads#", var"#num#unrolls#thread#0#")
@@ -365,8 +373,8 @@ function thread_one_loops_expr(
365373
VectorizationBase.assume(var"#thread#mask#" zero(var"#thread#mask#"))
366374
var"#trailzing#zeros#" = Base.trailing_zeros(var"#thread#mask#") % UInt32
367375
var"#nblock#size#thread#0#" = Core.ifelse(
368-
var"#thread#launch#count#" < (var"#nrem#thread#" % UInt32),
369-
var"#base#block#size#thread#0#" + var"#block#rem#step#",
376+
var"#thread#launch#count#" < (var"#nrem#thread#0#" % UInt32),
377+
var"#base#block#size#thread#0#" + var"#block#rem#step#0#",
370378
var"#base#block#size#thread#0#"
371379
)
372380
var"#trailzing#zeros#" += 0x00000001
@@ -381,7 +389,7 @@ function thread_one_loops_expr(
381389
var"#thread#mask#" >>>= var"#trailzing#zeros#"
382390

383391
var"#iter#start#0#" = var"#iter#stop#0#"
384-
var"#threads#remain#" = (var"#thread#launch#count#" += 0x00000001) var"$nrequest#"
392+
var"#threads#remain#" = (var"#thread#launch#count#" += 0x00000001) var"#nrequest#"
385393
end
386394
$_avx_call_
387395
var"#thread#id#" = 0x00000000
@@ -427,8 +435,15 @@ function thread_two_loops_expr(
427435
ls::LoopSet, ua::UnrollArgs, valid_thread_loop::Vector{Bool}, ntmax::UInt, c::Float64,
428436
UNROLL::Tuple{Bool,Int8,Int8,Int,Int,Int,Int,Int,Int,Int,UInt}, OPS::Expr, ARF::Expr, AM::Expr, LPSYM::Expr
429437
)
430-
choose_nthread = :(choose_num_threads(Val{$c}(), Val{$ntmax}()))
431-
push_loop_length_expr!(choose_nthread, ls)
438+
if all(isstaticloop, ls.loops)
439+
_num_threads = choose_num_threads(Val(c), Val(ntmax), 1)::UInt
440+
_num_threads > 1 || return avx_body(ls, UNROLL)
441+
choose_nthread = Expr(:(=), Symbol("#nthreads#"), _num_threads)
442+
else
443+
choose_nthread = :(choose_num_threads(Val{$(c/looplengthprod(ls))}(), Val{$ntmax}()))
444+
push_loop_length_expr!(choose_nthread, ls)
445+
choose_nthread = Expr(:(=), Symbol("#nthreads#"), choose_nthread)
446+
end
432447
threadedid1 = threadedid2 = 0
433448
for (i,v) enumerate(valid_thread_loop)
434449
v || continue
@@ -472,7 +487,7 @@ function thread_two_loops_expr(
472487
iterdef1 = define_block_size(threadedloop1, vloop, 0, ls.vector_width[])
473488
iterdef2 = define_block_size(threadedloop2, vloop, 1, ls.vector_width[])
474489
q = quote
475-
var"#nthreads#" = $choose_nthread # UInt
490+
$choose_nthread # UInt
476491
$define_len1
477492
$define_len2
478493
$define_num_unrolls1

src/reconstruct_loopset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ Execute an `@avx` block. The block's code is represented via the arguments:
582582
@generated function _avx_!(
583583
::Val{UNROLL}, ::Val{OPS}, ::Val{ARF}, ::Val{AM}, ::Val{LPSYM}, var"#lv#tuple#args#"::Tuple{LB,V}
584584
) where {UNROLL, OPS, ARF, AM, LPSYM, LB, V}
585-
1 + 1 # Irrelevant line you can comment out/in to force recompilation...
585+
# 1 + 1 # Irrelevant line you can comment out/in to force recompilation...
586586
ls = _avx_loopset(OPS, ARF, AM, LPSYM, LB.parameters, V.parameters, UNROLL)
587587
# return @show avx_body(ls, UNROLL)
588588
if last(UNROLL) > 1

0 commit comments

Comments
 (0)