Skip to content

Commit becbf17

Browse files
committed
Some fiddling; work on updating lowering of vectorization information.
1 parent 887d0b0 commit becbf17

File tree

3 files changed

+46
-25
lines changed

3 files changed

+46
-25
lines changed

src/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ function add_broadcast!(
172172
reduceddeps = Symbol[]
173173
for (i,arg) enumerate(args)
174174
argname = gensym(:arg)
175-
pushpreamble!(ls, Expr(:(=), argname, Expr(:ref, bcargs, i)))
175+
pushpreamble!(ls, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__,@__FILE__), Expr(:(=), argname, Expr(:ref, bcargs, i))))
176176
# dynamic dispatch
177177
parent = add_broadcast!(ls, gensym(:temp), argname, loopsyms, arg, elementbytes)::Operation
178178
pushparent!(parents, deps, reduceddeps, parent)

src/lowering.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ variable_name(op::Operation, ::Nothing) = mangledvar(op)
33
variable_name(op::Operation, suffix) = Symbol(mangledvar(op), suffix, :_)
44

55
struct TileDescription{T}
6-
u::Int
6+
u::Int32
77
unrolled::Symbol
88
tiled::Symbol
99
suffix::T
@@ -16,18 +16,18 @@ function parentind(ind::Symbol, op::Operation)
1616
end
1717
function symbolind(ind::Symbol, op::Operation, td::TileDescription)
1818
id = parentind(ind, op)
19-
id == -1 && return Expr(:call, :-, ind, 1)
19+
id == -1 && return Expr(:call, :-, ind, one(Int32))
2020
@unpack u, unrolled, tiled, suffix = td
2121
parent = parents(op)[id]
22-
pvar = if loopdependencies(parent) tiled
22+
pvar = if tiled loopdependencies(parent)
2323
variable_name(parent, suffix)
2424
else
2525
mangledvar(parent)
2626
end
27-
if loopdependencies(parent) unrolled
27+
if unrolled loopdependencies(parent)
2828
pvar = Symbol(pvar, u)
2929
end
30-
Expr(:call, :-, pvar, 1)
30+
Expr(:call, :-, pvar, one(Int32))
3131
end
3232
function mem_offset(op::Operation, td::TileDescription)
3333
@assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
@@ -110,7 +110,7 @@ end
110110
# Expr(:call, :+, q, incr)
111111
# end
112112
# end
113-
function varassignname(var::Symbol, u::Int, isunrolled::Bool)
113+
function varassignname(var::Symbol, u::Int32, isunrolled::Bool)
114114
isunrolled ? Symbol(var, u) : var
115115
end
116116
# name_mo only gets called when vectorized
@@ -158,7 +158,7 @@ function lower_load_scalar!(
158158
ptr = refname(op)
159159
isunrolled = unrolled loopdeps
160160
U = isunrolled ? U : 1
161-
for u 0:U-1
161+
for u zero(Int32):Base.unsafe_trunc(Int32,U-1)
162162
varname = varassignname(var, u, isunrolled)
163163
td = TileDescription(u, unrolled, tiled, suffix)
164164
push!(q.args, Expr(:(=), varname, Expr(:call, lv(:load), ptr, mem_offset_u(op, td))))
@@ -172,17 +172,17 @@ function lower_load_vectorized!(
172172
loopdeps = loopdependencies(op)
173173
@assert vectorized loopdeps
174174
if unrolled loopdeps
175-
umin = 0
175+
umin = zero(Int32)
176176
U = U
177177
else
178-
umin = -1
178+
umin = -one(Int32)
179179
U = 0
180180
end
181181
# Urange = unrolled ∈ loopdeps ? 0:U-1 : 0
182182
var = variable_name(op, suffix)
183183
vecnotunrolled = vectorized !== unrolled
184184
if first(getindices(op)) === vectorized # vload
185-
for u umin:U-1
185+
for u umin:Base.unsafe_trunc(Int32,U-1)
186186
td = TileDescription(u, unrolled, tiled, suffix)
187187
pushvectorload!(q, op, var, td, U, W, mask, vecnotunrolled)
188188
end
@@ -191,7 +191,7 @@ function lower_load_vectorized!(
191191
ustrides = Expr(:call, lv(:vmul), Expr(:call, :stride, refname(op), sn), Expr(:call, lv(:vrange), W))
192192
ustride = gensym(:ustride)
193193
push!(q.args, Expr(:(=), ustride, ustrides))
194-
for u umin:U-1
194+
for u umin:Base.unsafe_trunc(Int32,U-1)
195195
td = TileDescription(u, unrolled, tiled, suffix)
196196
pushvectorgather!(q, op, var, td, U, W, mask, ustride, vecnotunrolled)
197197
end
@@ -273,7 +273,7 @@ function lower_store_reduction!(
273273
# need to find out reduction type
274274
instr = first(parents(op)).instruction
275275
reduct_instruct = CORRESPONDING_REDUCTION[instr]
276-
for u 0:U-1
276+
for u zero(Int32):Base.unsafe_trunc(Int32,U-1)
277277
reducedname = varassignname(var, u, isunrolled)
278278
storevar = Expr(reduct_instruct, reducedname)
279279
td = TileDescription(u, unrolled, tiled, suffix)
@@ -287,7 +287,7 @@ function lower_store_scalar!(
287287
)
288288
var = pvariable_name(op, suffix)
289289
ptr = refname(op)
290-
for u 0:U-1
290+
for u zero(Int32):Base.unsafe_trunc(Int32,U-1)
291291
varname = varassignname(var, u, isunrolled)
292292
td = TileDescription(u, unrolled, tiled, suffix)
293293
push!(q.args, Expr(:call, lv(:store!), ptr, varname, mem_offset_u(op, td)))
@@ -302,16 +302,16 @@ function lower_store_vectorized!(
302302
@assert unrolled loopdeps
303303
var = pvariable_name(op, suffix)
304304
if isunrolled
305-
umin = 0
305+
umin = zero(Int32)
306306
U = U
307307
else
308-
umin = -1
308+
umin = -one(Int32)
309309
U = 0
310310
end
311311
ptr = refname(op)
312312
vecnotunrolled = vectorized !== unrolled
313313
if first(loopdependencies(op)) === vectorized # vstore!
314-
for u 0:U-1
314+
for u zero(Int32):Base.unsafe_trunc(Int32,U-1)
315315
td = TileDescription(u, unrolled, tiled, suffix)
316316
name, mo = name_mo(var, op, td, W, vecnotunrolled)
317317
instrcall = Expr(:call,lv(:vstore!), ptr, name, mo)
@@ -323,7 +323,7 @@ function lower_store_vectorized!(
323323
else
324324
sn = findfirst(x -> x === unrolled, loopdependencies(op))::Int
325325
ustrides = Expr(:call, lv(:vmul), Expr(:call, :stride, ptr, sn), Expr(:call, lv(:vrange), W))
326-
for u 0:U-1
326+
for u zero(Int32):Base.unsafe_trunc(Int32,U-1)
327327
td = TileDescription(u, unrolled, tiled, suffix)
328328
name, mo = name_mo(var, op, td, W, vecnotunrolled)
329329
instrcall = Expr(:call, lv(:scatter!), ptr, mo, ustrides, name)

test/runtests.jl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -556,12 +556,6 @@ using LinearAlgebra
556556
end
557557
end
558558

559-
maxdeg = 20
560-
nbasis = 10_000
561-
dim = 10
562-
basis = rand(1:(maxdeg+1), (dim, nbasis))
563-
coeffs = rand(nbasis)
564-
P = rand(dim, maxdeg+1)
565559

566560
function mvp(P, basis, coeffs::Vector{T}) where {T}
567561
len_c = length(coeffs)
@@ -576,7 +570,29 @@ using LinearAlgebra
576570
end
577571
return p
578572
end
579-
573+
function mvpavx(P, basis, coeffs::Vector{T}) where {T}
574+
len_c = length(coeffs)
575+
len_P = size(P, 1)
576+
p = zero(T)
577+
@avx for n = 1:len_c
578+
pn = coeffs[n]
579+
for a = 1:len_P
580+
pn *= P[a, basis[a, n]]
581+
end
582+
p += pn
583+
end
584+
return p
585+
end
586+
maxdeg = 20; nbasis = 10_000; dim = 10;
587+
bq = :(for n = 1:len_c
588+
pn = coeffs[n]
589+
for a = 1:len_P
590+
pn *= P[a, basis[a, n]]
591+
end
592+
p += pn
593+
end)
594+
lsb = LoopVectorization.LoopSet(bq)
595+
580596
for T (Float32, Float64)
581597
@show T, @__LINE__
582598
A = randn(T, 199, 498);
@@ -601,6 +617,11 @@ using LinearAlgebra
601617
x = rand(T, M); A = rand(T, M, N); y = rand(T, N);
602618
@test dot3avx(x, A, y) dot3(x, A, y)
603619

620+
r = T == Float32 ? (Int32(1):Int32(maxdeg+1)) : (1:maxdeg+1)
621+
basis = rand(r, (dim, nbasis));
622+
coeffs = rand(T, nbasis);
623+
P = rand(T, dim, maxdeg+1);
624+
@test mvp(P, basis, coeffs) mvpavx(P, basis, coeffs)
604625
end
605626
end
606627

0 commit comments

Comments
 (0)