Skip to content

Commit d0627d3

Browse files
committed
Further WIP on being able to evaluate more complicated loops. Triangular logdets can now be evaluated.
1 parent becbf17 commit d0627d3

File tree

8 files changed

+388
-312
lines changed

8 files changed

+388
-312
lines changed

src/LoopVectorization.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ module LoopVectorization
22

33
using VectorizationBase, SIMDPirates, SLEEFPirates, MacroTools, Parameters
44
using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector_load_expr,
5-
mask, masktable, pick_vector_width_val, valmul, valrem, valmuladd, valadd, valsub
6-
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul, vrange, reduced_add, reduced_prod
5+
mask, masktable, pick_vector_width_val, valmul, valrem, valmuladd, valadd, valsub, _MM
6+
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul, vrange, reduced_add, reduced_prod, reduce_to_add, reduce_to_prod
77
using Base.Broadcast: Broadcasted, DefaultArrayStyle
88
using LinearAlgebra: Adjoint, Transpose
99
using MacroTools: prewalk, postwalk

src/broadcast.jl

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@ function add_broadcast!(
4747
K = gensym(:K)
4848
mA = gensym(:Aₘₖ)
4949
mB = gensym(:Bₖₙ)
50-
pushpreamble!(ls, Expr(:(=), mA, Expr(:(.), bcname, QuoteNode(:a))))
51-
pushpreamble!(ls, Expr(:(=), mB, Expr(:(.), bcname, QuoteNode(:b))))
52-
pushpreamble!(ls, Expr(:(=), K, Expr(:call, :size, mB, 1)))
50+
pushprepreamble!(ls, Expr(:(=), mA, Expr(:(.), bcname, QuoteNode(:a))))
51+
pushprepreamble!(ls, Expr(:(=), mB, Expr(:(.), bcname, QuoteNode(:b))))
52+
pushprepreamble!(ls, Expr(:(=), K, Expr(:call, :size, mB, 1)))
5353

5454
k = gensym(:k)
55-
ls.loops[k] = Loop(k, K)
55+
ls.loops[k] = Loop(k, 0, K)
5656
m = loopsyms[1];
5757
if ndims(B) == 1
5858
bloopsyms = Symbol[k]
@@ -74,9 +74,9 @@ function add_broadcast!(
7474
# set Cₘₙ = 0
7575
# setC = add_constant!(ls, zero(promote_type(recursive_eltype(A), recursive_eltype(B))), cloopsyms, mC, elementbytes)
7676
setC = if elementbytes == 4
77-
add_constant!(ls, 0f0, cloopsyms, mC, elementbytes)
77+
add_constant!(ls, 0f0, cloopsyms, mC, Symbol(""), elementbytes)
7878
else#if elementbytes == 4
79-
add_constant!(ls, 0.0, cloopsyms, mC, elementbytes)
79+
add_constant!(ls, 0.0, cloopsyms, mC, Symbol(""), elementbytes)
8080
end
8181
# compute Cₘₙ += Aₘₖ * Bₖₙ
8282
reductop = Operation(
@@ -111,7 +111,7 @@ function add_broadcast_adjoint_array!(
111111
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{A}, elementbytes::Int = 8
112112
) where {T,N,A<:AbstractArray{T,N}}
113113
parent = gensym(:parent)
114-
pushpreamble!(ls, Expr(:(=), parent, Expr(:call, :parent, bcname)))
114+
pushprepreamble!(ls, Expr(:(=), parent, Expr(:call, :parent, bcname)))
115115
ref = ArrayReference(parent, Union{Symbol,Int}[loopsyms[N + 1 - n] for n 1:N])
116116
add_simple_load!( ls, destname, ref, elementbytes )::Operation
117117
end
@@ -143,7 +143,7 @@ function add_broadcast!(
143143
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{T}, elementbytes::Int = 8
144144
) where {T<:Union{Integer,Float32,Float64}}
145145
op = add_constant!(ls, destname, elementbytes) # or replace elementbytes with sizeof(T) ? u
146-
pushpreamble!(ls, Expr(:(=), mangledvar(op), bcname))
146+
pushprepreamble!(ls, Expr(:(=), mangledvar(op), bcname))
147147
op
148148
end
149149
function add_broadcast!(
@@ -172,7 +172,7 @@ function add_broadcast!(
172172
reduceddeps = Symbol[]
173173
for (i,arg) enumerate(args)
174174
argname = gensym(:arg)
175-
pushpreamble!(ls, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__,@__FILE__), Expr(:(=), argname, Expr(:ref, bcargs, i))))
175+
pushprepreamble!(ls, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__,@__FILE__), Expr(:(=), argname, Expr(:ref, bcargs, i))))
176176
# dynamic dispatch
177177
parent = add_broadcast!(ls, gensym(:temp), argname, loopsyms, arg, elementbytes)::Operation
178178
pushparent!(parents, deps, reduceddeps, parent)
@@ -195,10 +195,10 @@ end
195195
sizes = Expr(:tuple)
196196
for (n,itersym) enumerate(loopsyms)
197197
Nsym = gensym(:N)
198-
ls.loops[itersym] = Loop(itersym, Nsym)
198+
ls.loops[itersym] = Loop(itersym, 0, Nsym)
199199
push!(sizes.args, Nsym)
200200
end
201-
pushpreamble!(ls, Expr(:(=), sizes, Expr(:call, :size, :dest)))
201+
pushprepreamble!(ls, Expr(:(=), sizes, Expr(:call, :size, :dest)))
202202
elementbytes = sizeof(T)
203203
add_broadcast!(ls, :dest, :bc, loopsyms, BC, elementbytes)
204204
add_simple_store!(ls, :dest, ArrayReference(:dest, loopsyms), elementbytes)
@@ -216,14 +216,14 @@ end
216216
# need to construct the LoopSet
217217
loopsyms = [gensym(:n) for n 1:N]
218218
ls = LoopSet()
219-
pushpreamble!(ls, Expr(:(=), :dest, Expr(:call, :parent, :dest′)))
219+
pushprepreamble!(ls, Expr(:(=), :dest, Expr(:call, :parent, :dest′)))
220220
sizes = Expr(:tuple)
221221
for (n,itersym) enumerate(loopsyms)
222222
Nsym = gensym(:N)
223-
ls.loops[itersym] = Loop(itersym, Nsym)
223+
ls.loops[itersym] = Loop(itersym, 0, Nsym)
224224
push!(sizes.args, Nsym)
225225
end
226-
pushpreamble!(ls, Expr(:(=), sizes, Expr(:call, :size, :dest′)))
226+
pushprepreamble!(ls, Expr(:(=), sizes, Expr(:call, :size, :dest′)))
227227
elementbytes = sizeof(T)
228228
add_broadcast!(ls, :dest, :bc, loopsyms, BC, elementbytes)
229229
add_simple_store!(ls, :dest, ArrayReference(:dest, reverse(loopsyms)), elementbytes)

src/costs.jl

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,15 @@ const COST = Dict{Instruction,InstructionCost}(
8484
Instruction(:vadd) => InstructionCost(4,0.5),
8585
Instruction(:vsub) => InstructionCost(4,0.5),
8686
Instruction(:vmul) => InstructionCost(4,0.5),
87-
Instruction(:vdiv) => InstructionCost(13,4.0,-2.0),
87+
Instruction(:vfdiv) => InstructionCost(13,4.0,-2.0),
88+
Instruction(:evadd) => InstructionCost(4,0.5),
89+
Instruction(:evsub) => InstructionCost(4,0.5),
90+
Instruction(:evmul) => InstructionCost(4,0.5),
91+
Instruction(:evfdiv) => InstructionCost(13,4.0,-2.0),
92+
Instruction(:reduced_add) => InstructionCost(4,0.5),# ignoring reduction part of cost, might be nop
93+
Instruction(:reduced_prod) => InstructionCost(4,0.5),# ignoring reduction part of cost, might be nop
94+
Instruction(:reduce_to_add) => InstructionCost(0,0.0,0.0,0),
95+
Instruction(:reduce_to_prod) => InstructionCost(0,0.0,0.0,0),
8896
Instruction(:abs2) => InstructionCost(4,0.5),
8997
Instruction(:vabs2) => InstructionCost(4,0.5),
9098
Instruction(:(==)) => InstructionCost(1, 0.5),
@@ -110,14 +118,20 @@ const COST = Dict{Instruction,InstructionCost}(
110118
Instruction(:vfnmadd_fast) => InstructionCost(4,0.5), # + and -* will fuse into this, so much of the time they're not twice as expensive
111119
Instruction(:vfnmsub_fast) => InstructionCost(4,0.5), # - and -* will fuse into this, so much of the time they're not twice as expensive
112120
Instruction(:sqrt) => InstructionCost(15,4.0,-2.0),
121+
Instruction(:sqrt_fast) => InstructionCost(15,4.0,-2.0),
113122
Instruction(:log) => InstructionCost(20,20.0,40.0,20),
114123
Instruction(:exp) => InstructionCost(20,20.0,20.0,18),
115124
Instruction(:sin) => InstructionCost(18,15.0,68.0,23),
116125
Instruction(:cos) => InstructionCost(18,15.0,68.0,26),
117126
Instruction(:sincos) => InstructionCost(25,22.0,70.0,26),
127+
Instruction(:log_fast) => InstructionCost(20,20.0,40.0,20),
128+
Instruction(:exp_fast) => InstructionCost(20,20.0,20.0,18),
129+
Instruction(:sin_fast) => InstructionCost(18,15.0,68.0,23),
130+
Instruction(:cos_fast) => InstructionCost(18,15.0,68.0,26),
131+
Instruction(:sincos_fast) => InstructionCost(25,22.0,70.0,26),
118132
Instruction(:identity) => InstructionCost(0,0.0,0.0,0),
119133
Instruction(:adjoint) => InstructionCost(0,0.0,0.0,0),
120-
Instruction(:transpose) => InstructionCost(0,0.0,0.0,0),
134+
Instruction(:transpose) => InstructionCost(0,0.0,0.0,0)
121135
# Symbol("##CONSTANT##") => InstructionCost(0,0.0)
122136
)
123137
# for (k, v) ∈ COST # so we can look up Symbol(typeof(function))
@@ -131,6 +145,9 @@ const CORRESPONDING_REDUCTION = Dict{Instruction,Instruction}(
131145
Instruction(:vadd) => Instruction(:vsum),
132146
Instruction(:vsub) => Instruction(:vsum),
133147
Instruction(:vmul) => Instruction(:vprod),
148+
Instruction(:evadd) => Instruction(:vsum),
149+
Instruction(:evsub) => Instruction(:vsum),
150+
Instruction(:evmul) => Instruction(:vprod),
134151
Instruction(:&) => Instruction(:vall),
135152
Instruction(:|) => Instruction(:vany),
136153
Instruction(:muladd) => Instruction(:vsum),
@@ -140,7 +157,11 @@ const CORRESPONDING_REDUCTION = Dict{Instruction,Instruction}(
140157
Instruction(:vfmadd) => Instruction(:vsum),
141158
Instruction(:vfmsub) => Instruction(:vsum),
142159
Instruction(:vfnmadd) => Instruction(:vsum),
143-
Instruction(:vfnmsub) => Instruction(:vsum)
160+
Instruction(:vfnmsub) => Instruction(:vsum),
161+
Instruction(:vfmadd_fast) => Instruction(:vsum),
162+
Instruction(:vfmsub_fast) => Instruction(:vsum),
163+
Instruction(:vfnmadd_fast) => Instruction(:vsum),
164+
Instruction(:vfnmsub_fast) => Instruction(:vsum)
144165
)
145166
const REDUCTION_TRANSLATION = Dict{Instruction,Instruction}(
146167
Instruction(:+) => Instruction(:evadd),
@@ -158,25 +179,37 @@ const REDUCTION_TRANSLATION = Dict{Instruction,Instruction}(
158179
Instruction(:vfmadd) => Instruction(:evadd),
159180
Instruction(:vfmsub) => Instruction(:evadd),
160181
Instruction(:vfnmadd) => Instruction(:evadd),
161-
Instruction(:vfnmsub) => Instruction(:evadd)
182+
Instruction(:vfnmsub) => Instruction(:evadd),
183+
Instruction(:vfmadd_fast) => Instruction(:evadd),
184+
Instruction(:vfmsub_fast) => Instruction(:evadd),
185+
Instruction(:vfnmadd_fast) => Instruction(:evadd),
186+
Instruction(:vfnmsub_fast) => Instruction(:evadd)
162187
)
163188
const REDUCTION_ZERO = Dict{Instruction,Symbol}(
164189
Instruction(:+) => :zero,
165190
Instruction(:vadd) => :zero,
191+
Instruction(:evadd) => :zero,
166192
Instruction(:*) => :one,
167193
Instruction(:vmul) => :one,
194+
Instruction(:evmul) => :one,
168195
Instruction(:-) => :zero,
169196
Instruction(:vsub) => :zero,
197+
Instruction(:evsub) => :zero,
170198
Instruction(:/) => :one,
171199
Instruction(:vfdiv) => :one,
200+
Instruction(:evfdiv) => :one,
172201
Instruction(:muladd) => :zero,
173202
Instruction(:fma) => :zero,
174203
Instruction(:vmuladd) => :zero,
175204
Instruction(:vfma) => :zero,
176205
Instruction(:vfmadd) => :zero,
177206
Instruction(:vfmsub) => :zero,
178207
Instruction(:vfnmadd) => :zero,
179-
Instruction(:vfnmsub) => :zero
208+
Instruction(:vfnmsub) => :zero,
209+
Instruction(:vfmadd_fast) => :zero,
210+
Instruction(:vfmsub_fast) => :zero,
211+
Instruction(:vfnmadd_fast) => :zero,
212+
Instruction(:vfnmsub_fast) => :zero
180213
)
181214

182215
lv(x) = GlobalRef(LoopVectorization, x)
@@ -197,7 +230,15 @@ const REDUCTION_SCALAR_COMBINE = Dict{Instruction,GlobalRef}(
197230
Instruction(:vfmadd) => lv(:reduced_add),
198231
Instruction(:vfmsub) => lv(:reduced_add),
199232
Instruction(:vfnmadd) => lv(:reduced_add),
200-
Instruction(:vfnmsub) => lv(:reduced_add)
233+
Instruction(:vfnmsub) => lv(:reduced_add),
234+
Instruction(:vfmadd_fast) => lv(:reduced_add),
235+
Instruction(:vfmsub_fast) => lv(:reduced_add),
236+
Instruction(:vfnmadd_fast) => lv(:reduced_add),
237+
Instruction(:vfnmsub_fast) => lv(:reduced_add)
238+
)
239+
const REDUCTION_COMBINETO = Dict{Symbol,Symbol}(
240+
:reduced_add => :reduce_to_add,
241+
:reduced_prod => :reduce_to_prod
201242
)
202243

203244
const FUNCTIONSYMBOLS = Dict{Type{<:Function},Instruction}(
@@ -230,6 +271,10 @@ const FUNCTIONSYMBOLS = Dict{Type{<:Function},Instruction}(
230271
typeof(SIMDPirates.vfmsub) => :vfmsub,
231272
typeof(SIMDPirates.vfnmadd) => :vfnmadd,
232273
typeof(SIMDPirates.vfnmsub) => :vfnmsub,
274+
typeof(SIMDPirates.vfmadd_fast) => :vfmadd_fast,
275+
typeof(SIMDPirates.vfmsub_fast) => :vfmsub_fast,
276+
typeof(SIMDPirates.vfnmadd_fast) => :vfnmadd_fast,
277+
typeof(SIMDPirates.vfnmsub_fast) => :vfnmsub_fast,
233278
typeof(sqrt) => :sqrt,
234279
typeof(Base.FastMath.sqrt_fast) => :sqrt,
235280
typeof(SIMDPirates.vsqrt) => :sqrt,

src/determinestrategy.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,13 @@ function parentsnotreduction(op::Operation)
110110
end
111111
return true
112112
end
113+
function roundpow2(i::Integer)
114+
u = VectorizationBase.nextpow2(i)
115+
l = u >>> 1
116+
ud = u - i
117+
ld = i - l
118+
ud > ld ? l : u
119+
end
113120
function unroll_no_reductions(ls, order, vectorized, Wshift, size_T)
114121
innermost = last(order)
115122
compute_rt = 0.0
@@ -125,7 +132,7 @@ function unroll_no_reductions(ls, order, vectorized, Wshift, size_T)
125132
end
126133
# heuristic guess
127134
# @show compute_rt, load_rt
128-
min(4, round(Int, (compute_rt + load_rt + 1) / compute_rt))
135+
roundpow2(min(4, round(Int, (compute_rt + load_rt + 1) / compute_rt)))
129136
end
130137
function determine_unroll_factor(
131138
ls::LoopSet, order::Vector{Symbol}, unrolled::Symbol, vectorized::Symbol = first(order)
@@ -171,7 +178,7 @@ function determine_unroll_factor(
171178
load_recip_throughput,
172179
store_recip_throughput
173180
)
174-
max(1, round(Int, latency / (recip_throughput * num_reductions) ) )
181+
roundpow2(max(1, round(Int, latency / (recip_throughput * num_reductions) ) ))
175182
end
176183

177184
function tile_cost(X, U, T)
@@ -434,6 +441,7 @@ function choose_unroll_order(ls::LoopSet, lowest_cost::Float64 = Inf)
434441
end
435442
function choose_tile(ls::LoopSet)
436443
lo = LoopOrders(ls)
444+
# @show lo.syms ls.loop_order.bestorder
437445
best_order = copyto!(ls.loop_order.bestorder, lo.syms)
438446
best_vec = first(best_order) # filler
439447
new_order, state = iterate(lo) # right now, new_order === best_order

0 commit comments

Comments
 (0)