Skip to content

Commit 3555394

Browse files
committed
Starting to debug; still need to add exported reductions support.
1 parent 2d4c560 commit 3555394

File tree

8 files changed

+305
-68
lines changed

8 files changed

+305
-68
lines changed

src/LoopVectorization.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -893,11 +893,13 @@ for vec ∈ (false,true)
893893
end
894894
end
895895

896+
include("costs.jl")
897+
include("operations.jl")
898+
include("graphs.jl")
899+
include("determinestrategy.jl")
900+
include("lowering.jl")
901+
include("constructors.jl")
896902
include("precompile.jl")
897903
_precompile_()
898904

899-
function __init__()
900-
_precompile_()
901-
end
902-
903905
end # module

src/constructors.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,11 @@ function walk_body!(ls::LoopSet, body::Expr)
77
end
88
function Base.copyto!(ls::LoopSet, q::Expr)
99
q.head === :for || throw("Expression must be a for loop.")
10-
add_loop!(ls, q.args[1])
11-
body = q.args[2]
12-
10+
add_loop!(ls, q)
1311
end
1412

1513
function LoopSet(q::Expr)
16-
q = contract_pass(q)
14+
q = SIMDPirates.contract_pass(q)
1715
ls = LoopSet()
1816
copyto!(ls, q)
1917
resize!(ls.loop_order, num_loops(ls))

src/costs.jl

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ struct InstructionCost
1212
scalar_latency::Int
1313
register_pressure::Int
1414
end
15-
InstructionCost(sl::Int, srt::Float64, scaling::Float64 = -3.0) = InstructionCost(scaling, srt, sl, srt, 0)
15+
InstructionCost(sl::Int, srt::Float64, scaling::Float64 = -3.0) = InstructionCost(scaling, srt, sl, 0)
1616

1717
function scalar_cost(instruction::InstructionCost)#, ::Type{T} = Float64) where {T}
1818
@unpack scalar_reciprical_throughput, scalar_latency, register_pressure = instruction
@@ -38,13 +38,16 @@ function vector_cost(instruction::InstructionCost, Wshift, sizeof_T)
3838
end
3939
srt, sl, srp
4040
end
41+
instruction_cost(instruction::Symbol) = get(COST, instruction, OPAQUE_INSTRUCTION)
42+
scalar_cost(instr::Symbol) = scalar_cost(instruction_cost(instr))
43+
vector_cost(instr::Symbol, Wshift, sizeof_T) = vector_cost(instruction_cost(instr), Wshift, sizeof_T)
4144
function cost(instruction::InstructionCost, Wshift, sizeof_T)
4245
Wshift == 0 ? scalar_cost(instruction) : vector_cost(instruction, Wshift, sizeof_T)
4346
end
4447

4548
function cost(instruction::Symbol, Wshift, sizeof_T)
4649
cost(
47-
get(COST, instruction, OPAQUE_INSTRUCTION),
50+
instruction_cost(instruction),
4851
Wshift, sizeof_T
4952
)
5053
end
@@ -111,6 +114,42 @@ const CORRESPONDING_REDUCTION = Dict{Symbol,Symbol}(
111114
:vfnmadd => :vsum,
112115
:vfnmsub => :vsum
113116
)
117+
const REDUCTION_TRANSLATION = Dict{Symbol,Symbol}(
118+
:(+) => :evadd,
119+
:vadd => :evadd,
120+
:(*) => :evmul,
121+
:vmul => :evmul,
122+
:(-) => :evadd,
123+
:vsub => :evadd,
124+
:(/) => :evmul,
125+
:vdiv => :evmul,
126+
:muladd => :evadd,
127+
:fma => :evadd,
128+
:vmuladd => :evadd,
129+
:vfma => :evadd,
130+
:vfmadd => :evadd,
131+
:vfmsub => :evadd,
132+
:vfnmadd => :evadd,
133+
:vfnmsub => :evadd
134+
)
135+
const REDUCTION_ZERO = Dict{Symbol,Symbol}(
136+
:(+) => :zero,
137+
:vadd => :zero,
138+
:(*) => :one,
139+
:vmul => :one,
140+
:(-) => :zero,
141+
:vsub => :zero,
142+
:(/) => :one,
143+
:vdiv => :one,
144+
:muladd => :zero,
145+
:fma => :zero,
146+
:vmuladd => :zero,
147+
:vfma => :zero,
148+
:vfmadd => :zero,
149+
:vfmsub => :zero,
150+
:vfnmadd => :zero,
151+
:vfnmsub => :zero
152+
)
114153
# const SIMDPIRATES_COST = Dict{Symbol,InstructionCost}()
115154
# const SLEEFPIRATES_COST = Dict{Symbol,InstructionCost}()
116155

src/determinestrategy.jl

Lines changed: 33 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11

2+
# TODO: FIXME for general case
3+
unitstride(op, s) = first(loopdependencies(op)) === s
4+
25
function cost(op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int)
36
# Wshift == dependson(op, unrolled) ? Wshift : 0
47
# c = first(cost(instruction(op), Wshift, size_T))::Int
@@ -10,12 +13,12 @@ function cost(op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int)
1013
if opisunrolled
1114
if !unitstride(op, unrolled)# || !isdense(op) # need gather/scatter
1215
r = (1 << Wshift)
13-
c *= r
16+
srt *= r
1417
sl *= r
1518
# else # vmov(a/u)pd
1619
end
1720
elseif instr === :setindex! # broadcast or reductionstore; if store we want to penalize reduction
18-
c *= 2
21+
srt *= 2
1922
sl *= 2
2023
end
2124
end
@@ -33,7 +36,12 @@ end
3336
function VectorizationBase.pick_vector_width_shift(ls::LoopSet, u::Symbol)
3437
VectorizationBase.pick_vector_width_shift(length(ls, u), biggest_type_size(ls))
3538
end
36-
39+
function hasintersection(a, b)
40+
for aᵢ a, bᵢ b
41+
aᵢ === bᵢ && return true
42+
end
43+
false
44+
end
3745

