Skip to content

Commit dc0252b

Browse files
committed
Minor progress on lowering stores.
1 parent a3e2a86 commit dc0252b

File tree

2 files changed

+120
-82
lines changed

2 files changed

+120
-82
lines changed

src/costs.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,21 @@ for (k, v) ∈ COST # so we can look up Symbol(typeof(function))
9696
COST[Symbol("typeof(", k, ")")] = v
9797
end
9898

99-
99+
const CORRESPONDING_REDUCTION = Dict{Symbol,Symbol}(
100+
:(+) => :vsum,
101+
:(-) => :vsum,
102+
:(*) => :vprod,
103+
:(&) => :vall,
104+
:(|) => :vany,
105+
:muladd => :vsum,
106+
:fma => :vsum,
107+
:vmuladd => :vsum,
108+
:vfma => :vsum,
109+
:vfmadd => :vsum,
110+
:vfmsub => :vsum,
111+
:vfnmadd => :vsum,
112+
:vfnmsub => :vsum
113+
)
100114
# const SIMDPIRATES_COST = Dict{Symbol,InstructionCost}()
101115
# const SLEEFPIRATES_COST = Dict{Symbol,InstructionCost}()
102116

src/graphs.jl

Lines changed: 105 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@ end
493493

494494
function depends_on_assigned(op::Operation, assigned::Vector{Bool})
495495
for p parents(op)
496+
p === op && continue # don't fall into recursive loop when we have updates, eg a = a + b
496497
assigned[identifier(op)] && return true
497498
depends_on_assigned(p, assigned) && return true
498499
end
@@ -516,7 +517,7 @@ end
516517
# Using sentinel values (eg, T = -1 for non tiling) in part to avoid recompilation.
517518
function lower_load!(
518519
q::Expr, op::Operation, W::Int, unrolled::Symbol,
519-
U::Int, T::Int = -1, tiled::Symbol = Symbol("##UNDEFINED##")
520+
U::Int, T::Int = -1, tiled::Symbol = Symbol("##UNDEFINED##"), mask = nothing
520521
)
521522
loopdeps = loopdependencies(op)
522523
var = op.variable
@@ -531,15 +532,19 @@ function lower_load!(
531532
push!(q.args, Expr(:(=), var, Expr(:call,:vload,ptr,memoff)))
532533
elseif T == -1
533534
for u 0:U-1
534-
push!(q.args, Expr(:(=), Symbol(var,:_,u), Expr(:call,:vload, Val{W}(), ptr, u == 0 ? memoff : push!(copy(memoff), W*u))))
535+
instrcall = Expr(:call,:vload, Val{W}(), ptr, u == 0 ? memoff : push!(copy(memoff), W*u))
536+
mask === nothing || push!(instrcall.args, mask)
537+
push!(q.args, Expr(:(=), Symbol(var,:_,u), instrcall))
535538
end
536539
else # tiling
537540
for t 0:T-1
538541
replace_ind_inoffset!(memoff, op, tind, t)
539542
for u 0:U-1
540543
memoff2 = copy(memoff)
541544
u > 0 && push!(memoff2, W*u)
542-
push!(q.args, Expr(:(=), Symbol(var, :_, u, :_, t), Expr(:call, :vload, Val{W}(), ptr, memoff2)))
545+
instrcall = Expr(:call, :vload, Val{W}(), ptr, memoff2)
546+
mask === nothing || push!(instrcall.args, mask)
547+
push!(q.args, Expr(:(=), Symbol(var, :_, u, :_, t), instrcall))
543548
end
544549
end
545550
end
@@ -552,16 +557,22 @@ function lower_load!(
552557
for u 0:U-1
553558
memoff2 = copy(memoff)
554559
u > 0 && push!(memoff2, ustride > 1 ? u*W*ustride : Expr(:call,:*,op.symbolic_metadata[upos],u*W) )
555-
push!(q.args, Expr(:(=), Symbol(var,:_,u,:_,t), Expr(:call, :gather, ptr, Expr(:call, :vadd, memoff2, ustrides))))
560+
instrcall = Expr(:call, :gather, ptr, Expr(:call, :vadd, memoff2, ustrides))
561+
mask === nothing || push!(instrcall.args, mask)
562+
push!(q.args, Expr(:(=), Symbol(var,:_,u,:_,t), instrcall))
556563
end
557564
end
558565
# elseif unitstride(op, tiled) # TODO: we load tiled, and then shuffle
559566
elseif U == 1 # we gather, no tile, no extra unroll
560-
push!(q.args, Expr(:(=), var, Expr(:call,:gather,ptr,Expr(:call,:vadd,memoff,ustrides))))
567+
instrcall = Expr(:call,:gather,ptr,Expr(:call,:vadd,memoff,ustrides))
568+
mask === nothing || push!(instrcall.args, mask)
569+
push!(q.args, Expr(:(=), var, instrcall))
561570
else # we gather, no tile, but extra unroll
562571
for u 0:U-1
563572
memoff2 = u == 0 ? memoff : push!(copy(memoff), ustride > 1 ? u*W*ustride : Expr(:call,:*,op.symbolic_metadata[upos],u*W) )
564-
push!(q.args, Expr(:(=), Symbol(var,:_,u), Expr(:call, :gather, ptr, Expr(:call,:vadd,memoff2,ustrides))))
573+
instrcall = Expr(:call, :gather, ptr, Expr(:call,:vadd,memoff2,ustrides))
574+
mask === nothing || push!(instrcall.args, mask)
575+
push!(q.args, Expr(:(=), Symbol(var,:_,u), instrcall))
565576
end
566577
end
567578
end
@@ -570,15 +581,18 @@ function lower_load!(
570581
# memoff2 = copy(memoff)
571582
for t 0:T-1
572583
replace_ind_inoffset!(memoff, op, tind, t)
573-
push!(q.args, Expr(:(=), Symbol(var,:_,t), Expr(:call, :load, ptr, copy(memoff))))
584+
instrcall = Expr(:call, :load, ptr, copy(memoff))
585+
# mask === nothing || push!(instrcall.args, mask)
586+
push!(q.args, Expr(:(=), Symbol(var,:_,t), instrcall))
574587
end
575588
else # load scalar; promotion should broadcast as/when neccesary
576589
push!(q.args, Expr(:(=), var, Expr(:call, :load, ptr, memoff)))
577590
end
578591
end
592+
# TODO: handle reductions correctly when we're storing non-unrolled parameters!
579593
function lower_store!(
580594
q::Expr, op::Operation, W::Int, unrolled::Symbol,
581-
U::Int, T::Int = -1, tiled::Symbol = Symbol("##UNDEFINED##")
595+
U::Int, T::Int = -1, tiled::Symbol = Symbol("##UNDEFINED##"), mask = nothing
582596
)
583597
loopdeps = loopdependencies(op)
584598
var = first(parents(op)).variable
@@ -593,15 +607,19 @@ function lower_store!(
593607
push!(q.args, Expr(:(=), var, Expr(:call,:vload,ptr,memoff)))
594608
elseif T == -1
595609
for u 0:U-1
596-
push!(q.args, Expr(:call,:vstore!, ptr, Symbol(var,:_,u), u == 0 ? memoff : push!(copy(memoff), W*u)))
610+
instrcall = Expr(:call,:vstore!, ptr, Symbol(var,:_,u), u == 0 ? memoff : push!(copy(memoff), W*u))
611+
mask === nothing || push!(instrcall.args, mask)
612+
push!(q.args, instrcall)
597613
end
598614
else # tiling
599615
for t 0:T-1
600616
replace_ind_inoffset!(memoff, op, tind, t)
601617
for u 0:U-1
602618
memoff2 = copy(memoff)
603619
u > 0 && push!(memoff2, W*u)
604-
push!(q.args, Expr(:call, :vstore!, ptr, Symbol(var, :_, u, :_, t), memoff2))
620+
instrcall = Expr(:call, :vstore!, ptr, Symbol(var, :_, u, :_, t), memoff2)
621+
mask === nothing || push!(instrcall.args, mask)
622+
push!(q.args, instrcall)
605623
end
606624
end
607625
end
@@ -614,69 +632,108 @@ function lower_store!(
614632
for u 0:U-1
615633
memoff2 = copy(memoff)
616634
u > 0 && push!(memoff2, ustride > 1 ? u*W*ustride : Expr(:call,:*,op.symbolic_metadata[upos],u*W) )
617-
push!(q.args, Expr(:call, :scatter!, ptr, Symbol(var,:_,u,:_,t), Expr(:call, :vadd, memoff2, ustrides)))
635+
instrcall = Expr(:call, :scatter!, ptr, Symbol(var,:_,u,:_,t), Expr(:call, :vadd, memoff2, ustrides))
636+
mask === nothing || push!(instrcall.args, mask)
637+
push!(q.args, instrcall)
618638
end
619639
end
620640
# elseif unitstride(op, tiled) # TODO: we load tiled, and then shuffle
621641
elseif U == 1 # we gather, no tile, no extra unroll
622-
push!(q.args, Expr(:call,:scatter!,ptr, var, Expr(:call,:vadd,memoff,ustrides)))
642+
instrcall = Expr(:call,:scatter!,ptr, var, Expr(:call,:vadd,memoff,ustrides))
643+
mask === nothing || push!(instrcall.args, mask)
644+
push!(q.args, instrcall)
623645
else # we gather, no tile, but extra unroll
624646
for u 0:U-1
625647
memoff2 = u == 0 ? memoff : push!(copy(memoff), ustride > 1 ? u*W*ustride : Expr(:call,:*,op.symbolic_metadata[upos],u*W) )
626-
push!(q.args, Expr(:call, :scatter!, ptr, Symbol(var,:_,u), Expr(:call,:vadd,memoff2,ustrides)))
648+
instrcall = Expr(:call, :scatter!, ptr, Symbol(var,:_,u), Expr(:call,:vadd,memoff2,ustrides))
649+
mask === nothing || push!(instrcall.args, mask)
650+
push!(q.args, instrcall)
627651
end
628652
end
629653
end
630-
elseif T != -1 && tiled loopdeps # load for each tile.
631-
# load per T.
632-
# memoff2 = copy(memoff)
633-
for t 0:T-1
634-
replace_ind_inoffset!(memoff, op, tind, t)
635-
push!(q.args, Expr(:call, :store!, ptr, Symbol(var,:_,t), copy(memoff)))
654+
else
655+
# need to find out reduction type
656+
reduct = CORRESPONDING_REDUCTION[first(parents(op)).instruction]
657+
if T != -1 && tiled loopdeps # no unroll, but tiled
658+
# store per T.
659+
# memoff2 = copy(memoff)
660+
for t 0:T-1
661+
replace_ind_inoffset!(memoff, op, tind, t)
662+
push!(q.args, Expr(:call, :store!, ptr, Symbol(var,:_,t), copy(memoff)))
663+
end
664+
else # no unroll
665+
push!(q.args, Expr(:call, :store!, var, ptr, memoff))
636666
end
637-
else # load scalar; promotion should broadcast as/when neccesary
638-
push!(q.args, Expr(:call, :store!, var, ptr, memoff))
639667
end
640668
end
641669
# A compute op needs to know the unrolling and tiling status of each of its parents.
642670
#
643671
function lower_compute!(
644672
q::Expr, op::Operation, W::Int, unrolled::Symbol,
645-
U::Int, T::Int = -1, tiled::Symbol = Symbol("##UNDEFINED##")
673+
U::Int, T::Int = -1, tiled::Symbol = Symbol("##UNDEFINED##"), mask = nothing
646674
)
647675
opunrolled = unrolled loopdependencies(op)
648676
optiled = tiled loopdependencies(op)
649677
var = op.variable
678+
instr = op.instruction
679+
650680
# cache unroll and tiling check of parents
651681
# not broadcasted, because we use frequent checks of individual bools
652682
# making BitArrays inefficient.
653-
parentsunrolled = [unrolled loopdependencies(opp) for oppp parents(op)]
654-
parentstiled = [tiled loopdependencies(opp) for oppp parents(op)]
655-
if opunrolled
656-
if optiled
657-
for t 0:T-1
658-
for u 0:U-1
659-
683+
parents_op = parents(op)
684+
nparents = length(parents_op)
685+
parentsunrolled = opunrolled ? [unrolled loopdependencies(opp) for opp parents_op] : fill(false, nparents)
686+
parentstiled = optiled ? [tiled loopdependencies(opp) for opp parents_op] : fill(false, nparents)
687+
# parentsyms = [opp.variable for opp ∈ parents(op)]
688+
Uiter = opunrolled ? U-1 : o
689+
Titer = optiled ? T-1 : 0
690+
maskreduct = mask !== nothing && any(opp -> opp.variable === var, parents_op)
691+
# if a parent is not unrolled, the compiler should handle broadcasting CSE.
692+
# because unrolled/tiled parents result in an unrolled/tiled dependendency,
693+
# we handle both the tiled and untiled case here.
694+
# bajillion branches that go the same way on each iteration
695+
# but smaller function is probably worthwhile. Compiler could theoreically split anyway
696+
# but I suspect that the branches are so cheap compared to the cost of everything else going on
697+
# that smaller size is more advantageous.
698+
for t 0:Titer
699+
for u 0:U-1
700+
intrcall = Expr(:call, instr)
701+
for n 1:nparents
702+
parent = parents_op.variable
703+
if parentsunrolled[n]
704+
parent = Symbol(parent,:_,u)
660705
end
706+
if parentstiled[n]
707+
parent = Symbol(parent,:_,t)
708+
end
709+
push!(intrcall.args, parent)
710+
end
711+
varsym = var
712+
if opunrolled
713+
varsym = Symbol(varsym,:_,u)
714+
end
715+
if optiled
716+
varsym = Symbol(varsym,:_,t)
717+
end
718+
if maskreduct
719+
push!(q.args, Expr(:(=), varsym, Expr(:call, :vifesle, mask, varsym, instrcall)))
720+
else
721+
push!(q.args, Expr(:(=), varsym, instrcall))
661722
end
662-
else # not tiled
663-
664-
end
665-
else # not unrolled
666-
if optiled # but not unrolled
667-
668-
else # not tiled and not unrolled
669-
670723
end
671724
end
672725
end
673-
function lower!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
726+
function lower!(
727+
q::Expr, op::Operation, W::Int, unrolled::Symbol,
728+
U::Int, T::Int = -1, tiled::Symbol = Symbol("##UNDEFINED##"),
729+
mask = nothing
730+
)
674731
if isload(op)
675-
lower_load!(q, op, unrolled, U, T)
732+
lower_load!(q, op, W, unrolled, U, T, tiled, mask)
676733
elseif isstore(op)
677-
lower_store!(q, op, unrolled, U, T)
734+
lower_store!(q, op, W, unrolled, U, T, tiled, mask)
678735
else
679-
lower_compute!(q, op, unrolled, U, T)
736+
lower_compute!(q, op, W, unrolled, U, T, tiled, mask)
680737
end
681738
end
682739

@@ -690,13 +747,12 @@ function lower_unroll(ls::LoopSet, order::Vector{Symbol}, U::Int)
690747
end
691748
function lower_unroll_inner_block(ls::LoopSet, order::Vector{Symbol}, U::Int)
692749
# this function create the inner block
693-
args = Any[]
750+
# args = Any[]
694751
nloops = length(order)
695752
unrolled = first(order)
696753
# included_syms = Set( (unrolled,) )
697754
included_vars = fill(false, length(operations(ls)))
698755
# to go inside out, we just have to include all those not-yet included depending on the current sym
699-
700756
n = 0
701757
loopsym = last(order)
702758
blockq = Expr(:block, )
@@ -711,7 +767,7 @@ function lower_unroll_inner_block(ls::LoopSet, order::Vector{Symbol}, U::Int)
711767
for n 1:nloops - 2
712768
loopsym = order[nloops - n]
713769
blockq = Expr(:block, )
714-
loopq = Expr(:for, Expr(:(=), itersym, looprange), blockq)
770+
postloop = Expr(:block, )
715771
for (id,op) enumerate(operations(ls))
716772
included_vars[id] && continue
717773
# We add an op the first time all loop dependencies are met
@@ -720,10 +776,13 @@ function lower_unroll_inner_block(ls::LoopSet, order::Vector{Symbol}, U::Int)
720776
included_vars[id] = true
721777

722778
after_loop = depends_on_assigned(op, included_vars)
723-
724-
779+
after_loop || lower!(blockq, op, unrolled, U)
780+
after_loop && lower!(postloop, op, unrolled, U)
725781
end
782+
push!(blockq.args, loopq_old); append!(blockq.args, postloop.args)
783+
loopq = Expr(:for, Expr(:(=), itersym, looprange), blockq)
726784
end
785+
loopq
727786
end
728787
function lower_unroll_static(ls::LoopSet, order::Vector{Symbol}, U::Int)
729788

@@ -801,38 +860,3 @@ end
801860

802861

803862

804-
using BenchmarkTools, LoopVectorization, SLEEF
805-
θ = randn(1000); c = randn(1000);
806-
function sumsc_vectorized::AbstractArray{Float64}, coef::AbstractArray{Float64})
807-
s, c = 0.0, 0.0
808-
@vvectorize for i eachindex(θ, coef)
809-
sinθᵢ, cosθᵢ = sincos(θ[i])
810-
s += coef[i] * sinθᵢ
811-
c += coef[i] * cosθᵢ
812-
end
813-
s, c
814-
end
815-
function sumsc_serial::AbstractArray{Float64}, coef::AbstractArray{Float64})
816-
s, c = 0.0, 0.0
817-
@inbounds for i eachindex(θ, coef)
818-
sinθᵢ, cosθᵢ = sincos(θ[i])
819-
s += coef[i] * sinθᵢ
820-
c += coef[i] * cosθᵢ
821-
end
822-
s, c
823-
end
824-
function sumsc_sleef::AbstractArray{Float64}, coef::AbstractArray{Float64})
825-
s, c = 0.0, 0.0
826-
@inbounds @simd for i eachindex(θ, coef)
827-
sinθᵢ, cosθᵢ = SLEEF.sincos_fast(θ[i])
828-
s += coef[i] * sinθᵢ
829-
c += coef[i] * cosθᵢ
830-
end
831-
s, c
832-
end
833-
834-
@btime sumsc_serial($θ, $c)
835-
@btime sumsc_sleef($θ, $c)
836-
@btime sumsc_vectorized($θ, $c)
837-
838-

0 commit comments

Comments
 (0)