Skip to content

Commit a3e2a86

Browse files
committed
Minor updates; working on lower_compute.
1 parent 5d09179 commit a3e2a86

File tree

1 file changed

+51
-16
lines changed

1 file changed

+51
-16
lines changed

src/graphs.jl

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ end
3636

3737
# const ID = Threads.Atomic{UInt}(0)
3838

39+
# TODO: can some computations be cached in the operations?
3940
"""
4041
if ooperation_type == memstore || operation_type == memstore# || operation_type == compute_new || operation_type == compute_update
4142
symbolic metadata contains info on direct dependencies / placement within loop.
@@ -77,6 +78,8 @@ struct Operation
7778
end
7879
end
7980

81+
82+
8083
function isreduction(op::Operation)
8184
(op.node_type == memstore) && (length(op.symbolic_metadata) < length(op.dependencies))# && issubset(op.symbolic_metadata, op.dependencies)
8285
end
@@ -92,6 +95,7 @@ identifier(op::Operation) = op.identifier
9295
name(op::Operation) = op.variable
9396
instruction(op::Operation) = op.instruction
9497

98+
9599
function symposition(op::Operation, sym::Symbol)
96100
findfirst(s -> s === sym, op.symbolic_metadata)
97101
end
@@ -139,7 +143,7 @@ struct Loop
139143
hintexact::Bool # if true, rangesym ignored and rangehint used for final lowering
140144
end
141145
function Loop(itersymbol::Symbol, rangehint::Int)
142-
Loop( itersymbol, rangehint, :undef, true )
146+
Loop( itersymbol, rangehint, Symbol("##UNDEFINED##"), true )
143147
end
144148
function Loop(itersymbol::Symbol, rangesym::Symbol, rangehint::Int = 1_024)
145149
Loop( itersymbol, rangehint, rangesym, false )
@@ -152,8 +156,9 @@ struct LoopSet
152156
loadops::Vector{Operation} # Split them to make it easier to iterate over just a subset
153157
computeops::Vector{Operation}
154158
storeops::Vector{Operation}
155-
reductions::Set{UInt} # IDs of reduction operations that need to be reduced at end.
156-
strideset::Vector{}
159+
inner_reductions::Set{UInt} # IDs of reduction operations nested within loops and stored.
160+
outer_reductions::Set{UInt} # IDs of reduction operations that need to be reduced at end.
161+
# strideset::Vector{}
157162
end
158163
num_loops(ls::LoopSet) = length(ls.loops)
159164
isstaticloop(ls::LoopSet, s::Symbol) = ls.loops[s].hintexact
@@ -511,7 +516,7 @@ end
511516
# Using sentinel values (eg, T = -1 for non tiling) in part to avoid recompilation.
512517
function lower_load!(
513518
q::Expr, op::Operation, W::Int, unrolled::Symbol,
514-
U::Int, T::Int = -1, tiled::Symbol = :undef
519+
U::Int, T::Int = -1, tiled::Symbol = Symbol("##UNDEFINED##")
515520
)
516521
loopdeps = loopdependencies(op)
517522
var = op.variable
@@ -526,15 +531,15 @@ function lower_load!(
526531
push!(q.args, Expr(:(=), var, Expr(:call,:vload,ptr,memoff)))
527532
elseif T == -1
528533
for u 0:U-1
529-
push!(q.args, Expr(:(=), Symbol(var,:_,u), Expr(:call,:vload, Val(W), ptr, u == 0 ? memoff : push!(copy(memoff), W*u))))
534+
push!(q.args, Expr(:(=), Symbol(var,:_,u), Expr(:call,:vload, Val{W}(), ptr, u == 0 ? memoff : push!(copy(memoff), W*u))))
530535
end
531536
else # tiling
532537
for t 0:T-1
533538
replace_ind_inoffset!(memoff, op, tind, t)
534539
for u 0:U-1
535540
memoff2 = copy(memoff)
536541
u > 0 && push!(memoff2, W*u)
537-
push!(q.args, Expr(:(=), Symbol(var, :_, u, :_, t), Expr(:call, :vload, Val(W), ptr, memoff2)))
542+
push!(q.args, Expr(:(=), Symbol(var, :_, u, :_, t), Expr(:call, :vload, Val{W}(), ptr, memoff2)))
538543
end
539544
end
540545
end
@@ -571,9 +576,9 @@ function lower_load!(
571576
push!(q.args, Expr(:(=), var, Expr(:call, :load, ptr, memoff)))
572577
end
573578
end
574-
function lower_store!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
579+
function lower_store!(
575580
q::Expr, op::Operation, W::Int, unrolled::Symbol,
576-
U::Int, T::Int = -1, tiled::Symbol = :undef
581+
U::Int, T::Int = -1, tiled::Symbol = Symbol("##UNDEFINED##")
577582
)
578583
loopdeps = loopdependencies(op)
579584
var = first(parents(op)).variable
@@ -588,15 +593,15 @@ function lower_store!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
588593
push!(q.args, Expr(:(=), var, Expr(:call,:vload,ptr,memoff)))
589594
elseif T == -1
590595
for u 0:U-1
591-
push!(q.args, Expr(:(=), Symbol(var,:_,u), Expr(:call,:vstore, Val(W), ptr, u == 0 ? memoff : push!(copy(memoff), W*u))))
596+
push!(q.args, Expr(:call,:vstore!, ptr, Symbol(var,:_,u), u == 0 ? memoff : push!(copy(memoff), W*u)))
592597
end
593598
else # tiling
594599
for t 0:T-1
595600
replace_ind_inoffset!(memoff, op, tind, t)
596601
for u 0:U-1
597602
memoff2 = copy(memoff)
598603
u > 0 && push!(memoff2, W*u)
599-
push!(q.args, Expr(:(=), Symbol(var, :_, u, :_, t), Expr(:call, :vload, Val(W), ptr, memoff2)))
604+
push!(q.args, Expr(:call, :vstore!, ptr, Symbol(var, :_, u, :_, t), memoff2))
600605
end
601606
end
602607
end
@@ -609,16 +614,16 @@ function lower_store!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
609614
for u 0:U-1
610615
memoff2 = copy(memoff)
611616
u > 0 && push!(memoff2, ustride > 1 ? u*W*ustride : Expr(:call,:*,op.symbolic_metadata[upos],u*W) )
612-
push!(q.args, Expr(:(=), Symbol(var,:_,u,:_,t), Expr(:call, :gather, ptr, Expr(:call, :vadd, memoff2, ustrides))))
617+
push!(q.args, Expr(:call, :scatter!, ptr, Symbol(var,:_,u,:_,t), Expr(:call, :vadd, memoff2, ustrides)))
613618
end
614619
end
615620
# elseif unitstride(op, tiled) # TODO: we load tiled, and then shuffle
616621
elseif U == 1 # we gather, no tile, no extra unroll
617-
push!(q.args, Expr(:(=), var, Expr(:call,:gather,ptr,Expr(:call,:vadd,memoff,ustrides))))
622+
push!(q.args, Expr(:call,:scatter!,ptr, var, Expr(:call,:vadd,memoff,ustrides)))
618623
else # we gather, no tile, but extra unroll
619624
for u 0:U-1
620625
memoff2 = u == 0 ? memoff : push!(copy(memoff), ustride > 1 ? u*W*ustride : Expr(:call,:*,op.symbolic_metadata[upos],u*W) )
621-
push!(q.args, Expr(:(=), Symbol(var,:_,u), Expr(:call, :gather, ptr, Expr(:call,:vadd,memoff2,ustrides))))
626+
push!(q.args, Expr(:call, :scatter!, ptr, Symbol(var,:_,u), Expr(:call,:vadd,memoff2,ustrides)))
622627
end
623628
end
624629
end
@@ -627,13 +632,43 @@ function lower_store!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
627632
# memoff2 = copy(memoff)
628633
for t 0:T-1
629634
replace_ind_inoffset!(memoff, op, tind, t)
630-
push!(q.args, Expr(:(=), Symbol(var,:_,t), Expr(:call, :load, ptr, copy(memoff))))
635+
push!(q.args, Expr(:call, :store!, ptr, Symbol(var,:_,t), copy(memoff)))
631636
end
632637
else # load scalar; promotion should broadcast as/when neccesary
633-
push!(q.args, Expr(:(=), var, Expr(:call, :load, ptr, memoff)))
638+
push!(q.args, Expr(:call, :store!, var, ptr, memoff))
634639
end
635640
end
636-
function lower_compute!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
641+
# A compute op needs to know the unrolling and tiling status of each of its parents.
642+
#
643+
function lower_compute!(
644+
q::Expr, op::Operation, W::Int, unrolled::Symbol,
645+
U::Int, T::Int = -1, tiled::Symbol = Symbol("##UNDEFINED##")
646+
)
647+
opunrolled = unrolled loopdependencies(op)
648+
optiled = tiled loopdependencies(op)
649+
var = op.variable
650+
# cache unroll and tiling check of parents
651+
# not broadcasted, because we use frequent checks of individual bools
652+
# making BitArrays inefficient.
653+
parentsunrolled = [unrolled loopdependencies(opp) for oppp parents(op)]
654+
parentstiled = [tiled loopdependencies(opp) for oppp parents(op)]
655+
if opunrolled
656+
if optiled
657+
for t 0:T-1
658+
for u 0:U-1
659+
660+
end
661+
end
662+
else # not tiled
663+
664+
end
665+
else # not unrolled
666+
if optiled # but not unrolled
667+
668+
else # not tiled and not unrolled
669+
670+
end
671+
end
637672
end
638673
function lower!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
639674
if isload(op)

0 commit comments

Comments
 (0)