Skip to content

Commit 5d09179

Browse files
committed
Some progress, wrote lowering for load, working on lowering for store (currently a copy paste of the load).
1 parent ba1930b commit 5d09179

File tree

2 files changed

+179
-21
lines changed

2 files changed

+179
-21
lines changed

src/graphs.jl

Lines changed: 177 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ isdense(::Type{<:DenseArray}) = true
2626

2727

2828

29-
@enum NodeType begin
29+
@enum OperationType begin
3030
memload
3131
memstore
3232
compute_new
@@ -37,35 +37,42 @@ end
3737
# const ID = Threads.Atomic{UInt}(0)
3838

3939
"""
40-
if node_type == memstore || node_type == compute_new || node_type == compute_store
40+
if ooperation_type == memstore || operation_type == memstore# || operation_type == compute_new || operation_type == compute_update
4141
symbolic metadata contains info on direct dependencies / placement within loop.
4242
43-
43+
if accesses_memory(op)
44+
Symbol(:vptr_, op.variable)
45+
is how we access the memory.
46+
If numerical_metadata[i] == -1
47+
Symbol(:stride_, op.variable, :_, op.symbolic_metadata[i])
48+
is the stride for loop index
49+
symbolic_metadata[i]
4450
"""
4551
struct Operation
4652
identifier::UInt
4753
variable::Symbol
4854
elementbytes::Int
4955
instruction::Symbol
50-
node_type::NodeType
56+
node_type::OperationType
5157
# dependencies::Vector{Symbol}
5258
dependencies::Set{Symbol}
5359
# dependencies::Set{Symbol}
5460
parents::Vector{Operation}
5561
children::Vector{Operation}
56-
numerical_metadata::Vector{Int}
62+
numerical_metadata::Vector{Int} # stride of -1 indicates dynamic
5763
symbolic_metadata::Vector{Symbol}
64+
# strides::Dict{Symbol,Union{Symbol,Int}}
5865
function Operation(
66+
identifier,
5967
elementbytes,
6068
instruction,
6169
node_type,
62-
identifier,
6370
variable = gensym()
6471
)
6572
# identifier = Threads.atomic_add!(ID, one(UInt))
6673
new(
6774
identifier, variable, elementbytes, instruction, node_type,
68-
Set{Symbol}(), Operation[], Operation[], Int[], Symbol[]
75+
Set{Symbol}(), Operation[], Operation[], Int[], Symbol[]#, Dict{Symbol,Union{Symbol,Int}}()
6976
)
7077
end
7178
end
@@ -85,16 +92,45 @@ identifier(op::Operation) = op.identifier
8592
name(op::Operation) = op.variable
8693
instruction(op::Operation) = op.instruction
8794

95+
function symposition(op::Operation, sym::Symbol)
96+
findfirst(s -> s === sym, op.symbolic_metadata)
97+
end
8898
function stride(op::Operation, sym::Symbol)
8999
@assert accesses_memory(op) "This operation does not access memory!"
90100
# access stride info?
91-
op.numerical_metadata[findfirst(s -> s === sym, op.symbolic_metadata)]
101+
op.numerical_metadata[symposition(op,sym)]
92102
end
93103
# function
94104
function unitstride(op::Operation, sym::Symbol)
95105
(first(op.symbolic_metadata) === sym) && (first(op.numerical_metadata) == 1)
96106
end
97-
107+
function mem_offset(op::Operation, incr::Int = 0)::Union{Symbol,Expr}
108+
@assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
109+
@unpack numerical_metadata, symbolic_metadata = op
110+
if incr == 0 && length(numerical_metadata) == 1
111+
firstsym = first(symbolic_metadata)
112+
if first(numerical_metadata) == 1
113+
return firstsym
114+
elseif first(numerical_metadata) == -1
115+
return Expr(:call, :*, Symbol(:stride_, op.variable, :_, firstsym), firstsym)
116+
else
117+
return Expr(:call, :*, first(numerical_metadata), firstsym)
118+
end
119+
end
120+
ret = Expr(:call, :+, )
121+
for i eachindex(numerical_metadata)
122+
sym = symbolic_metadata[i]; num = numerical_metadata[i]
123+
if num == 1
124+
push!(ret.args, sym)
125+
elseif num == -1
126+
push!(ret.args, Expr(:call, :*, Symbol(:stride_, op.variable, :_, firstsym), sym))
127+
else
128+
push!(ret.args, Expr(:call, :*, num, sym))
129+
end
130+
end
131+
incr == 0 || push!(ret.args, incr)
132+
ret
133+
end
98134

