Skip to content

Commit 2572896

Browse files
committed
Fixed a lot of bugs.
1 parent b39791e commit 2572896

File tree

8 files changed

+476
-198
lines changed

8 files changed

+476
-198
lines changed

src/LoopVectorization.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
module LoopVectorization
22

33
using VectorizationBase, SIMDPirates, SLEEFPirates, MacroTools, Parameters
4-
using VectorizationBase: REGISTER_SIZE, extract_data, num_vector_load_expr
5-
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul
4+
using VectorizationBase: REGISTER_SIZE, extract_data, num_vector_load_expr, mask
5+
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul, vrange, reduced_add, reduced_prod
66
using MacroTools: prewalk, postwalk
77

8-
export vectorizable, @vectorize, @vvectorize
8+
export vectorizable, @vectorize, @vvectorize, @avx
99

1010
function isdense end #
1111

src/constructors.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11

22
### This file contains convenience functions for constructing LoopSets.
33

4-
function walk_body!(ls::LoopSet, body::Expr)
5-
6-
7-
end
84
function Base.copyto!(ls::LoopSet, q::Expr)
95
q.head === :for || throw("Expression must be a for loop.")
106
add_loop!(ls, q)
@@ -18,5 +14,7 @@ function LoopSet(q::Expr)
1814
ls
1915
end
2016

21-
17+
macro avx(q)
18+
esc(lower(LoopSet(q)))
19+
end
2220

