493
493
494
494
function depends_on_assigned (op:: Operation , assigned:: Vector{Bool} )
495
495
for p ∈ parents (op)
496
+ p === op && continue # don't fall into recursive loop when we have updates, eg a = a + b
496
497
assigned[identifier (op)] && return true
497
498
depends_on_assigned (p, assigned) && return true
498
499
end
516
517
# Using sentinel values (eg, T = -1 for non tiling) in part to avoid recompilation.
517
518
function lower_load! (
518
519
q:: Expr , op:: Operation , W:: Int , unrolled:: Symbol ,
519
- U:: Int , T:: Int = - 1 , tiled:: Symbol = Symbol (" ##UNDEFINED##" )
520
+ U:: Int , T:: Int = - 1 , tiled:: Symbol = Symbol (" ##UNDEFINED##" ), mask = nothing
520
521
)
521
522
loopdeps = loopdependencies (op)
522
523
var = op. variable
@@ -531,15 +532,19 @@ function lower_load!(
531
532
push! (q. args, Expr (:(= ), var, Expr (:call ,:vload ,ptr,memoff)))
532
533
elseif T == - 1
533
534
for u ∈ 0 : U- 1
534
- push! (q. args, Expr (:(= ), Symbol (var,:_ ,u), Expr (:call ,:vload , Val {W} (), ptr, u == 0 ? memoff : push! (copy (memoff), W* u))))
535
+ instrcall = Expr (:call ,:vload , Val {W} (), ptr, u == 0 ? memoff : push! (copy (memoff), W* u))
536
+ mask === nothing || push! (instrcall. args, mask)
537
+ push! (q. args, Expr (:(= ), Symbol (var,:_ ,u), instrcall))
535
538
end
536
539
else # tiling
537
540
for t ∈ 0 : T- 1
538
541
replace_ind_inoffset! (memoff, op, tind, t)
539
542
for u ∈ 0 : U- 1
540
543
memoff2 = copy (memoff)
541
544
u > 0 && push! (memoff2, W* u)
542
- push! (q. args, Expr (:(= ), Symbol (var, :_ , u, :_ , t), Expr (:call , :vload , Val {W} (), ptr, memoff2)))
545
+ instrcall = Expr (:call , :vload , Val {W} (), ptr, memoff2)
546
+ mask === nothing || push! (instrcall. args, mask)
547
+ push! (q. args, Expr (:(= ), Symbol (var, :_ , u, :_ , t), instrcall))
543
548
end
544
549
end
545
550
end
@@ -552,16 +557,22 @@ function lower_load!(
552
557
for u ∈ 0 : U- 1
553
558
memoff2 = copy (memoff)
554
559
u > 0 && push! (memoff2, ustride > 1 ? u* W* ustride : Expr (:call ,:* ,op. symbolic_metadata[upos],u* W) )
555
- push! (q. args, Expr (:(= ), Symbol (var,:_ ,u,:_ ,t), Expr (:call , :gather , ptr, Expr (:call , :vadd , memoff2, ustrides))))
560
+ instrcall = Expr (:call , :gather , ptr, Expr (:call , :vadd , memoff2, ustrides))
561
+ mask === nothing || push! (instrcall. args, mask)
562
+ push! (q. args, Expr (:(= ), Symbol (var,:_ ,u,:_ ,t), instrcall))
556
563
end
557
564
end
558
565
# elseif unitstride(op, tiled) # TODO : we load tiled, and then shuffle
559
566
elseif U == 1 # we gather, no tile, no extra unroll
560
- push! (q. args, Expr (:(= ), var, Expr (:call ,:gather ,ptr,Expr (:call ,:vadd ,memoff,ustrides))))
567
+ instrcall = Expr (:call ,:gather ,ptr,Expr (:call ,:vadd ,memoff,ustrides))
568
+ mask === nothing || push! (instrcall. args, mask)
569
+ push! (q. args, Expr (:(= ), var, instrcall))
561
570
else # we gather, no tile, but extra unroll
562
571
for u ∈ 0 : U- 1
563
572
memoff2 = u == 0 ? memoff : push! (copy (memoff), ustride > 1 ? u* W* ustride : Expr (:call ,:* ,op. symbolic_metadata[upos],u* W) )
564
- push! (q. args, Expr (:(= ), Symbol (var,:_ ,u), Expr (:call , :gather , ptr, Expr (:call ,:vadd ,memoff2,ustrides))))
573
+ instrcall = Expr (:call , :gather , ptr, Expr (:call ,:vadd ,memoff2,ustrides))
574
+ mask === nothing || push! (instrcall. args, mask)
575
+ push! (q. args, Expr (:(= ), Symbol (var,:_ ,u), instrcall))
565
576
end
566
577
end
567
578
end
@@ -570,15 +581,18 @@ function lower_load!(
570
581
# memoff2 = copy(memoff)
571
582
for t ∈ 0 : T- 1
572
583
replace_ind_inoffset! (memoff, op, tind, t)
573
- push! (q. args, Expr (:(= ), Symbol (var,:_ ,t), Expr (:call , :load , ptr, copy (memoff))))
584
+ instrcall = Expr (:call , :load , ptr, copy (memoff))
585
+ # mask === nothing || push!(instrcall.args, mask)
586
+ push! (q. args, Expr (:(= ), Symbol (var,:_ ,t), instrcall))
574
587
end
575
588
else # load scalar; promotion should broadcast as/when neccesary
576
589
push! (q. args, Expr (:(= ), var, Expr (:call , :load , ptr, memoff)))
577
590
end
578
591
end
592
+ # TODO : handle reductions correctly when we're storing non-unrolled parameters!
579
593
function lower_store! (
580
594
q:: Expr , op:: Operation , W:: Int , unrolled:: Symbol ,
581
- U:: Int , T:: Int = - 1 , tiled:: Symbol = Symbol (" ##UNDEFINED##" )
595
+ U:: Int , T:: Int = - 1 , tiled:: Symbol = Symbol (" ##UNDEFINED##" ), mask = nothing
582
596
)
583
597
loopdeps = loopdependencies (op)
584
598
var = first (parents (op)). variable
@@ -593,15 +607,19 @@ function lower_store!(
593
607
push! (q. args, Expr (:(= ), var, Expr (:call ,:vload ,ptr,memoff)))
594
608
elseif T == - 1
595
609
for u ∈ 0 : U- 1
596
- push! (q. args, Expr (:call ,:vstore! , ptr, Symbol (var,:_ ,u), u == 0 ? memoff : push! (copy (memoff), W* u)))
610
+ instrcall = Expr (:call ,:vstore! , ptr, Symbol (var,:_ ,u), u == 0 ? memoff : push! (copy (memoff), W* u))
611
+ mask === nothing || push! (instrcall. args, mask)
612
+ push! (q. args, instrcall)
597
613
end
598
614
else # tiling
599
615
for t ∈ 0 : T- 1
600
616
replace_ind_inoffset! (memoff, op, tind, t)
601
617
for u ∈ 0 : U- 1
602
618
memoff2 = copy (memoff)
603
619
u > 0 && push! (memoff2, W* u)
604
- push! (q. args, Expr (:call , :vstore! , ptr, Symbol (var, :_ , u, :_ , t), memoff2))
620
+ instrcall = Expr (:call , :vstore! , ptr, Symbol (var, :_ , u, :_ , t), memoff2)
621
+ mask === nothing || push! (instrcall. args, mask)
622
+ push! (q. args, instrcall)
605
623
end
606
624
end
607
625
end
@@ -614,69 +632,108 @@ function lower_store!(
614
632
for u ∈ 0 : U- 1
615
633
memoff2 = copy (memoff)
616
634
u > 0 && push! (memoff2, ustride > 1 ? u* W* ustride : Expr (:call ,:* ,op. symbolic_metadata[upos],u* W) )
617
- push! (q. args, Expr (:call , :scatter! , ptr, Symbol (var,:_ ,u,:_ ,t), Expr (:call , :vadd , memoff2, ustrides)))
635
+ instrcall = Expr (:call , :scatter! , ptr, Symbol (var,:_ ,u,:_ ,t), Expr (:call , :vadd , memoff2, ustrides))
636
+ mask === nothing || push! (instrcall. args, mask)
637
+ push! (q. args, instrcall)
618
638
end
619
639
end
620
640
# elseif unitstride(op, tiled) # TODO : we load tiled, and then shuffle
621
641
elseif U == 1 # we gather, no tile, no extra unroll
622
- push! (q. args, Expr (:call ,:scatter! ,ptr, var, Expr (:call ,:vadd ,memoff,ustrides)))
642
+ instrcall = Expr (:call ,:scatter! ,ptr, var, Expr (:call ,:vadd ,memoff,ustrides))
643
+ mask === nothing || push! (instrcall. args, mask)
644
+ push! (q. args, instrcall)
623
645
else # we gather, no tile, but extra unroll
624
646
for u ∈ 0 : U- 1
625
647
memoff2 = u == 0 ? memoff : push! (copy (memoff), ustride > 1 ? u* W* ustride : Expr (:call ,:* ,op. symbolic_metadata[upos],u* W) )
626
- push! (q. args, Expr (:call , :scatter! , ptr, Symbol (var,:_ ,u), Expr (:call ,:vadd ,memoff2,ustrides)))
648
+ instrcall = Expr (:call , :scatter! , ptr, Symbol (var,:_ ,u), Expr (:call ,:vadd ,memoff2,ustrides))
649
+ mask === nothing || push! (instrcall. args, mask)
650
+ push! (q. args, instrcall)
627
651
end
628
652
end
629
653
end
630
- elseif T != - 1 && tiled ∈ loopdeps # load for each tile.
631
- # load per T.
632
- # memoff2 = copy(memoff)
633
- for t ∈ 0 : T- 1
634
- replace_ind_inoffset! (memoff, op, tind, t)
635
- push! (q. args, Expr (:call , :store! , ptr, Symbol (var,:_ ,t), copy (memoff)))
654
+ else
655
+ # need to find out reduction type
656
+ reduct = CORRESPONDING_REDUCTION[first (parents (op)). instruction]
657
+ if T != - 1 && tiled ∈ loopdeps # no unroll, but tiled
658
+ # store per T.
659
+ # memoff2 = copy(memoff)
660
+ for t ∈ 0 : T- 1
661
+ replace_ind_inoffset! (memoff, op, tind, t)
662
+ push! (q. args, Expr (:call , :store! , ptr, Symbol (var,:_ ,t), copy (memoff)))
663
+ end
664
+ else # no unroll
665
+ push! (q. args, Expr (:call , :store! , var, ptr, memoff))
636
666
end
637
- else # load scalar; promotion should broadcast as/when neccesary
638
- push! (q. args, Expr (:call , :store! , var, ptr, memoff))
639
667
end
640
668
end
641
669
# A compute op needs to know the unrolling and tiling status of each of its parents.
642
670
#
643
671
function lower_compute! (
644
672
q:: Expr , op:: Operation , W:: Int , unrolled:: Symbol ,
645
- U:: Int , T:: Int = - 1 , tiled:: Symbol = Symbol (" ##UNDEFINED##" )
673
+ U:: Int , T:: Int = - 1 , tiled:: Symbol = Symbol (" ##UNDEFINED##" ), mask = nothing
646
674
)
647
675
opunrolled = unrolled ∈ loopdependencies (op)
648
676
optiled = tiled ∈ loopdependencies (op)
649
677
var = op. variable
678
+ instr = op. instruction
679
+
650
680
# cache unroll and tiling check of parents
651
681
# not broadcasted, because we use frequent checks of individual bools
652
682
# making BitArrays inefficient.
653
- parentsunrolled = [unrolled ∈ loopdependencies (opp) for oppp ∈ parents (op)]
654
- parentstiled = [tiled ∈ loopdependencies (opp) for oppp ∈ parents (op)]
655
- if opunrolled
656
- if optiled
657
- for t ∈ 0 : T- 1
658
- for u ∈ 0 : U- 1
659
-
683
+ parents_op = parents (op)
684
+ nparents = length (parents_op)
685
+ parentsunrolled = opunrolled ? [unrolled ∈ loopdependencies (opp) for opp ∈ parents_op] : fill (false , nparents)
686
+ parentstiled = optiled ? [tiled ∈ loopdependencies (opp) for opp ∈ parents_op] : fill (false , nparents)
687
+ # parentsyms = [opp.variable for opp ∈ parents(op)]
688
+ Uiter = opunrolled ? U- 1 : o
689
+ Titer = optiled ? T- 1 : 0
690
+ maskreduct = mask != = nothing && any (opp -> opp. variable === var, parents_op)
691
+ # if a parent is not unrolled, the compiler should handle broadcasting CSE.
692
+ # because unrolled/tiled parents result in an unrolled/tiled dependendency,
693
+ # we handle both the tiled and untiled case here.
694
+ # bajillion branches that go the same way on each iteration
695
+ # but smaller function is probably worthwhile. Compiler could theoreically split anyway
696
+ # but I suspect that the branches are so cheap compared to the cost of everything else going on
697
+ # that smaller size is more advantageous.
698
+ for t ∈ 0 : Titer
699
+ for u ∈ 0 : U- 1
700
+ intrcall = Expr (:call , instr)
701
+ for n ∈ 1 : nparents
702
+ parent = parents_op. variable
703
+ if parentsunrolled[n]
704
+ parent = Symbol (parent,:_ ,u)
660
705
end
706
+ if parentstiled[n]
707
+ parent = Symbol (parent,:_ ,t)
708
+ end
709
+ push! (intrcall. args, parent)
710
+ end
711
+ varsym = var
712
+ if opunrolled
713
+ varsym = Symbol (varsym,:_ ,u)
714
+ end
715
+ if optiled
716
+ varsym = Symbol (varsym,:_ ,t)
717
+ end
718
+ if maskreduct
719
+ push! (q. args, Expr (:(= ), varsym, Expr (:call , :vifesle , mask, varsym, instrcall)))
720
+ else
721
+ push! (q. args, Expr (:(= ), varsym, instrcall))
661
722
end
662
- else # not tiled
663
-
664
- end
665
- else # not unrolled
666
- if optiled # but not unrolled
667
-
668
- else # not tiled and not unrolled
669
-
670
723
end
671
724
end
672
725
end
673
- function lower! (q:: Expr , op:: Operation , unrolled:: Symbol , U, T = 1 )
726
+ function lower! (
727
+ q:: Expr , op:: Operation , W:: Int , unrolled:: Symbol ,
728
+ U:: Int , T:: Int = - 1 , tiled:: Symbol = Symbol (" ##UNDEFINED##" ),
729
+ mask = nothing
730
+ )
674
731
if isload (op)
675
- lower_load! (q, op, unrolled, U, T)
732
+ lower_load! (q, op, W, unrolled, U, T, tiled, mask )
676
733
elseif isstore (op)
677
- lower_store! (q, op, unrolled, U, T)
734
+ lower_store! (q, op, W, unrolled, U, T, tiled, mask )
678
735
else
679
- lower_compute! (q, op, unrolled, U, T)
736
+ lower_compute! (q, op, W, unrolled, U, T, tiled, mask )
680
737
end
681
738
end
682
739
@@ -690,13 +747,12 @@ function lower_unroll(ls::LoopSet, order::Vector{Symbol}, U::Int)
690
747
end
691
748
function lower_unroll_inner_block (ls:: LoopSet , order:: Vector{Symbol} , U:: Int )
692
749
# this function create the inner block
693
- args = Any[]
750
+ # args = Any[]
694
751
nloops = length (order)
695
752
unrolled = first (order)
696
753
# included_syms = Set( (unrolled,) )
697
754
included_vars = fill (false , length (operations (ls)))
698
755
# to go inside out, we just have to include all those not-yet included depending on the current sym
699
-
700
756
n = 0
701
757
loopsym = last (order)
702
758
blockq = Expr (:block , )
@@ -711,7 +767,7 @@ function lower_unroll_inner_block(ls::LoopSet, order::Vector{Symbol}, U::Int)
711
767
for n ∈ 1 : nloops - 2
712
768
loopsym = order[nloops - n]
713
769
blockq = Expr (:block , )
714
- loopq = Expr (:for , Expr (:( = ), itersym, looprange), blockq )
770
+ postloop = Expr (:block , )
715
771
for (id,op) ∈ enumerate (operations (ls))
716
772
included_vars[id] && continue
717
773
# We add an op the first time all loop dependencies are met
@@ -720,10 +776,13 @@ function lower_unroll_inner_block(ls::LoopSet, order::Vector{Symbol}, U::Int)
720
776
included_vars[id] = true
721
777
722
778
after_loop = depends_on_assigned (op, included_vars)
723
-
724
-
779
+ after_loop || lower! (blockq, op, unrolled, U)
780
+ after_loop && lower! (postloop, op, unrolled, U)
725
781
end
782
+ push! (blockq. args, loopq_old); append! (blockq. args, postloop. args)
783
+ loopq = Expr (:for , Expr (:(= ), itersym, looprange), blockq)
726
784
end
785
+ loopq
727
786
end
728
787
function lower_unroll_static (ls:: LoopSet , order:: Vector{Symbol} , U:: Int )
729
788
801
860
802
861
803
862
804
- using BenchmarkTools, LoopVectorization, SLEEF
805
- θ = randn (1000 ); c = randn (1000 );
806
- function sumsc_vectorized (θ:: AbstractArray{Float64} , coef:: AbstractArray{Float64} )
807
- s, c = 0.0 , 0.0
808
- @vvectorize for i ∈ eachindex (θ, coef)
809
- sinθᵢ, cosθᵢ = sincos (θ[i])
810
- s += coef[i] * sinθᵢ
811
- c += coef[i] * cosθᵢ
812
- end
813
- s, c
814
- end
815
- function sumsc_serial (θ:: AbstractArray{Float64} , coef:: AbstractArray{Float64} )
816
- s, c = 0.0 , 0.0
817
- @inbounds for i ∈ eachindex (θ, coef)
818
- sinθᵢ, cosθᵢ = sincos (θ[i])
819
- s += coef[i] * sinθᵢ
820
- c += coef[i] * cosθᵢ
821
- end
822
- s, c
823
- end
824
- function sumsc_sleef (θ:: AbstractArray{Float64} , coef:: AbstractArray{Float64} )
825
- s, c = 0.0 , 0.0
826
- @inbounds @simd for i ∈ eachindex (θ, coef)
827
- sinθᵢ, cosθᵢ = SLEEF. sincos_fast (θ[i])
828
- s += coef[i] * sinθᵢ
829
- c += coef[i] * cosθᵢ
830
- end
831
- s, c
832
- end
833
-
834
- @btime sumsc_serial ($ θ, $ c)
835
- @btime sumsc_sleef ($ θ, $ c)
836
- @btime sumsc_vectorized ($ θ, $ c)
837
-
838
-
0 commit comments