Skip to content

Commit 23724da

Browse files
committed
Remove extra call in threaded code.
1 parent ea843c7 commit 23724da

File tree

2 files changed

+85
-74
lines changed

2 files changed

+85
-74
lines changed

src/codegen/lower_threads.jl

Lines changed: 84 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -349,15 +349,15 @@ function thread_one_loops_expr(
349349
loop_boundary!(lastboundexpr, loop)
350350
end
351351
end
352-
_avx_call_core_ = :(_avx_!(Val{$UNROLL}(), $OPS, $ARF, $AM, $LPSYM, ($lastboundexpr, var"#vargs#")))
353-
_avx_call_ = _avx_call_core_
352+
_avx_call_ = :(_avx_!(Val{$UNROLL}(), $OPS, $ARF, $AM, $LPSYM, ($lastboundexpr, var"#vargs#")))
354353
update_return_values = if length(ls.outer_reductions) > 0
355354
retv = loopset_return_value(ls, Val(false))
356355
_avx_call_ = Expr(:(=), retv, _avx_call_)
357356
outer_reduct_combine_expressions(ls, retv)
358357
else
359358
nothing
360359
end
360+
retexpr = length(ls.outer_reductions) > 0 ? :(return $retv) : :(return nothing)
361361
# @unpack u₁loop, u₂loop, vloop, u₁, u₂max = ua
362362
iterdef = define_block_size(threadedloop, ua.vloop, 0, ls.vector_width[])
363363
q = quote
@@ -367,37 +367,43 @@ function thread_one_loops_expr(
367367
var"#nthreads#" = Base.min(var"#nthreads#", var"#num#unrolls#thread#0#")
368368
var"#nrequest#" = (var"#nthreads#" % UInt32) - 0x00000001
369369
$loopstart
370-
var"#nrequest#" == 0x00000000 && return $_avx_call_core_
371-
var"#threads#", var"#torelease#" = CheapThreads.request_threads(Threads.threadid()%UInt32, var"#nrequest#")
372-
var"#thread#factor#0#" = var"#nthreads#"
373-
$iterdef
374-
var"#thread#launch#count#" = 0x00000000
375-
var"#thread#id#" = 0x00000000
376-
var"#thread#mask#" = CheapThreads.mask(var"#threads#")
377-
var"#threads#remain#" = true
378-
while var"#threads#remain#"
379-
VectorizationBase.assume(var"#thread#mask#" zero(var"#thread#mask#"))
380-
var"#trailzing#zeros#" = Base.trailing_zeros(var"#thread#mask#") % UInt32
381-
var"#nblock#size#thread#0#" = Core.ifelse(
382-
var"#thread#launch#count#" < (var"#nrem#thread#0#" % UInt32),
383-
var"#base#block#size#thread#0#" + var"#block#rem#step#0#",
384-
var"#base#block#size#thread#0#"
385-
)
386-
var"#trailzing#zeros#" += 0x00000001
387-
$iterstop
388-
var"#thread#id#" += var"#trailzing#zeros#"
370+
var"##do#thread##" = var"#nrequest#" 0x00000000
371+
if var"##do#thread##"
372+
var"#threads#", var"#torelease#" = CheapThreads.request_threads(Threads.threadid()%UInt32, var"#nrequest#")
373+
var"#thread#factor#0#" = var"#nthreads#"
374+
$iterdef
375+
var"#thread#launch#count#" = 0x00000000
376+
var"#thread#id#" = 0x00000000
377+
var"#thread#mask#" = CheapThreads.mask(var"#threads#")
378+
var"#threads#remain#" = true
379+
while var"#threads#remain#"
380+
VectorizationBase.assume(var"#thread#mask#" zero(var"#thread#mask#"))
381+
var"#trailzing#zeros#" = Base.trailing_zeros(var"#thread#mask#") % UInt32
382+
var"#nblock#size#thread#0#" = Core.ifelse(
383+
var"#thread#launch#count#" < (var"#nrem#thread#0#" % UInt32),
384+
var"#base#block#size#thread#0#" + var"#block#rem#step#0#",
385+
var"#base#block#size#thread#0#"
386+
)
387+
var"#trailzing#zeros#" += 0x00000001
388+
$iterstop
389+
var"#thread#id#" += var"#trailzing#zeros#"
389390

390-
avx_launch(
391-
Val{$UNROLL}(), $OPS, $ARF, $AM, $LPSYM,
392-
$loopboundexpr, var"#vargs#", var"#thread#id#"
393-
)
391+
avx_launch(
392+
Val{$UNROLL}(), $OPS, $ARF, $AM, $LPSYM,
393+
$loopboundexpr, var"#vargs#", var"#thread#id#"
394+
)
394395

395-
var"#thread#mask#" >>>= var"#trailzing#zeros#"
396+
var"#thread#mask#" >>>= var"#trailzing#zeros#"
396397

397-
var"#iter#start#0#" = var"#iter#stop#0#"
398-
var"#threads#remain#" = (var"#thread#launch#count#" += 0x00000001) var"#nrequest#"
398+
var"#iter#start#0#" = var"#iter#stop#0#"
399+
var"#threads#remain#" = (var"#thread#launch#count#" += 0x00000001) var"#nrequest#"
400+
end
401+
else# eliminate undef var errors that the compiler should be able to figure out are unreachable, but doesn't
402+
var"#torelease#" = zero(CheapThreads.worker_type())
403+
var"#threads#" = CheapThreads.UnsignedIteratorEarlyStop(var"#torelease#", 0x00000000)
399404
end
400405
$_avx_call_
406+
var"##do#thread##" || $retexpr
401407
var"#thread#id#" = 0x00000000
402408
var"#thread#mask#" = CheapThreads.mask(var"#threads#")
403409
var"#threads#remain#" = true
@@ -413,8 +419,8 @@ function thread_one_loops_expr(
413419
var"#threads#remain#" = var"#thread#mask#" 0x00000000
414420
end
415421
CheapThreads.free_threads!(var"#torelease#")
422+
$retexpr
416423
end
417-
length(ls.outer_reductions) > 0 ? push!(q.args, retv) : push!(q.args, nothing)
418424
Expr(:block, ls.preamble, q)
419425
end
420426
function define_vthread_blocks(vloop, u₁loop, u₂loop, u₁, u₂, ntmax, tn)
@@ -484,8 +490,7 @@ function thread_two_loops_expr(
484490
loop_boundary!(lastboundexpr, loop)
485491
end
486492
end
487-
_avx_call_core_ = :(_avx_!(Val{$UNROLL}(), $OPS, $ARF, $AM, $LPSYM, ($lastboundexpr, var"#vargs#")))
488-
_avx_call_ = _avx_call_core_
493+
_avx_call_ = :(_avx_!(Val{$UNROLL}(), $OPS, $ARF, $AM, $LPSYM, ($lastboundexpr, var"#vargs#")))
489494
update_return_values = if length(ls.outer_reductions) > 0
490495
retv = loopset_return_value(ls, Val(false))
491496
_avx_call_ = Expr(:(=), retv, _avx_call_)
@@ -496,6 +501,7 @@ function thread_two_loops_expr(
496501
blockdef = define_thread_blocks(threadedloop1, threadedloop2, vloop, u₁loop, u₂loop, u₁, u₂, ntmax)
497502
iterdef1 = define_block_size(threadedloop1, vloop, 0, ls.vector_width[])
498503
iterdef2 = define_block_size(threadedloop2, vloop, 1, ls.vector_width[])
504+
retexpr = length(ls.outer_reductions) > 0 ? :(return $retv) : :(return nothing)
499505
q = quote
500506
$choose_nthread # UInt
501507
$define_len1
@@ -515,54 +521,59 @@ function thread_two_loops_expr(
515521
$loopstart1
516522
var"#loop#1#start#init#" = var"#iter#start#0#"
517523
$loopstart2
518-
var"#nrequest#" == 0x00000000 && return $_avx_call_core_
519-
var"#threads#", var"#torelease#" = CheapThreads.request_threads(Threads.threadid(), var"#nrequest#")
524+
var"##do#thread##" = var"#nrequest#" 0x00000000
525+
if var"##do#thread##"
526+
var"#threads#", var"#torelease#" = CheapThreads.request_threads(Threads.threadid(), var"#nrequest#")
527+
$iterdef1
528+
$iterdef2
529+
# @show var"#base#block#size#thread#0#", var"#block#rem#step#0#" var"#base#block#size#thread#1#", var"#block#rem#step#1#"
530+
var"#thread#launch#count#" = 0x00000000
531+
var"#thread#launch#count#0#" = 0x00000000
532+
var"#thread#launch#count#1#" = 0x00000000
533+
var"#thread#id#" = 0x00000000
534+
var"#thread#mask#" = CheapThreads.mask(var"#threads#")
535+
var"#threads#remain#" = true
536+
while var"#threads#remain#"
537+
VectorizationBase.assume(var"#thread#mask#" zero(var"#thread#mask#"))
538+
var"#trailzing#zeros#" = Base.trailing_zeros(var"#thread#mask#") % UInt32
539+
var"#nblock#size#thread#0#" = Core.ifelse(
540+
var"#thread#launch#count#0#" < (var"#nrem#thread#0#" % UInt32),
541+
var"#base#block#size#thread#0#" + var"#block#rem#step#0#",
542+
var"#base#block#size#thread#0#"
543+
)
544+
var"#nblock#size#thread#1#" = Core.ifelse(
545+
var"#thread#launch#count#1#" < (var"#nrem#thread#1#" % UInt32),
546+
var"#base#block#size#thread#1#" + var"#block#rem#step#1#",
547+
var"#base#block#size#thread#1#"
548+
)
549+
var"#trailzing#zeros#" += 0x00000001
550+
$iterstop1
551+
$iterstop2
552+
var"#thread#id#" += var"#trailzing#zeros#"
553+
# @show var"#thread#id#" $loopboundexpr
554+
avx_launch(
555+
Val{$UNROLL}(), $OPS, $ARF, $AM, $LPSYM,
556+
$loopboundexpr, var"#vargs#", var"#thread#id#"
557+
)
520558

521-
$iterdef1
522-
$iterdef2
523-
# @show var"#base#block#size#thread#0#", var"#block#rem#step#0#" var"#base#block#size#thread#1#", var"#block#rem#step#1#"
524-
var"#thread#launch#count#" = 0x00000000
525-
var"#thread#launch#count#0#" = 0x00000000
526-
var"#thread#launch#count#1#" = 0x00000000
527-
var"#thread#id#" = 0x00000000
528-
var"#thread#mask#" = CheapThreads.mask(var"#threads#")
529-
var"#threads#remain#" = true
530-
while var"#threads#remain#"
531-
VectorizationBase.assume(var"#thread#mask#" zero(var"#thread#mask#"))
532-
var"#trailzing#zeros#" = Base.trailing_zeros(var"#thread#mask#") % UInt32
533-
var"#nblock#size#thread#0#" = Core.ifelse(
534-
var"#thread#launch#count#0#" < (var"#nrem#thread#0#" % UInt32),
535-
var"#base#block#size#thread#0#" + var"#block#rem#step#0#",
536-
var"#base#block#size#thread#0#"
537-
)
538-
var"#nblock#size#thread#1#" = Core.ifelse(
539-
var"#thread#launch#count#1#" < (var"#nrem#thread#1#" % UInt32),
540-
var"#base#block#size#thread#1#" + var"#block#rem#step#1#",
541-
var"#base#block#size#thread#1#"
542-
)
543-
var"#trailzing#zeros#" += 0x00000001
544-
$iterstop1
545-
$iterstop2
546-
var"#thread#id#" += var"#trailzing#zeros#"
547-
# @show var"#thread#id#" $loopboundexpr
548-
avx_launch(
549-
Val{$UNROLL}(), $OPS, $ARF, $AM, $LPSYM,
550-
$loopboundexpr, var"#vargs#", var"#thread#id#"
551-
)
559+
var"#thread#mask#" >>>= var"#trailzing#zeros#"
552560

553-
var"#thread#mask#" >>>= var"#trailzing#zeros#"
561+
var"##end#inner##" = var"#thread#launch#count#0#" == (var"#thread#factor#0#"-0x00000001)
562+
var"#thread#launch#count#0#" = Core.ifelse(var"##end#inner##", 0x00000000, var"#thread#launch#count#0#" + 0x00000001)
563+
var"#thread#launch#count#1#" = Core.ifelse(var"##end#inner##", var"#thread#launch#count#1#" + 0x00000001, var"#thread#launch#count#1#")
554564

555-
var"##end#inner##" = var"#thread#launch#count#0#" == (var"#thread#factor#0#"-0x00000001)
556-
var"#thread#launch#count#0#" = Core.ifelse(var"##end#inner##", 0x00000000, var"#thread#launch#count#0#" + 0x00000001)
557-
var"#thread#launch#count#1#" = Core.ifelse(var"##end#inner##", var"#thread#launch#count#1#" + 0x00000001, var"#thread#launch#count#1#")
565+
var"#iter#start#0#" = Core.ifelse(var"##end#inner##", var"#loop#1#start#init#", var"#iter#stop#0#")
566+
var"#iter#start#1#" = Core.ifelse(var"##end#inner##", var"#iter#stop#1#", var"#iter#start#1#")
558567

559-
var"#iter#start#0#" = Core.ifelse(var"##end#inner##", var"#loop#1#start#init#", var"#iter#stop#0#")
560-
var"#iter#start#1#" = Core.ifelse(var"##end#inner##", var"#iter#stop#1#", var"#iter#start#1#")
561-
562-
var"#threads#remain#" = (var"#thread#launch#count#" += 0x00000001) var"#nrequest#"
568+
var"#threads#remain#" = (var"#thread#launch#count#" += 0x00000001) var"#nrequest#"
569+
end
570+
else# eliminate undef var errors that the compiler should be able to figure out are unreachable, but doesn't
571+
var"#torelease#" = zero(CheapThreads.worker_type())
572+
var"#threads#" = CheapThreads.UnsignedIteratorEarlyStop(var"#torelease#", 0x00000000)
563573
end
564574
# @show $lastboundexpr
565575
$_avx_call_
576+
var"##do#thread##" || $retexpr
566577
# @show $retv
567578
var"#thread#id#" = 0x00000000
568579
var"#thread#mask#" = CheapThreads.mask(var"#threads#")
@@ -579,8 +590,8 @@ function thread_two_loops_expr(
579590
var"#threads#remain#" = var"#thread#mask#" 0x00000000
580591
end
581592
CheapThreads.free_threads!(var"#torelease#")
593+
$retexpr
582594
end
583-
length(ls.outer_reductions) > 0 ? push!(q.args, retv) : push!(q.args, nothing)
584595
# @show
585596
Expr(:block, ls.preamble, q)
586597
end

src/reconstruct_loopset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ Execute an `@avx` block. The block's code is represented via the arguments:
623623
@generated function _avx_!(
624624
::Val{UNROLL}, ::Val{OPS}, ::Val{ARF}, ::Val{AM}, ::Val{LPSYM}, var"#lv#tuple#args#"::Tuple{LB,V}
625625
) where {UNROLL, OPS, ARF, AM, LPSYM, LB, V}
626-
1 + 1 # Irrelevant line you can comment out/in to force recompilation...
626+
# 1 + 1 # Irrelevant line you can comment out/in to force recompilation...
627627
ls = _avx_loopset(OPS, ARF, AM, LPSYM, LB.parameters, V.parameters, UNROLL)
628628
# return @show avx_body(ls, UNROLL)
629629
if last(UNROLL) > 1

0 commit comments

Comments
 (0)