Skip to content

Commit bb75698

Browse files
committed
Series of updates to better support a few more complicated loops.
1 parent d0627d3 commit bb75698

File tree

3 files changed

+35
-28
lines changed

3 files changed

+35
-28
lines changed

src/graphs.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ end
428428
function add_load!(
429429
ls::LoopSet, var::Symbol, mpref::ArrayReferenceMetaPosition, elementbytes::Int = 8
430430
)
431+
length(mpref.loopdependencies) == 0 && return add_constant!(ls, var, mpref, elementbytes)
431432
ref = mpref.mref.ref
432433
# try to CSE
433434
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
@@ -518,6 +519,12 @@ function add_constant!(ls::LoopSet, var, elementbytes::Int = 8)
518519
pushpreamble!(ls, Expr(:(=), mangledvar(op), var))
519520
pushop!(ls, op, sym)
520521
end
522+
function add_constant!(ls::LoopSet, var::Symbol, mpref::ArrayReferenceMetaPosition, elementbytes::Int)
523+
op = Operation(length(operations(ls)), var, elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS, mpref.mref)
524+
add_vptr!(ls, op)
525+
pushpreamble!(ls, Expr(:(=), mangledvar(op), Expr(:call, lv(:load), mpref.mref.ptr, mem_offset(op, TileDescription(zero(Int32), Symbol(""), Symbol(""), nothing)))))
526+
pushop!(ls, op, var)
527+
end
521528
# This version has loop dependencies. var gets assigned to sym when lowering.
522529
function add_constant!(ls::LoopSet, var::Symbol, deps::Vector{Symbol}, sym::Symbol = gensym(:constant), f::Symbol = Symbol(""), elementbytes::Int = 8)
523530
# length(deps) == 0 && push!(ls.preamble.args, Expr(:(=), sym, var))
@@ -533,7 +540,7 @@ end
533540
function pushparent!(parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, parent::Operation)
534541
push!(parents, parent)
535542
mergesetdiffv!(deps, loopdependencies(parent), reduceddependencies(parent))
536-
if !(isload(parent) || isconstant(parent))
543+
if !(isload(parent) || isconstant(parent)) && parent.instruction.instr (:reduced_add, :reduced_prod, :reduce_to_add, :reduce_to_prod)
537544
mergesetv!(reduceddeps, reduceddependencies(parent))
538545
end
539546
nothing
@@ -585,8 +592,12 @@ function add_reduction_update_parent!(
585592
reductcombine = Symbol("")
586593
end
587594
# mergesetv!(reduceddeps, deps)
588-
setdiffv!(reduceddeps, deps, loopdependencies(reductinit))
589-
mergesetv!(reduceddependencies(reductinit), reduceddeps)
595+
if length(reduceddependencies(reductinit)) == 0
596+
setdiffv!(reduceddeps, deps, loopdependencies(reductinit))
597+
else
598+
setdiffv!(reduceddeps, deps, loopdependencies(reductinit))
599+
end
600+
# mergesetv!(reduceddependencies(reductinit), reduceddeps)
590601
pushparent!(parents, deps, reduceddeps, reductinit)#parent) # deps and reduced deps will not be disjoint
591602
op = Operation(length(operations(ls)), reductsym, elementbytes, instr, compute, deps, reduceddeps, parents)
592603
parent.instruction === LOOPCONSTANT && push!(ls.outer_reductions, identifier(op))

src/lowering.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ function symbolind(ind::Symbol, op::Operation, td::TileDescription)
3030
Expr(:call, :-, pvar, one(Int32))
3131
end
3232
function mem_offset(op::Operation, td::TileDescription)
33-
@assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
33+
# @assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
3434
ret = Expr(:tuple)
3535
indices = getindices(op)
3636
loopedindex = op.ref.loopedindex
@@ -146,7 +146,7 @@ function lower_load_scalar!(
146146
for u zero(Int32):Base.unsafe_trunc(Int32,U-1)
147147
varname = varassignname(var, u, isunrolled)
148148
td = TileDescription(u, unrolled, tiled, suffix)
149-
push!(q.args, Expr(:(=), varname, Expr(:call, lv(:load), ptr, mem_offset_u(op, td))))
149+
push!(q.args, Expr(:(=), varname, Expr(:call, lv(:load), ptr, mem_offset_u(op, td))))
150150
end
151151
nothing
152152
end
@@ -512,11 +512,7 @@ function lower_nest(
512512
blockq = Expr(:block)
513513
if n > 1
514514
looptoadd = order[n-1]
515-
if looptoadd === vectorized
516-
push!(blockq.args, Expr(:(=), looptoadd, Expr(:call, lv(:_MM), W, loopstart)))
517-
else
518-
push!(blockq.args, Expr(:(=), looptoadd, loopstart))
519-
end
515+
push!(blockq.args, startloop(ls.loops[looptoadd], looptoadd === vectorized, W, looptoadd))
520516
end
521517
loopq = if exprtype === :block
522518
blockq
@@ -875,11 +871,7 @@ function lower_unrolled(ls::LoopSet, vectorized::Symbol, U::Int)
875871
W = ls.W
876872
typeT = ls.T
877873
setup_Wmask!(ls, W, typeT, vectorized, unrolled, last(order), U)
878-
initunrolledcounter = if unrolled === vectorized
879-
Expr(:(=), unrolled, Expr(:call, lv(:_MM), W, 0))
880-
else
881-
Expr(:(=), unrolled, 0)
882-
end
874+
initunrolledcounter = startloop(ls.loops[unrolled], unrolled === vectorized, W, unrolled)
883875
q = lower_unrolled!(Expr(:block, initunrolledcounter), ls, vectorized, U, -1, W, typeT, ls.loops[unrolled])
884876
lsexpr(ls, q)
885877
end

test/runtests.jl

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ using LinearAlgebra
4040
@test logsumexp!(r, x) 102.35216846104409
4141

4242
@testset "GEMM" begin
43-
using LoopVectorization, Test
43+
using LoopVectorization, Test; T = Float64
4444
Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (4, 4)
4545
AmulBq1 = :(for m 1:size(A,1), n 1:size(B,2)
4646
C[m,n] = zeroB
@@ -209,7 +209,7 @@ using LinearAlgebra
209209
C12 += A[k,m] * B[k,n1]
210210
C22 += A[k,m1] * B[k,n1]
211211
end)
212-
lsmul2x2q = LoopVectorization.LoopSet(mul2x2q)
212+
# lsmul2x2q = LoopVectorization.LoopSet(mul2x2q)
213213

214214
function toy1!(G, B,κ)
215215
d = size(G,1)
@@ -220,15 +220,6 @@ lsmul2x2q = LoopVectorization.LoopSet(mul2x2q)
220220
end
221221
end
222222
end
223-
# function toy4!(G, B,κ)
224-
# d = size(G,1)
225-
# @avx for d1=1:d
226-
# G[d1,κ] = B[1,d1]*B[1,κ]
227-
# for d2=2:d
228-
# G[d1,κ] += B[d2,d1]*B[d2,κ]
229-
# end
230-
# end
231-
# end
232223
# tq1 = :(for d1=1:d
233224
# G[d1,κ] = B[1,d1]*B[1,κ]
234225
# for d2=2:d
@@ -271,6 +262,15 @@ lsmul2x2q = LoopVectorization.LoopSet(mul2x2q)
271262
G[d1,κ] = z
272263
end);
273264
lst3 = LoopVectorization.LoopSet(tq3)
265+
function toy4!(G, B,κ)
266+
d = size(G,1)
267+
@avx for d1=1:d
268+
G[d1,κ] = B[1,d1]*B[1,κ]
269+
for d2=2:d
270+
G[d1,κ] += B[d2,d1]*B[d2,κ]
271+
end
272+
end
273+
end
274274

275275
for T (Float32, Float64, Int32, Int64)
276276
@show T, @__LINE__
@@ -312,6 +312,8 @@ lsmul2x2q = LoopVectorization.LoopSet(mul2x2q)
312312
@test G1 G2
313313
fill!(G2, TC(NaN)), toy3!(G2,B,1);
314314
@test G1 G2
315+
fill!(G2, TC(NaN)), toy4!(G2,B,1);
316+
@test G1 G2
315317
# fill!(G2, TC(NaN)), toy4!(G2,B,1);
316318
# @test G1 ≈ G2
317319
end
@@ -448,7 +450,7 @@ lsmul2x2q = LoopVectorization.LoopSet(mul2x2q)
448450
y[i] = yᵢ
449451
end)
450452
lsgemv = LoopVectorization.LoopSet(gemvq);
451-
@test LoopVectorization.choose_order(lsgemv) == (Symbol[:i, :j], :i, 8, -1)
453+
@test LoopVectorization.choose_order(lsgemv) == (Symbol[:i, :j], :i, 4, -1)
452454

453455
function mygemv!(y, A, x)
454456
@inbounds for i eachindex(y)
@@ -645,7 +647,9 @@ lsmul2x2q = LoopVectorization.LoopSet(mul2x2q)
645647
basis = rand(r, (dim, nbasis));
646648
coeffs = rand(T, nbasis);
647649
P = rand(T, dim, maxdeg+1);
648-
@test_broken mvp(P, basis, coeffs) mvpavx(P, basis, coeffs)
650+
mvp(P, basis, coeffs)
651+
mvpavx(P, basis, coeffs)
652+
@test mvp(P, basis, coeffs) mvpavx(P, basis, coeffs)
649653
end
650654
end
651655

0 commit comments

Comments
 (0)