99135
struct Loop
100136
itersymbol::Symbol
@@ -457,25 +493,147 @@ function depends_on_assigned(op::Operation, assigned::Vector{Bool})
457493
end
458494
false
459495
end
460-
function lower_load!(q::Expr, op::Operation, unrolled::Symbol, U, Umax, T = nothing, Tmax = nothing)
496+
function replace_ind_in_offset!(offset::Vector, op::Operation, ind::Int, dynamic::Bool, t)
497+
t == 0 && return nothing
498+
var = op.variable
499+
siter = op.symbolic_metadata[ind]
500+
striden = op.numerical_metadata[ind]
501+
strides = Symbol(:stride_, var)
502+
offset[ind] = if tstriden == -1
503+
Expr(:call, :*, Expr(:call, :+, strides, t), siter)
504+
else
505+
Expr(:call, :*, striden + t, siter)
506+
end
507+
nothing
508+
end
509+
510+
# TODO: this code should be rewritten to be more "orthogonal", so that we're just combining separate pieces.
511+
# Using sentinel values (eg, T = -1 for non tiling) in part to avoid recompilation.
512+
function lower_load!(
513+
q::Expr, op::Operation, W::Int, unrolled::Symbol,
514+
U::Int, T::Int = -1, tiled::Symbol = :undef
515+
)
461516
loopdeps = loopdependencies(op)
517+
var = op.variable
518+
ptr = Symbol(:vptr_, var)
519+
memoff = mem_offset(op)
520+
tind = T == -1 ? -1 : findfirst(s -> s === tiled, op.symbolic_metadata)
521+
upos = symposition(op, unrolled)
522+
ustride = op.numerical_metadata[upos]
462523
if unrolled loopdeps # we need a vector
463-
if unitstride(op, unrolled) # vload
464-
465-
else # gather
466-
524+
if ustride == 1 # vload
525+
if T == -1 && U == 1
526+
push!(q.args, Expr(:(=), var, Expr(:call,:vload,ptr,memoff)))
527+
elseif T == -1
528+
for u 0:U-1
529+
push!(q.args, Expr(:(=), Symbol(var,:_,u), Expr(:call,:vload, Val(W), ptr, u == 0 ? memoff : push!(copy(memoff), W*u))))
530+
end
531+
else # tiling
532+
for t 0:T-1
533+
replace_ind_inoffset!(memoff, op, tind, t)
534+
for u 0:U-1
535+
memoff2 = copy(memoff)
536+
u > 0 && push!(memoff2, W*u)
537+
push!(q.args, Expr(:(=), Symbol(var, :_, u, :_, t), Expr(:call, :vload, Val(W), ptr, memoff2)))
538+
end
539+
end
540+
end
541+
else
542+
# ustep = ustride > 1 ? ustride : op.symbolic_metadata[upos]
543+
ustrides = Expr(:tuple, (ustride > 1 ? [Core.VecElement{Int}(ustride*w) for w 0:W-1] : [:(Core.VecElement{Int}($(op.symbolic_metadata[upos])*$w)) for w 0:W-1])...)
544+
if T != -1 # gather tile
545+
for t 0:T-1
546+
replace_ind_inoffset!(memoff, op, tind, t)
547+
for u 0:U-1
548+
memoff2 = copy(memoff)
549+
u > 0 && push!(memoff2, ustride > 1 ? u*W*ustride : Expr(:call,:*,op.symbolic_metadata[upos],u*W) )
550+
push!(q.args, Expr(:(=), Symbol(var,:_,u,:_,t), Expr(:call, :gather, ptr, Expr(:call, :vadd, memoff2, ustrides))))
551+
end
552+
end
553+
# elseif unitstride(op, tiled) # TODO: we load tiled, and then shuffle
554+
elseif U == 1 # we gather, no tile, no extra unroll
555+
push!(q.args, Expr(:(=), var, Expr(:call,:gather,ptr,Expr(:call,:vadd,memoff,ustrides))))
556+
else # we gather, no tile, but extra unroll
557+
for u 0:U-1
558+
memoff2 = u == 0 ? memoff : push!(copy(memoff), ustride > 1 ? u*W*ustride : Expr(:call,:*,op.symbolic_metadata[upos],u*W) )
559+
push!(q.args, Expr(:(=), Symbol(var,:_,u), Expr(:call, :gather, ptr, Expr(:call,:vadd,memoff2,ustrides))))
560+
end
561+
end
562+
end
563+
elseif T != -1 && tiled loopdeps # load for each tile.
564+
# load per T.
565+
# memoff2 = copy(memoff)
566+
for t 0:T-1
567+
replace_ind_inoffset!(memoff, op, tind, t)
568+
push!(q.args, Expr(:(=), Symbol(var,:_,t), Expr(:call, :load, ptr, copy(memoff))))
467569
end
468570
else # load scalar; promotion should broadcast as/when neccesary
469-
Expr(:call, :(VectorizationBase.load), )
571+
push!(q.args, Expr(:(=), var, Expr(:call, :load, ptr, memoff)))
470572
end
471573
end
472574
function lower_store!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
473-
575+
q::Expr, op::Operation, W::Int, unrolled::Symbol,
576+
U::Int, T::Int = -1, tiled::Symbol = :undef
577+
)
578+
loopdeps = loopdependencies(op)
579+
var = first(parents(op)).variable
580+
ptr = Symbol(:vptr_, op.variable)
581+
memoff = mem_offset(op)
582+
tind = T == -1 ? -1 : findfirst(s -> s === tiled, op.symbolic_metadata)
583+
upos = symposition(op, unrolled)
584+
ustride = op.numerical_metadata[upos]
585+
if unrolled loopdeps # we need a vector
586+
if ustride == 1 # vload
587+
if T == -1 && U == 1
588+
push!(q.args, Expr(:(=), var, Expr(:call,:vload,ptr,memoff)))
589+
elseif T == -1
590+
for u 0:U-1
591+
push!(q.args, Expr(:(=), Symbol(var,:_,u), Expr(:call,:vstore, Val(W), ptr, u == 0 ? memoff : push!(copy(memoff), W*u))))
592+
end
593+
else # tiling
594+
for t 0:T-1
595+
replace_ind_inoffset!(memoff, op, tind, t)
596+
for u 0:U-1
597+
memoff2 = copy(memoff)
598+
u > 0 && push!(memoff2, W*u)
599+
push!(q.args, Expr(:(=), Symbol(var, :_, u, :_, t), Expr(:call, :vload, Val(W), ptr, memoff2)))
600+
end
601+
end
602+
end
603+
else
604+
# ustep = ustride > 1 ? ustride : op.symbolic_metadata[upos]
605+
ustrides = Expr(:tuple, (ustride > 1 ? [Core.VecElement{Int}(ustride*w) for w 0:W-1] : [:(Core.VecElement{Int}($(op.symbolic_metadata[upos])*$w)) for w 0:W-1])...)
606+
if T != -1 # gather tile
607+
for t 0:T-1
608+
replace_ind_inoffset!(memoff, op, tind, t)
609+
for u 0:U-1
610+
memoff2 = copy(memoff)
611+
u > 0 && push!(memoff2, ustride > 1 ? u*W*ustride : Expr(:call,:*,op.symbolic_metadata[upos],u*W) )
612+
push!(q.args, Expr(:(=), Symbol(var,:_,u,:_,t), Expr(:call, :gather, ptr, Expr(:call, :vadd, memoff2, ustrides))))
613+
end
614+
end
615+
# elseif unitstride(op, tiled) # TODO: we load tiled, and then shuffle
616+
elseif U == 1 # we gather, no tile, no extra unroll
617+
push!(q.args, Expr(:(=), var, Expr(:call,:gather,ptr,Expr(:call,:vadd,memoff,ustrides))))
618+
else # we gather, no tile, but extra unroll
619+
for u 0:U-1
620+
memoff2 = u == 0 ? memoff : push!(copy(memoff), ustride > 1 ? u*W*ustride : Expr(:call,:*,op.symbolic_metadata[upos],u*W) )
621+
push!(q.args, Expr(:(=), Symbol(var,:_,u), Expr(:call, :gather, ptr, Expr(:call,:vadd,memoff2,ustrides))))
622+
end
623+
end
624+
end
625+
elseif T != -1 && tiled loopdeps # load for each tile.
626+
# load per T.
627+
# memoff2 = copy(memoff)
628+
for t 0:T-1
629+
replace_ind_inoffset!(memoff, op, tind, t)
630+
push!(q.args, Expr(:(=), Symbol(var,:_,t), Expr(:call, :load, ptr, copy(memoff))))
631+
end
632+
else # load scalar; promotion should broadcast as/when neccesary
633+
push!(q.args, Expr(:(=), var, Expr(:call, :load, ptr, memoff)))
634+
end
474635
end
475636
function lower_compute!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
476-
for t T, u U
477-
478-
end
479637
end
480638
function lower!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
481639
if isload(op)

