Skip to content

Commit ba1930b

Browse files
committed
Still working on lowering code.
1 parent a0f3448 commit ba1930b

File tree

3 files changed

+48
-13
lines changed

3 files changed

+48
-13
lines changed

src/LoopVectorization.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,8 +202,9 @@ end
202202
b = macroexpand(LoopVectorization, b)
203203
## body preamble must define indexed symbols
204204
## we only need that for loads.
205+
dicts = (indexed_expressions, reduction_symbols, loaded_exprs, loop_constants_dict)
205206
push!(main_body.args,
206-
_vectorloads!(main_body, q, indexed_expressions, reduction_symbols, loaded_exprs, V, W, T, loop_constants_quote, loop_constants_dict, b;
207+
_vectorloads!(main_body, q, dicts, V, W, T, loop_constants_quote, b;
207208
itersym = itersym, declared_iter_sym = n, VectorizationDict = vecdict, mod = mod)
208209
)# |> x -> (@show(x), _pirate(x)))
209210
end
@@ -382,8 +383,9 @@ end
382383
end
383384

384385

385-
@noinline function _vectorloads!(main_body, pre_quote, indexed_expressions, reduction_symbols, loaded_exprs, V, W, VET, loop_constants_quote, loop_constants_dict, expr;
386-
itersym = :iter, declared_iter_sym = nothing, VectorizationDict = SLEEFPiratesDict, mod = :LoopVectorization)
386+
@noinline function _vectorloads!(main_body, pre_quote, dicts, V, W, VET, loop_constants_quote, expr;
387+
itersym = :iter, declared_iter_sym = nothing, VectorizationDict = SLEEFPiratesDict, mod = :LoopVectorization)
388+
(indexed_expressions, reduction_symbols, loaded_exprs, loop_constants_dict) = dicts
387389
_spirate(prewalk(expr) do x
388390
# @show x
389391
# @show main_body

src/graphs.jl

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ struct Operation
7171
end
7272

7373
function isreduction(op::Operation)
74-
(op.node_type == memstore) && (length(op.symbolic_metadata) < length(op.dependencies)) && issubset(op.symbolic_metadata, op.dependencies)
74+
(op.node_type == memstore) && (length(op.symbolic_metadata) < length(op.dependencies))# && issubset(op.symbolic_metadata, op.dependencies)
7575
end
7676
isload(op::Operation) = op.node_type == memload
7777
isstore(op::Operation) = op.node_type == memstore
@@ -91,11 +91,11 @@ function stride(op::Operation, sym::Symbol)
9191
op.numerical_metadata[findfirst(s -> s === sym, op.symbolic_metadata)]
9292
end
9393
# function
94-
95-
struct Node
96-
type::DataType
94+
function unitstride(op::Operation, sym::Symbol)
95+
(first(op.symbolic_metadata) === sym) && (first(op.numerical_metadata) == 1)
9796
end
9897

98+
9999
struct Loop
100100
itersymbol::Symbol
101101
rangehint::Int
@@ -116,7 +116,8 @@ struct LoopSet
116116
loadops::Vector{Operation} # Split them to make it easier to iterate over just a subset
117117
computeops::Vector{Operation}
118118
storeops::Vector{Operation}
119-
reductions::Set{Operation}
119+
reductions::Set{UInt} # IDs of reduction operations that need to be reduced at end.
120+
strideset::Vector{}
120121
end
121122
num_loops(ls::LoopSet) = length(ls.loops)
122123
isstaticloop(ls::LoopSet, s::Symbol) = ls.loops[s].hintexact
@@ -147,7 +148,7 @@ function cost(op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int)
147148
if accesses_memory(op)
148149
# either vbroadcast/reductionstore, vmov(a/u)pd, or gather/scatter
149150
if opisunrolled
150-
if (stride(op, unrolled) != 1) || !isdense(op) # need gather/scatter
151+
if !unitstride(op, unrolled)# || !isdense(op) # need gather/scatter
151152
r = (1 << Wshift)
152153
c *= r
153154
sl *= r
@@ -456,6 +457,35 @@ function depends_on_assigned(op::Operation, assigned::Vector{Bool})
456457
end
457458
false
458459
end
460+
function lower_load!(q::Expr, op::Operation, unrolled::Symbol, U, Umax, T = nothing, Tmax = nothing)
461+
loopdeps = loopdependencies(op)
462+
if unrolled loopdeps # we need a vector
463+
if unitstride(op, unrolled) # vload
464+
465+
else # gather
466+
467+
end
468+
else # load scalar; promotion should broadcast as/when neccesary
469+
Expr(:call, :(VectorizationBase.load), )
470+
end
471+
end
472+
function lower_store!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
473+
474+
end
475+
function lower_compute!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
476+
for t T, u U
477+
478+
end
479+
end
480+
function lower!(q::Expr, op::Operation, unrolled::Symbol, U, T = 1)
481+
if isload(op)
482+
lower_load!(q, op, unrolled, U, T)
483+
elseif isstore(op)
484+
lower_store!(q, op, unrolled, U, T)
485+
else
486+
lower_compute!(q, op, unrolled, U, T)
487+
end
488+
end
459489

460490
# construction requires ops inserted into operations vectors in dependency order.
461491
function lower_unroll(ls::LoopSet, order::Vector{Symbol}, U::Int)
@@ -476,13 +506,14 @@ function lower_unroll_inner_block(ls::LoopSet, order::Vector{Symbol}, U::Int)
476506

477507
n = 0
478508
loopsym = last(order)
509+
blockq = Expr(:block, )
510+
loopq = Expr(:for, Expr(:(=), itersym, looprange(ls, loopsym)), blockq)
479511
for (id,op) enumerate(operations(ls))
480512
# We add an op the first time all loop dependencies are met
481513
# when working through loops backwords, that equates to the first time we encounter a loop dependency
482514
loopsym dependencies(op) || continue
483515
included_vars[id] = true
484-
485-
516+
lower!(blockq, op, unrolled, U)
486517
end
487518
for n 1:nloops - 2
488519
loopsym = order[nloops - n]

src/precompile.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ function _precompile_()
22
ccall(:jl_generating_output, Cint, ()) == 1 || return nothing
33
precompile(Tuple{typeof(LoopVectorization.vectorize_body),Int64,Type{Float64},Int64,Symbol,Array{Any,1},Dict{Symbol,Tuple{Symbol,Symbol}},Any,Bool,Module})
44
precompile(Tuple{typeof(LoopVectorization.vectorize_body),Int64,Type{Float64},Int64,Symbol,Array{Any,1},Dict{Symbol,Tuple{Symbol,Symbol}},Any,Bool})
5-
precompile(Tuple{LoopVectorization.var"#_vectorloads!##kw",NamedTuple{(:itersym, :declared_iter_sym, :VectorizationDict, :mod),Tuple{Symbol,Symbol,Dict{Symbol,Tuple{Symbol,Symbol}},Symbol}},typeof(LoopVectorization._vectorloads!),Expr,Expr,Dict{Symbol,Symbol},Dict{Tuple{Symbol,Symbol},Symbol},Dict{Expr,Symbol},Type,Int64,Type,Expr,Dict{Expr,Symbol},Expr})
6-
precompile(Tuple{LoopVectorization.var"#_vectorloads!##kw",NamedTuple{(:itersym, :declared_iter_sym, :VectorizationDict, :mod),Tuple{Symbol,Symbol,Dict{Symbol,Tuple{Symbol,Symbol}},Module}},typeof(LoopVectorization._vectorloads!),Expr,Expr,Dict{Symbol,Symbol},Dict{Tuple{Symbol,Symbol},Symbol},Dict{Expr,Symbol},Type,Int64,Type,Expr,Dict{Expr,Symbol},Expr})
5+
precompile(Tuple{LoopVectorization.var"#_vectorloads!##kw",NamedTuple{(:itersym, :declared_iter_sym, :VectorizationDict, :mod),Tuple{Symbol,Symbol,Dict{Symbol,Tuple{Symbol,Symbol}},Module}},typeof(LoopVectorization._vectorloads!),Expr,Expr,Tuple{Dict{Symbol,Symbol},Dict{Tuple{Symbol,Symbol},Symbol},Dict{Expr,Symbol},Dict{Expr,Symbol}},Type,Int64,Type,Expr,Expr})
76
precompile(Tuple{typeof(LoopVectorization.vectorize_body),Symbol,Type{Float64},Int64,Symbol,Array{Any,1},Dict{Symbol,Tuple{Symbol,Symbol}},Any,Bool})
7+
precompile(Tuple{LoopVectorization.var"#_vectorloads!##kw",NamedTuple{(:itersym, :declared_iter_sym, :VectorizationDict, :mod),Tuple{Symbol,Symbol,Dict{Symbol,Tuple{Symbol,Symbol}},Symbol}},typeof(LoopVectorization._vectorloads!),Expr,Expr,Tuple{Dict{Symbol,Symbol},Dict{Tuple{Symbol,Symbol},Symbol},Dict{Expr,Symbol},Dict{Expr,Symbol}},Type,Int64,Type,Expr,Expr})
8+
precompile(Tuple{typeof(LoopVectorization.add_masks),Expr,Symbol,Dict{Tuple{Symbol,Symbol},Symbol},Module})
9+
precompile(Tuple{typeof(LoopVectorization.add_masks),Expr,Symbol,Dict{Tuple{Symbol,Symbol},Symbol},Symbol})
810
end

0 commit comments

Comments
 (0)