3846
# evaluates cost of evaluating loop in given order
3947
# heuristically, could simplify analysis by just unrolling outer loop?
@@ -42,7 +50,7 @@ function evaluate_cost_unroll(
4250
)
4351
# included_vars = Set{UInt}()
4452
included_vars = fill(false, length(operations(ls)))
45-
nested_loop_syms = Set{Symbol}()
53+
nested_loop_syms = Symbol[]#Set{Symbol}()
4654
total_cost = 0.0
4755
iter = 1.0
4856
# Need to check if fusion is possible
@@ -122,10 +130,12 @@ end
122130
function tile_cost(X, U, T)
123131
X[1] + X[4] + X[2] / T + X[3] / U
124132
end
125-
function solve_tilsize(X, R)
133+
function solve_tilesize(X, R)
126134
# We use lagrange multiplier to finding floating point values for U and T
127135
# first solving for U via quadratic formula
128136
# X is vector of costs, and R is of register pressures
137+
@show X
138+
@show R
129139
RR = VectorizationBase.REGISTER_COUNT - R[3] - R[4]
130140
a = (R[1])^2*X[2] - (R[2])^2*R[1]*X[3]/RR
131141
b = 2*R[1]*R[2]*X[3]
@@ -196,7 +206,7 @@ function evaluate_cost_tile(
196206
tiled = order[1]
197207
unrolled = order[2]
198208
included_vars = fill(false, length(operations(ls)))
199-
nested_loop_syms = Set{Symbol}()
209+
nested_loop_syms = Symbol[]# Set{Symbol}()
200210
iter = 1.0
201211
# Need to check if fusion is possible
202212
size_T = biggest_type_size(ls)
@@ -225,24 +235,28 @@ function evaluate_cost_tile(
225235
included_vars[id] && continue
226236
# it must also be a subset of defined symbols
227237
loopdependencies(op) nested_loop_syms || continue
228-
hasintersection(reduceddependencies(op), nested_loop_syms) && return 0,0,Inf
238+
# @show nested_loop_syms
239+
# @show reduceddependencies(op)
240+
rd = reduceddependencies(op)
241+
hasintersection(rd, nested_loop_syms[1:end-length(rd)]) && return 0,0,Inf
229242
included_vars[id] = true
230243
rt, lat, rp = cost(op, unrolled, Wshift, size_T)
244+
@show instruction(op), rt, lat, rp, iter
231245
rt *= iter
232246
isunrolled = unrolled loopdependencies(op)
233247
istiled = tiled loopdependencies(op)
234248
if isunrolled && istiled # no cost decrease; cost must be repeated
235-
cost_vec[1] = rt
236-
reg_pressure[1] = rp
249+
cost_vec[1] += rt
250+
reg_pressure[1] += rp
237251
elseif isunrolled # cost decreased by tiling
238-
cost_vec[2] = rt
239-
reg_pressure[2] = rp
252+
cost_vec[2] += rt
253+
reg_pressure[2] += rp
240254
elseif istiled # cost decreased by unrolling
241-
cost_vec[3] = rt
242-
reg_pressure[3] = rp
255+
cost_vec[3] += rt
256+
reg_pressure[3] += rp
243257
else# neither unrolled or tiled
244-
cost_vec[4] = rt
245-
reg_pressure[4] = rp
258+
cost_vec[4] += rt
259+
reg_pressure[4] += rp
246260
end
247261
end
248262
end
@@ -252,13 +266,13 @@ function evaluate_cost_tile(
252266
if Ustatic
253267
solve_tilesize(cost_vec, reg_pressure, looprangehint(ls, tiled), looprangehint(ls, unrolled))
254268
else
255-
solve_tilesize(cost_vec, reg_pressure, looprangehint(ls, tiled), nothing)
269+
solve_tilesize(cost_vec, reg_pressure, looprangehint(ls, tiled), typemax(Int))
256270
end
257271
else
258272
if Ustatic
259-
solve_tilesize(cost_vec, reg_pressure, nothing, looprangehint(ls, unrolled))
273+
solve_tilesize(cost_vec, reg_pressure, typemax(Int), looprangehint(ls, unrolled))
260274
else
261-
solve_tilesize(cost_vec, reg_pressure)
275+
solve_tilesize(cost_vec, reg_pressure)#, typemax(Int), typemax(Int))
262276
end
263277
end
264278
end
@@ -270,7 +284,7 @@ struct LoopOrders
270284
end
271285
function LoopOrders(ls::LoopSet)
272286
syms = [s for s keys(ls.loops)]
273-
LoopOrders(syms, similar(buff))
287+
LoopOrders(syms, similar(syms))
274288
end
275289
function Base.iterate(lo::LoopOrders)
276290
lo.syms, zeros(Int, length(lo.syms))# - 1)

0 commit comments

Comments
 (0)