src/precompile.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ function _precompile_()
22
ccall(:jl_generating_output, Cint, ()) == 1 || return nothing
33
precompile(Tuple{typeof(LoopVectorization.vectorize_body),Int64,Type{Float64},Int64,Symbol,Array{Any,1},Dict{Symbol,Tuple{Symbol,Symbol}},Any,Bool,Module})
44
precompile(Tuple{typeof(LoopVectorization.vectorize_body),Int64,Type{Float64},Int64,Symbol,Array{Any,1},Dict{Symbol,Tuple{Symbol,Symbol}},Any,Bool})
5-
precompile(Tuple{LoopVectorization.var"#_vectorloads!##kw",NamedTuple{(:itersym, :declared_iter_sym, :VectorizationDict, :mod),Tuple{Symbol,Symbol,Dict{Symbol,Tuple{Symbol,Symbol}},Module}},typeof(LoopVectorization._vectorloads!),Expr,Expr,Tuple{Dict{Symbol,Symbol},Dict{Tuple{Symbol,Symbol},Symbol},Dict{Expr,Symbol},Dict{Expr,Symbol}},Type,Int64,Type,Expr,Expr})
5+
# precompile(Tuple{LoopVectorization.var"#_vectorloads!##kw",NamedTuple{(:itersym, :declared_iter_sym, :VectorizationDict, :mod),Tuple{Symbol,Symbol,Dict{Symbol,Tuple{Symbol,Symbol}},Module}},typeof(LoopVectorization._vectorloads!),Expr,Expr,Tuple{Dict{Symbol,Symbol},Dict{Tuple{Symbol,Symbol},Symbol},Dict{Expr,Symbol},Dict{Expr,Symbol}},Type,Int64,Type,Expr,Expr})
66
precompile(Tuple{typeof(LoopVectorization.vectorize_body),Symbol,Type{Float64},Int64,Symbol,Array{Any,1},Dict{Symbol,Tuple{Symbol,Symbol}},Any,Bool})
7-
precompile(Tuple{LoopVectorization.var"#_vectorloads!##kw",NamedTuple{(:itersym, :declared_iter_sym, :VectorizationDict, :mod),Tuple{Symbol,Symbol,Dict{Symbol,Tuple{Symbol,Symbol}},Symbol}},typeof(LoopVectorization._vectorloads!),Expr,Expr,Tuple{Dict{Symbol,Symbol},Dict{Tuple{Symbol,Symbol},Symbol},Dict{Expr,Symbol},Dict{Expr,Symbol}},Type,Int64,Type,Expr,Expr})
7+
# precompile(Tuple{LoopVectorization.var"#_vectorloads!##kw",NamedTuple{(:itersym, :declared_iter_sym, :VectorizationDict, :mod),Tuple{Symbol,Symbol,Dict{Symbol,Tuple{Symbol,Symbol}},Symbol}},typeof(LoopVectorization._vectorloads!),Expr,Expr,Tuple{Dict{Symbol,Symbol},Dict{Tuple{Symbol,Symbol},Symbol},Dict{Expr,Symbol},Dict{Expr,Symbol}},Type,Int64,Type,Expr,Expr})
88
precompile(Tuple{typeof(LoopVectorization.add_masks),Expr,Symbol,Dict{Tuple{Symbol,Symbol},Symbol},Module})
99
precompile(Tuple{typeof(LoopVectorization.add_masks),Expr,Symbol,Dict{Tuple{Symbol,Symbol},Symbol},Symbol})
1010
end

0 commit comments

Comments
 (0)