@@ -137,6 +137,7 @@ function getroots!(rooted::Vector{Bool}, ls::LoopSet)
137
137
for op ∈ ops
138
138
isstore (op) && recursively_set_parents_true! (rooted, op)
139
139
end
140
+ length (ls. includedactualarrays) == 0 || remove_outer_reducts! (rooted, ls)
140
141
return rooted
141
142
end
142
143
function OperationStruct! (varnames:: Vector{Symbol} , ids:: Vector{Int} , ls:: LoopSet , op:: Operation )
@@ -461,28 +462,79 @@ function remove_outer_reducts!(roots::Vector{Bool}, ls::LoopSet)
461
462
end
462
463
end
463
464
464
- # function generate_call_split(ls::LoopSet, (inline,u₁,u₂)::Tuple{Bool,Int8,Int8}, thread::UInt, debug::Bool = false)
465
- # ops = operations(ls)
466
- # for op ∈ ops
467
- # if (iscompute(op) && (instruction(op).instr === :ifelse)) && iszero(length(loopdependencies(first(parents(op)))))
468
- # # we want to eliminate
469
- # end
470
- # end
471
- # end
472
465
473
- # Try to condense in type stable manner
474
- function generate_call (ls:: LoopSet , (inline,u₁,u₂):: Tuple{Bool,Int8,Int8} , thread:: UInt , debug:: Bool = false )
475
- extra_args = Expr (:tuple )
476
- preserve, shouldindbyind, roots = add_grouped_strided_pointer! (extra_args, ls)
477
466
467
+ function split_ifelse! (
468
+ ls:: LoopSet , preserve:: Vector{Symbol} , shouldindbyind:: Vector{Bool} , roots:: Vector{Bool} , extra_args:: Expr , k:: Int , inlineu₁u₂:: Tuple{Bool,Int8,Int8} , thread:: UInt , debug:: Bool
469
+ )
470
+ roots[k] = false
471
+ op = operations (ls)[k]
472
+ op. instruction = DROPPEDCONSTANT
473
+ op. node_type = constant
474
+ # we want to eliminate
475
+ parents_op = parents (op)
476
+ condop = first (parents_op)
477
+ # create one loop where `opp` is true, and a second where it is `false`
478
+ prepre = ls. prepreamble; append! (prepre. args, ls. preamble. args)
479
+ ls. prepreamble = Expr (:block ); ls. preamble = Expr (:block );
480
+ ls_true = deepcopy (ls)
481
+ lsfalse = ls
482
+ true_ops = operations (ls_true)
483
+ falseops = operations (lsfalse)
484
+ true_op = parents (true_ops[k])[2 ]
485
+ falseop = parents_op[3 ]
486
+ condop_count = 0
487
+ for i ∈ eachindex (falseops)
488
+ fop = falseops[i]
489
+ parents_false = parents (fop)
490
+ for (j,opp) ∈ enumerate (parents_false)
491
+ if opp === op # then ops[i]'s jth parent is the ifelse
492
+ parents (true_ops[i])[j] = true_op
493
+ parents_false[j] = falseop
494
+ end
495
+ condop_count += roots[i] & (condop === opp)
496
+ end
497
+ end
498
+ roots[identifier (condop)] &= condop_count > 0
499
+ q = :(if $ (name (condop))
500
+ $ (generate_call_split (ls_true, preserve, shouldindbyind, roots, copy (extra_args), inlineu₁u₂, thread, debug))
501
+ else
502
+ $ (generate_call_split (lsfalse, preserve, shouldindbyind, roots, extra_args, inlineu₁u₂, thread, debug))
503
+ end )
504
+ push! (prepre. args, q)
505
+ prepre
506
+ end
507
+
508
+ function generate_call (ls:: LoopSet , inlineu₁u₂:: Tuple{Bool,Int8,Int8} , thread:: UInt , debug:: Bool )
509
+ extra_args = Expr (:tuple )
510
+ preserve, shouldindbyind, roots = add_grouped_strided_pointer! (extra_args, ls)
511
+ generate_call_split (ls, preserve, shouldindbyind, roots, extra_args, inlineu₁u₂, thread, debug)
512
+ end
513
+ function generate_call_split (
514
+ ls:: LoopSet , preserve:: Vector{Symbol} , shouldindbyind:: Vector{Bool} , roots:: Vector{Bool} , extra_args:: Expr , inlineu₁u₂:: Tuple{Bool,Int8,Int8} , thread:: UInt , debug:: Bool
515
+ )
516
+ if ! debug
517
+ for (k,op) ∈ enumerate (operations (ls))
518
+ parents_op = parents (op)
519
+ if (iscompute (op) && (instruction (op). instr === :ifelse )) && (length (parents_op) == 3 ) && isconstantop (first (parents_op))
520
+ return split_ifelse! (ls, preserve, shouldindbyind, roots, extra_args, k, inlineu₁u₂, thread, debug)
521
+ end
522
+ end
523
+ end
524
+ return generate_call_types (ls, preserve, shouldindbyind, roots, extra_args, inlineu₁u₂, thread, debug)
525
+ end
526
+ # Try to condense in type stable manner
527
+ function generate_call_types (
528
+ ls:: LoopSet , preserve:: Vector{Symbol} , shouldindbyind:: Vector{Bool} , roots:: Vector{Bool} , extra_args:: Expr , (inline,u₁,u₂):: Tuple{Bool,Int8,Int8} , thread:: UInt , debug:: Bool
529
+ )
530
+ # good place to check for split
478
531
operation_descriptions = Expr (:tuple )
479
532
varnames = Symbol[]; ids = Vector {Int} (undef, length (operations (ls)))
480
533
ops = operations (ls)
481
- length (ls. includedactualarrays) == 0 || remove_outer_reducts! (roots, ls)
482
534
for op ∈ ops
483
535
instr:: Instruction = instruction (op)
484
536
if (isconstant (op) && (instr == LOOPCONSTANT)) && (! roots[identifier (op)])
485
- instr = op. instruction = DROPPEDCONSTANT
537
+ instr = op. instruction = DROPPEDCONSTANT
486
538
end
487
539
push! (operation_descriptions. args, QuoteNode (instr. mod))
488
540
push! (operation_descriptions. args, QuoteNode (instr. instr))
@@ -505,7 +557,7 @@ function generate_call(ls::LoopSet, (inline,u₁,u₂)::Tuple{Bool,Int8,Int8}, t
505
557
configarg = (inline,u₁,u₂,ls. isbroadcast,thread)
506
558
unroll_param_tup = Expr (:call , lv (:avx_config_val ), :(Val {$configarg} ()), VECTORWIDTHSYMBOL)
507
559
q = Expr (:call , func, unroll_param_tup, val (operation_descriptions), val (arrayref_descriptions), val (argmeta), val (loop_syms))
508
-
560
+
509
561
add_reassigned_syms! (extra_args, ls) # counterpart to `add_ops!` constants
510
562
for (opid,sym) ∈ ls. preamble_symsym # counterpart to process_metadata! symsym extraction
511
563
if instruction (ops[opid]) ≠ DROPPEDCONSTANT
@@ -517,7 +569,13 @@ function generate_call(ls::LoopSet, (inline,u₁,u₂)::Tuple{Bool,Int8,Int8}, t
517
569
push! (q. args, Expr (:tuple , lbarg, extra_args))
518
570
vecwidthdefq = Expr (:block )
519
571
define_eltype_vec_width! (vecwidthdefq, ls, nothing )
520
- Expr (:block , vecwidthdefq, q), preserve
572
+ push! (vecwidthdefq. args, q)
573
+ if debug
574
+ pushpreamble! (ls,vecwidthdefq)
575
+ Expr (:block , ls. prepreamble, ls. preamble)
576
+ else
577
+ setup_call_final (ls, setup_outerreduct_preserve (ls, vecwidthdefq, preserve))
578
+ end
521
579
end
522
580
523
581
@@ -581,34 +639,32 @@ function gc_preserve(call::Expr, preserve::Vector{Symbol})
581
639
q
582
640
end
583
641
584
- function setup_call_inline (ls:: LoopSet , inline:: Bool , u₁:: Int8 , u₂:: Int8 , thread:: Int )
585
- call, preserve = generate_call (ls, (inline,u₁,u₂), thread % UInt, false )
586
- if iszero ( length (ls . outer_reductions) )
587
- pushpreamble! (ls, gc_preserve (call, preserve))
588
- push! (ls. preamble . args, nothing )
589
- return ls . preamble
590
- end
591
- retv = loopset_return_value (ls, Val ( false ))
592
- outer_reducts = Expr ( :local )
593
- q = Expr ( :block , gc_preserve ( Expr (:( = ), retv, call), preserve))
594
- for or ∈ ls . outer_reductions
595
- op = ls . operations[or]
596
- var = name (op)
597
- # push!(call.args, Symbol("##TYPEOF##", var) )
598
- mvar = mangledvar (op )
599
- instr = instruction (op )
600
- out = Symbol (mvar, " ##onevec## " )
601
- push! (outer_reducts . args, out)
602
- push! (q . args, Expr (:( = ), var, Expr ( :call , lv ( reduction_scalar_combine (instr)), Expr ( :call , lv ( :vecmemaybe ), out), var)))
603
- end
604
- pushpreamble! (ls, outer_reducts )
605
- append ! (ls. preamble. args, q . args )
606
- ls. preamble
642
+ # function setup_call_inline(ls::LoopSet, inline::Bool, u₁::Int8, u₂::Int8, thread::Int)
643
+ # call, preserve = generate_call_split (ls, (inline,u₁,u₂), thread % UInt, false)
644
+ # setup_call_ret!(ls, call, preserve )
645
+ # end
646
+ function setup_outerreduct_preserve (ls:: LoopSet , call :: Expr , preserve :: Vector{Symbol} )
647
+ iszero ( length (ls . outer_reductions)) && return gc_preserve (call, preserve)
648
+ retv = loopset_return_value (ls, Val ( false ))
649
+ q = Expr ( :block , gc_preserve ( Expr (:( = ), retv, call), preserve ))
650
+ for or ∈ ls . outer_reductions
651
+ op = ls . operations[or]
652
+ var = name (op)
653
+ # push!(call.args, Symbol("##TYPEOF##", var))
654
+ mvar = mangledvar (op)
655
+ instr = instruction (op )
656
+ out = Symbol (mvar, " ##onevec## " )
657
+ push! (q . args, Expr (:( = ), var, Expr ( :call , lv ( reduction_scalar_combine ( instr)), Expr ( :call , lv ( :vecmemaybe ), out), var)) )
658
+ end
659
+ q
660
+ end
661
+ function setup_call_final (ls :: LoopSet , q :: Expr )
662
+ pushpreamble! (ls, q )
663
+ push ! (ls. preamble. args, nothing )
664
+ return ls. preamble
607
665
end
608
666
function setup_call_debug (ls:: LoopSet )
609
- # avx_loopset(instr, ops, arf, AM, LB, vargs)
610
- pushpreamble! (ls, first (generate_call (ls, (false ,zero (Int8),zero (Int8)), zero (UInt), true )))
611
- Expr (:block , ls. prepreamble, ls. preamble)
667
+ generate_call (ls, (false ,zero (Int8),zero (Int8)), zero (UInt), true )
612
668
end
613
669
function setup_call (
614
670
ls:: LoopSet , q:: Expr , source:: LineNumberNode , inline:: Bool , check_empty:: Bool , u₁:: Int8 , u₂:: Int8 , thread:: Int
@@ -620,9 +676,9 @@ function setup_call(
620
676
# inlining the generated function into the loop preamble.
621
677
lnns = extract_all_lnns (q)
622
678
pushfirst! (lnns, source)
623
- call = setup_call_inline (ls, inline, u₁, u₂, thread)
679
+ call = generate_call (ls, ( inline, u₁, u₂) , thread% UInt, false )
624
680
call = check_empty ? check_if_empty (ls, call) : call
625
- result = Expr ( :block , ls . prepreamble , Expr (:if , check_args_call (ls), call, make_crashy (make_fast (q))))
626
- prepend_lnns! (result , lnns)
627
- return result
681
+ pushprepreamble! (ls , Expr (:if , check_args_call (ls), call, make_crashy (make_fast (q))))
682
+ prepend_lnns! (ls . prepreamble , lnns)
683
+ return ls . prepreamble
628
684
end
0 commit comments