src/costs.jl

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ const REDUCTION_TRANSLATION = Dict{Symbol,Symbol}(
123123
:(-) => :evadd,
124124
:vsub => :evadd,
125125
:(/) => :evmul,
126-
:vdiv => :evmul,
126+
:vfdiv => :evmul,
127127
:muladd => :evadd,
128128
:fma => :evadd,
129129
:vmuladd => :evadd,
@@ -141,7 +141,7 @@ const REDUCTION_ZERO = Dict{Symbol,Symbol}(
141141
:(-) => :zero,
142142
:vsub => :zero,
143143
:(/) => :one,
144-
:vdiv => :one,
144+
:vfdiv => :one,
145145
:muladd => :zero,
146146
:fma => :zero,
147147
:vmuladd => :zero,
@@ -151,12 +151,46 @@ const REDUCTION_ZERO = Dict{Symbol,Symbol}(
151151
:vfnmadd => :zero,
152152
:vfnmsub => :zero
153153
)
154-
# const SIMDPIRATES_COST = Dict{Symbol,InstructionCost}()
155-
# const SLEEFPIRATES_COST = Dict{Symbol,InstructionCost}()
154+
# Fast functions, because common pattern is
155+
const REDUCTION_SCALAR_COMBINE = Dict{Symbol,Expr}(
156+
:(+) => :(LoopVectorization.reduced_add),
157+
:vadd => :(LoopVectorization.reduced_add),
158+
:(*) => :(LoopVectorization.reduced_prod),
159+
:vmul => :(LoopVectorization.reduced_prod),
160+
:(-) => :(LoopVectorization.reduced_add),
161+
:vsub => :(LoopVectorization.reduced_add),
162+
:(/) => :(LoopVectorization.reduced_prod),
163+
:vfdiv => :(LoopVectorization.reduced_prod),
164+
:muladd => :(LoopVectorization.reduced_add),
165+
:fma => :(LoopVectorization.reduced_add),
166+
:vmuladd => :(LoopVectorization.reduced_add),
167+
:vfma => :(LoopVectorization.reduced_add),
168+
:vfmadd => :(LoopVectorization.reduced_add),
169+
:vfmsub => :(LoopVectorization.reduced_add),
170+
:vfnmadd => :(LoopVectorization.reduced_add),
171+
:vfnmsub => :(LoopVectorization.reduced_add)
172+
)
173+
174+
const FUNCTION_MODULES = Dict{Symbol,Expr}(
175+
:vadd => :(LoopVectorization.vadd),
176+
:vmul => :(LoopVectorization.vmul),
177+
:vsub => :(LoopVectorization.vsub),
178+
:vfdiv => :(LoopVectorization.vfdiv),
179+
:vmuladd => :(LoopVectorization.vmuladd),
180+
:vfma => :(LoopVectorization.vfma),
181+
:vfmadd => :(LoopVectorization.vfmadd),
182+
:vfmsub => :(LoopVectorization.vfmsub),
183+
:vfnmadd => :(LoopVectorization.vfnmadd),
184+
:vfnmsub => :(LoopVectorization.vfnmsub),
185+
:vsqrt => :(LoopVectorization.vsqrt),
186+
:log => :(LoopVectorization.SIMDPirates.vlog),
187+
:exp => :(LoopVectorization.SIMDPirates.vexp),
188+
:sin => :(LoopVectorization.SLEEFPirates.sin),
189+
:cos => :(LoopVectorization.SLEEFPirates.cos),
190+
:sincos => :(LoopVectorization.SLEEFPirates.sincos)
191+
)
192+
function callfun(f::Symbol)
193+
Expr(:call, get(FUNCTION_MODULES, f, f))::Expr
194+
end
156195

157-
# const MODULE_LOOKUP = Dict{Symbol,Dict{Symbol,InstructionCost}}(
158-
# :Base => BASE_COST,
159-
# :SIMDPirates => SIMDPIRATES_COST,
160-
# :SLEEFPirates => SLEEFPIRATES_COST
161-
# )
162196

src/determinestrategy.jl

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,22 +88,21 @@ end
8888

8989
# only covers unrolled ops; everything else considered lifted?
9090
function depchain_cost!(
91-
skip::Vector{Bool}, op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int, sl::Int = 0, rt::Float64 = 0.0
91+
skip::Vector{Bool}, op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int, rt::Float64 = 0.0, sl::Int = 0
9292
)
9393
skip[identifier(op)] = true
9494
# depth first search
9595
for opp parents(op)
9696
skip[identifier(opp)] && continue
97-
sl, rt = depchain_cost!(skip, opp, unrolled, Wshift, size_T, sl, rt)
97+
rt, sl = depchain_cost!(skip, opp, unrolled, Wshift, size_T, rt, sl)
9898
end
9999
# Basically assuming memory and compute don't conflict, but everything else does
100100
# Ie, ignoring the fact that integer and floating point operations likely don't either
101-
if accesses_memory(op)
102-
return sl, rt
101+
if iscompute(op)
102+
rtᵢ, slᵢ = cost(op, unrolled, Wshift, size_T)
103+
rt += rtᵢ; sl += slᵢ
103104
end
104-
# @show instruction(op)
105-
rtᵢ, slᵢ = cost(op, unrolled, Wshift, size_T)
106-
sl + slᵢ, rt + rtᵢ
105+
rt, sl
107106
end
108107

109108
function determine_unroll_factor(
@@ -116,21 +115,35 @@ function determine_unroll_factor(
116115
# The assumption here is that unrolling provides no real benefit, unless it is needed to enable OOO execution by breaking up these dependency chains
117116
num_reductions = sum(isreduction, operations(ls))
118117
# @show num_reductions
119-
iszero(num_reductions) && return 1
118+
if iszero(num_reductions) # the 4 is a hack, based on the idea that there is some cost to moving through columns
119+
return length(order) == 1 ? 1 : 4
120+
end
120121
# So if num_reductions > 0, we set the unroll factor to be high enough so that the CPU can be kept busy
121122
# if there are, U = max(1, round(Int, max(latency) * throughput / num_reductions)) = max(1, round(Int, latency / (recip_throughput * num_reductions)))
122123
# We also make sure register pressure is not too high.
123124
latency = 0
124-
recip_throughput = 0.0
125+
compute_recip_throughput = 0.0
125126
visited_nodes = fill(false, length(operations(ls)))
127+
load_recip_throughput = 0.0
128+
store_recip_throughput = 0.0
126129
for op operations(ls)
127-
if isreduction(op) && dependson(op, unrolled)
128-
sl, rt = depchain_cost!(visited_nodes, op, unrolled, Wshift, size_T)
130+
dependson(op, unrolled) || continue
131+
if isreduction(op)
132+
rt, sl = depchain_cost!(visited_nodes, op, unrolled, Wshift, size_T)
129133
latency = max(sl, latency)
130-
recip_throughput += rt
134+
compute_recip_throughput += rt
135+
elseif isload(op)
136+
load_recip_throughput += first(cost(op, unrolled, Wshift, size_T))
137+
elseif isstore(op)
138+
store_recip_throughput += first(cost(op, unrolled, Wshift, size_T))
131139
end
132140
end
133-
max(1, round(Int, latency / (recip_throughput * num_reductions) ) )
141+
recip_throughput = max(
142+
compute_recip_throughput,
143+
load_recip_throughput,
144+
store_recip_throughput
145+
)
146+
max(1, round(Int, latency / (recip_throughput * num_reductions) ) )
134147
end
135148

136149
function tile_cost(X, U, T)

src/graphs.jl

Lines changed: 63 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
1+
lv(x) = Expr(:(.), :LoopVectorization, QuoteNode(x))
22

33
isdense(::Type{<:DenseArray}) = true
44

@@ -72,6 +72,7 @@ Base.@propagate_inbounds Base.getindex(lo::LoopOrder, i::Int) = lo.oporder[i]
7272
Base.@propagate_inbounds Base.getindex(lo::LoopOrder, i...) = lo.oporder[LinearIndices(size(lo))[i...]]
7373

7474
# Must make it easy to iterate
75+
# outer_reductions is a vector of indixes (within operation vectors) of the reduction operation, eg the vmuladd op in a dot product
7576
struct LoopSet
7677
loops::Dict{Symbol,Loop} # sym === loops[sym].itersymbol
7778
opdict::Dict{Symbol,Operation}
@@ -80,7 +81,13 @@ struct LoopSet
8081
loop_order::LoopOrder
8182
# stridesets::Dict{ShortVector{Symbol},ShortVector{Symbol}}
8283
preamble::Expr # TODO: add preamble to lowering
83-
includedarrays::Vector{Symbol}
84+
includedarrays::Vector{Tuple{Symbol,Int}}
85+
end
86+
function includesarray(ls::LoopSet, array::Symbol)
87+
for (a,i) ls.includedarrays
88+
a === array && return i
89+
end
90+
-1
8491
end
8592
function LoopSet()
8693
LoopSet(
@@ -90,7 +97,7 @@ function LoopSet()
9097
Int[],
9198
LoopOrder(),
9299
Expr(:block,),
93-
Symbol[]
100+
Tuple{Symbol,Int}[]
94101
)
95102
end
96103
num_loops(ls::LoopSet) = length(ls.loops)
@@ -106,13 +113,13 @@ looprangesym(ls::LoopSet, s::Symbol) = ls.loops[s].rangesym
106113
getop(ls::LoopSet, s::Symbol) = ls.opdict[s]
107114
getop(ls::LoopSet, i::Int) = ls.operations[i + 1]
108115

109-
function looprange(ls::LoopSet, s::Symbol, incr::Int = 1)
116+
function looprange(ls::LoopSet, s::Symbol, incr::Int = 1, mangledname::Symbol = s)
110117
loop = ls.loops[s]
111118
incr -= 1
112119
if iszero(incr)
113-
Expr(:call, :<, s, loop.hintexact ? loop.rangehint : loop.rangesym)
120+
Expr(:call, :<, mangledname, loop.hintexact ? loop.rangehint : loop.rangesym)
114121
else
115-
Expr(:call, :<, s, loop.hintexact ? loop.rangehint - incr : Expr(:call, :-, loop.rangesym, incr))
122+
Expr(:call, :<, mangledname, loop.hintexact ? loop.rangehint - incr : Expr(:call, :-, loop.rangesym, incr))
116123
end
117124
end
118125
function Base.length(ls::LoopSet, is::Symbol)
@@ -193,10 +200,10 @@ function add_loop!(ls::LoopSet, q::Expr, elementbytes::Int = 8)
193200
Base.push!(ls, q, elementbytes)
194201
end
195202
end
196-
function add_vptr!(ls::LoopSet, indexed::Symbol)
197-
if indexed ls.includedarrays
198-
push!(ls.includedarrays, indexed)
199-
push!(ls.preamble.args, Expr(:(=), Symbol(:vptr_, indexed), Expr(:call, Expr(:(.), :VectorizationBase, QuoteNode(:stridedpointer)), indexed)))
203+
function add_vptr!(ls::LoopSet, indexed::Symbol, id::Int)
204+
if includesarray(ls, indexed) < 0
205+
push!(ls.includedarrays, (indexed, id))
206+
push!(ls.preamble.args, Expr(:(=), Symbol("##vptr##_", indexed), Expr(:call, lv(:stridedpointer), indexed)))
200207
end
201208
nothing
202209
end
@@ -205,7 +212,7 @@ function add_load!(
205212
ls::LoopSet, var::Symbol, indexed::Symbol, indices::AbstractVector, elementbytes::Int = 8
206213
)
207214
op = Operation( length(operations(ls)), var, elementbytes, :getindex, memload, indices, [indexed], NOPARENTS )
208-
add_vptr!(ls, indexed)
215+
add_vptr!(ls, indexed, identifier(op))
209216
pushop!(ls, op, var)
210217
end
211218
function add_load_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
@@ -240,12 +247,12 @@ end
240247
### if it is a literal, that literal is either var"##ZERO#Float##", var"##ONE#Float##", or has to have been assigned to var in the preamble.
241248
# if it is a literal, that literal has to have been assigned to var in the preamble.
242249
function add_constant!(ls::LoopSet, var::Symbol, elementbytes::Int = 8)
243-
pushop!(ls, Operation(length(operations(ls)), var, elementbytes, Symbol("##CONSTANT##"), constant, NODEPENDENCY, NODEPENDENCY, NOPARENTS), var)
250+
pushop!(ls, Operation(length(operations(ls)), var, elementbytes, Symbol("##CONSTANT##"), constant, NODEPENDENCY, Symbol[], NOPARENTS), var)
244251
end
245252
function add_constant!(ls::LoopSet, var, elementbytes::Int = 8)
246253
sym = gensym(:temp)
247254
push!(ls.preamble.args, Expr(:(=), sym, var))
248-
pushop!(ls, Operation(length(operations(ls)), sym, elementbytes, Symbol("##CONSTANT##"), constant, NODEPENDENCY, NODEPENDENCY, NOPARENTS), sym)
255+
pushop!(ls, Operation(length(operations(ls)), sym, elementbytes, Symbol("##CONSTANT##"), constant, NODEPENDENCY, Symbol[], NOPARENTS), sym)
249256
end
250257
# This version has loop dependencies. var gets assigned to sym when lowering.
251258
function add_constant!(ls::LoopSet, var::Symbol, deps::Vector{Symbol}, sym::Symbol = gensym(:constant), elementbytes::Int = 8)
@@ -260,9 +267,28 @@ end
260267
function pushparent!(parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, parent::Operation)
261268
push!(parents, parent)
262269
mergesetv!(deps, loopdependencies(parent))
263-
isload(parent) || mergesetv!(reduceddeps, reduceddependencies(parent))
270+
if !(isload(parent) || isconstant(parent))
271+
mergesetv!(reduceddeps, reduceddependencies(parent))
272+
end
264273
nothing
265274
end
275+
function maybe_cse_load!(ls::LoopSet, expr::Expr, elementbytes::Int = 8)
276+
if expr.head === :ref
277+
array = first(expr.args)::Symbol
278+
args = @view expr.args[2:end]
279+
elseif expr.head === :call && first(expr.args) === :getindex
280+
array = (expr.args[2])::Symbol
281+
args = @view expr.args[3:end]
282+
else
283+
return add_operation!(ls, gensym(:temporary), expr, elementbytes)
284+
end
285+
id = includesarray(ls, array)
286+
if id > 0
287+
ls.operations[id]
288+
else
289+
add_load!( ls, gensym(:temporary), array, args, elementbytes )
290+
end
291+
end
266292
function add_parent!(
267293
parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet, var, elementbytes::Int = 8
268294
)
@@ -271,9 +297,8 @@ function add_parent!(
271297
# might add constant
272298
add_constant!(ls, var, elementbytes)
273299
end
274-
elseif var isa Expr
275-
temp = gensym(:temporary)
276-
add_operation!(ls, temp, var, elementbytes)
300+
elseif var isa Expr #CSE candidate
301+
maybe_cse_load!(ls, var, elementbytes)
277302
else # assumed constant
278303
add_constant!(ls, var, elementbytes)
279304
end
@@ -283,13 +308,23 @@ function add_reduction!(
283308
parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet, var::Symbol, elementbytes::Int = 8
284309
)
285310
get!(ls.opdict, var) do
286-
p = add_constant!(ls, var, elementbytes)
287-
push!(ls.outer_reductions, identifier(p))
288-
p
311+
add_constant!(ls, var, elementbytes)
312+
# push!(ls.outer_reductions, identifier(p))
313+
# p
289314
end
290315
# pushparent!(parents, deps, reduceddeps, parent)
291316
end
292-
317+
function add_reduction_update_parent!(
318+
parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet,
319+
var::Symbol, instr::Symbol, elementbytes::Int = 8
320+
)
321+
parent = getop(ls, var)
322+
setdiffv!(reduceddeps, deps, loopdependencies(parent))
323+
pushparent!(parents, deps, reduceddeps, parent) # deps and reduced deps will not be disjoint
324+
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, parents)
325+
parent.instruction === Symbol("##CONSTANT##") && push!(ls.outer_reductions, identifier(op))
326+
pushop!(ls, op, var) # note this overwrites the entry in the operations dict, but not the vector
327+
end
293328
function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
294329
@assert ex.head === :call
295330
instr = instruction(first(ex.args))::Symbol
@@ -308,20 +343,18 @@ function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
308343
end
309344
end
310345
if reduction # arg[reduction] is the reduction
311-
parent = getop(ls, var)
312-
setdiffv!(reduceddeps, deps, loopdependencies(parent))
313-
pushparent!(parents, deps, reduceddeps, parent) # deps and reduced deps will not be disjoint
314-
# append!(reduceddependencies(parent), reduceddeps)
346+
add_reduction_update_parent!(parents, deps, reduceddeps, ls, var, instr, elementbytes)
347+
else
348+
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, parents)
349+
pushop!(ls, op, var)
315350
end
316-
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, parents)
317-
pushop!(ls, op, var) # note this overwrites the entry in the operations dict, but not the vector
318351
end
319352
function add_store!(
320353
ls::LoopSet, indexed::Symbol, var::Symbol, indices::AbstractVector, elementbytes::Int = 8
321354
)
322355
parent = getop(ls, var)
323356
op = Operation( length(operations(ls)), indexed, elementbytes, :setindex!, memstore, indices, reduceddependencies(parent), [parent] )
324-
add_vptr!(ls, indexed)
357+
add_vptr!(ls, indexed, identifier(op))
325358
pushop!(ls, op, var)
326359
end
327360
function add_store_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
@@ -414,11 +447,11 @@ function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
414447
isunrolled = (unrolled loopdependencies(op)) + 1
415448
istiled = (loopistiled ? (tiled loopdependencies(op)) : false) + 1
416449
optype = Int(op.node_type) + 1
417-
after_loop = (length(reduceddependencies(op)) > 0) + 1
450+
after_loop = isload(op) ? 1 : (length(reduceddependencies(op)) > 0) + 1
418451
push!(lo[optype,isunrolled,istiled,after_loop,_n], op)
419452
end
420453
end
421-
@show 3, ro, order
454+
# 3, ro, order
422455
end
423456

424457

0 commit comments

Comments
 (0)