Skip to content

Commit 6928df5

Browse files
committed
Added more tests, and made sure to define outer reductions after defining W.
1 parent 2ab11e0 commit 6928df5

File tree

5 files changed

+111
-45
lines changed

5 files changed

+111
-45
lines changed

src/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ end
130130
function add_broadcast!(
131131
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{T}, elementbytes::Int = 8
132132
) where {T<:Union{Integer,Float32,Float64}}
133-
pushpreamble!(ls, Expr(:(=), destname, bcname))
133+
pushpreamble!(ls, Expr(:(=), Symbol("##", destname), bcname))
134134
add_constant!(ls, destname, elementbytes) # or replace elementbytes with sizeof(T) ?
135135
end
136136
function add_broadcast!(

src/determinestrategy.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,22 @@ function parentsnotreduction(op::Operation)
110110
end
111111
return true
112112
end
113+
function unroll_no_reductions(ls, order, vectorized, Wshift, size_T)
114+
innermost = last(order)
115+
compute_rt = 0.0
116+
load_rt = 0.0
117+
# latency not a concern, because no depchains
118+
for op operations(ls)
119+
dependson(op, innermost) || continue
120+
if iscompute(op)
121+
compute_rt += first(cost(op, vectorized, Wshift, size_T))
122+
elseif isload(op)
123+
load_rt += first(cost(op, vectorized, Wshift, size_T))
124+
end
125+
end
126+
# heuristic guess
127+
round(Int, (compute_rt + load_rt + 1) / compute_rt)
128+
end
113129
function determine_unroll_factor(
114130
ls::LoopSet, order::Vector{Symbol}, unrolled::Symbol, vectorized::Symbol = first(order)
115131
)
@@ -124,9 +140,10 @@ function determine_unroll_factor(
124140
num_reductions += 1
125141
end
126142
end
127-
# @show num_reductions
128-
if iszero(num_reductions) # the 4 is a hack, based on the idea that there is some cost to moving through columns
129-
return length(order) == 1 ? 1 : 4
143+
if iszero(num_reductions)
144+
# if only 1 loop, no need to unroll
145+
# if more than 1 loop, there is some cost. Picking 2 here as a heuristic.
146+
return length(order) == 1 ? 1 : unroll_no_reductions(ls, order, vectorized, Wshift, size_T)
130147
end
131148
# So if num_reductions > 0, we set the unroll factor to be high enough so that the CPU can be kept busy
132149
# if there are, U = max(1, round(Int, max(latency) * throughput / num_reductions)) = max(1, round(Int, latency / (recip_throughput * num_reductions)))
@@ -410,6 +427,7 @@ function choose_tile(ls::LoopSet)
410427
new_order, state = iter
411428
end
412429
end
430+
# Last in order is the inner most loop
413431
function choose_order(ls::LoopSet)
414432
if num_loops(ls) > 1
415433
torder, tvec, tU, tT, tc = choose_tile(ls)

src/graphs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ function add_constant!(ls::LoopSet, var::Symbol, elementbytes::Int = 8)
367367
end
368368
function add_constant!(ls::LoopSet, var, elementbytes::Int = 8)
369369
sym = gensym(:temp)
370-
pushpreamble!(ls, Expr(:(=), sym, var))
370+
pushpreamble!(ls, Expr(:(=), Symbol("##", sym), var))
371371
pushop!(ls, Operation(length(operations(ls)), sym, elementbytes, Symbol("##CONSTANT##"), constant, NODEPENDENCY, Symbol[], NOPARENTS), sym)
372372
end
373373
# This version has loop dependencies. var gets assigned to sym when lowering.

src/lowering.jl

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -828,14 +828,10 @@ function definemask(loop::Loop, W::Symbol, allon::Bool)
828828
maskexpr(W, loop.rangesym, allon)
829829
end
830830
end
831-
function setup_mainblock(ls::LoopSet, W::Symbol, typeT::Symbol, vectorized::Symbol, unrolled::Symbol, U::Int, q::Expr)
832-
preambleW = Expr(
833-
:block,
834-
Expr(:(=), typeT, determine_eltype(ls)),
835-
Expr(:(=), W, determine_width(ls, typeT, unrolled)),
836-
definemask(ls.loops[vectorized], W, U > 1 && unrolled === vectorized)
837-
)
838-
Expr(:block, ls.preamble, preambleW, q)
831+
function setup_Wmask!(ls::LoopSet, W::Symbol, typeT::Symbol, vectorized::Symbol, unrolled::Symbol, U::Int)
832+
pushpreamble!(ls, Expr(:(=), typeT, determine_eltype(ls)))
833+
pushpreamble!(ls, Expr(:(=), W, determine_width(ls, typeT, unrolled)))
834+
pushpreamble!(ls, definemask(ls.loops[vectorized], W, U > 1 && unrolled === vectorized))
839835
end
840836
function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
841837
order = ls.loop_order.loopnames
@@ -844,6 +840,7 @@ function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
844840
mangledtiled = tiledsym(tiled)
845841
W = gensym(:W)
846842
typeT = gensym(:T)
843+
setup_Wmask!(ls, W, typeT, vectorized, unrolled, U)
847844
# W = VectorizationBase.pick_vector_width(ls, unrolled)
848845
tiledloop = ls.loops[tiled]
849846
static_tile = tiledloop.hintexact
@@ -900,7 +897,7 @@ function lower_tiled(ls::LoopSet, vectorized::Symbol, U::Int, T::Int)
900897
end
901898
q = gc_preserve(ls, q)
902899
reduce_expr!(q, ls, U)
903-
setup_mainblock(ls, W, typeT, vectorized, unrolled, U, q)
900+
Expr(:block, ls.preamble, q)
904901
end
905902
function lower_unrolled(ls::LoopSet, vectorized::Symbol, U::Int)
906903
order = ls.loop_order.loopnames
@@ -909,8 +906,9 @@ function lower_unrolled(ls::LoopSet, vectorized::Symbol, U::Int)
909906
# W = VectorizationBase.pick_vector_width(ls, unrolled)
910907
W = gensym(:W)
911908
typeT = gensym(:T)
909+
setup_Wmask!(ls, W, typeT, vectorized, unrolled, U)
912910
q = lower_unrolled!(Expr(:block, Expr(:(=), unrolled, 0)), ls, vectorized, U, -1, W, typeT, ls.loops[unrolled])
913-
setup_mainblock(ls, W, typeT, vectorized, unrolled, U, q)
911+
Expr(:block, ls.preamble, q)
914912
end
915913

916914

test/runtests.jl

Lines changed: 80 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -36,28 +36,27 @@ using LinearAlgebra
3636
@test logsumexp!(r, x) 102.35216846104409
3737

3838
@testset "GEMM" begin
39-
gemmq = :(for i 1:size(A,1), j 1:size(B,2)
39+
AmulBq = :(for i 1:size(A,1), j 1:size(B,2)
4040
Cᵢⱼ = zero(eltype(C))
4141
for k 1:size(A,2)
4242
Cᵢⱼ += A[i,k] * B[k,j]
4343
end
4444
C[i,j] = Cᵢⱼ
4545
end)
4646

47-
lsgemm = LoopVectorization.LoopSet(gemmq);
47+
lsAmulB = LoopVectorization.LoopSet(AmulBq);
4848
U, T = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3,4) : (6, 4)
49-
@test LoopVectorization.choose_order(lsgemm) == (Symbol[:j,:i,:k], :i, U, T)
49+
@test LoopVectorization.choose_order(lsAmulB) == (Symbol[:j,:i,:k], :i, U, T)
5050

51-
function mygemm!(C, A, B)
52-
@inbounds for i 1:size(A,1), j 1:size(B,2)
53-
Cᵢⱼ = zero(eltype(C))
54-
@simd ivdep for k 1:size(A,2)
55-
Cᵢⱼ += A[i,k] * B[k,j]
51+
function AmulB!(C, A, B)
52+
C .= 0
53+
for k 1:size(A,2), j 1:size(B,2)
54+
@simd ivdep for i 1:size(A,1)
55+
@inbounds C[i,j] += A[i,k] * B[k,j]
5656
end
57-
C[i,j] = Cᵢⱼ
5857
end
5958
end
60-
function mygemmavx!(C, A, B)
59+
function AmulBavx!(C, A, B)
6160
@avx for i 1:size(A,1), j 1:size(B,2)
6261
Cᵢⱼ = zero(eltype(C))
6362
for k 1:size(A,2)
@@ -67,13 +66,46 @@ using LinearAlgebra
6766
end
6867
end
6968

70-
69+
# function AtmulB!(C, A, B)
70+
# for j ∈ 1:size(C,2), i ∈ 1:size(C,1)
71+
# Cᵢⱼ = zero(eltype(C))
72+
# @simd ivdep for k ∈ 1:size(A,1)
73+
# @inbounds Cᵢⱼ += A[k,i] * B[k,j]
74+
# end
75+
# C[i,j] = Cᵢⱼ
76+
# end
77+
# end
78+
AtmulBq = :(for j 1:size(C,2), i 1:size(C,1)
79+
Cᵢⱼ = zero(eltype(C))
80+
for k 1:size(A,1)
81+
Cᵢⱼ += A[k,i] * B[k,j]
82+
end
83+
C[i,j] = Cᵢⱼ
84+
end)
85+
lsAtmulB = LoopVectorization.LoopSet(AtmulBq);
86+
# LoopVectorization.choose_order(lsAtmulB)
87+
@test LoopVectorization.choose_order(lsAtmulB) == (Symbol[:j,:i,:k], :k, U, T)
88+
89+
function AtmulBavx!(C, A, B)
90+
@avx for j 1:size(C,2), i 1:size(C,1)
91+
Cᵢⱼ = zero(eltype(C))
92+
for k 1:size(A,1)
93+
Cᵢⱼ += A[k,i] * B[k,j]
94+
end
95+
C[i,j] = Cᵢⱼ
96+
end
97+
end
98+
7199
for T (Float32, Float64)
72100
M, K, N = 72, 75, 71;
73101
C = Matrix{T}(undef, M, N); A = randn(T, M, K); B = randn(T, K, N);
74102
C2 = similar(C);
75-
mygemmavx!(C, A, B)
76-
mygemm!(C2, A, B)
103+
AmulBavx!(C, A, B)
104+
AmulB!(C2, A, B)
105+
@test C C2
106+
At = copy(A');
107+
fill!(C, 9999.999);
108+
AtmulBavx!(C, At, B)
77109
@test C C2
78110
end
79111
end
@@ -178,6 +210,7 @@ using LinearAlgebra
178210
myvexpavx!(b2, a)
179211
@test b1 b2
180212
@test myvexp(a) myvexpavx(a)
213+
@test b1 @avx exp.(a)
181214
end
182215
end
183216

@@ -225,6 +258,23 @@ using LinearAlgebra
225258

226259

227260
@testset "Miscellaneous" begin
261+
262+
dot3q = :(for m 1:M, n 1:N
263+
s += x[m] * A[m,n] * y[n]
264+
end)
265+
lsdot3 = LoopVectorization.LoopSet(dot3q);
266+
LoopVectorization.choose_order(lsdot3)
267+
268+
dot3(x, A, y) = dot(x, A * y)
269+
function dot3avx(x, A, y)
270+
M, N = size(A)
271+
s = zero(promote_type(eltype(x), eltype(A), eltype(y)))
272+
@avx for m 1:M, n 1:N
273+
s += x[m] * A[m,n] * y[n]
274+
end
275+
s
276+
end
277+
228278
subcolq = :(for i 1:size(A,2), j eachindex(x)
229279
B[j,i] = A[j,i] - x[j]
230280
end)
@@ -246,25 +296,25 @@ using LinearAlgebra
246296
end
247297

248298
colsumq = :(for i 1:size(A,2), j eachindex(x)
249-
x[j] += A[j,i]
299+
x[j] += A[j,i] - 0.25
250300
end)
251301
lscolsum = LoopVectorization.LoopSet(colsumq);
252-
@test LoopVectorization.choose_order(lscolsum) == (Symbol[:j,:i], :j, 4, -1)
302+
@test LoopVectorization.choose_order(lscolsum) == (Symbol[:j,:i], :j, 8, -1)
253303

304+
# my colsum is wrong (by 0.25), but slightly more interesting
254305
function mycolsum!(x, A)
255306
@. x = 0
256307
@inbounds for i 1:size(A,2)
257308
@simd for j eachindex(x)
258-
x[j] += A[j,i]
309+
x[j] += A[j,i] - 0.25
259310
end
260311
end
261312
end
262-
263313
function mycolsumavx!(x, A)
264314
@avx for j eachindex(x)
265315
xⱼ = zero(eltype(x))
266316
for i 1:size(A,2)
267-
xⱼ += A[j,i]
317+
xⱼ += A[j,i] - 0.25
268318
end
269319
x[j] = xⱼ
270320
end
@@ -300,8 +350,8 @@ using LinearAlgebra
300350
end
301351

302352
for T (Float32, Float64)
303-
A = randn(T, 199, 498)
304-
x = randn(T, size(A,1))
353+
A = randn(T, 199, 498);
354+
x = randn(T, size(A,1));
305355
B1 = similar(A); B2 = similar(A);
306356

307357
mysubcol!(B1, A, x)
@@ -311,13 +361,17 @@ using LinearAlgebra
311361
x1 = similar(x); x2 = similar(x);
312362
mycolsum!(x1, A)
313363
mycolsumavx!(x2, A)
314-
315364
@test x1 x2
316365

317366
= x1 ./ size(A,2);
318367
myvar!(x1, A, x̄)
319368
myvaravx!(x2, A, x̄)
320369
@test x1 x2
370+
371+
M, N = 47, 73;
372+
x = rand(T, M); A = rand(T, M, N); y = rand(T, N);
373+
@test dot3avx(x, A, y) dot3(x, A, y)
374+
321375
end
322376
end
323377

@@ -351,26 +405,17 @@ end
351405
@avx @. d4 = a + B c;
352406
@test d3 d4
353407

354-
# T = Float64
355-
# T = Float32
356408
M, K, N = 77, 83, 57;
357409
A = rand(T,M,K); B = rand(T,K,N); C = rand(T,M,N);
358410

359411
D1 = C .+ A * B;
360412
D2 = @avx C .+ A B;
361413
@test D1 D2
362414

363-
B = rand(T,K,N);
364415
D3 = exp.(B');
365416
D4 = @avx exp.(B');
366417
@test D3 D4
367418

368-
d4q = @avx exp.(B')
369-
370-
D3c = copy(D3); D4c = copy(D4);
371-
D3[1:21,1:10]
372-
D4[1:21,1:10]
373-
374419
fill!(D3, -1e3); fill!(D4, 9e9);
375420
Bt = Transpose(B);
376421
@. D3 = exp(Bt);
@@ -390,6 +435,11 @@ end
390435
lset = @avx @. D2' = exp(Bt);
391436

392437
@test D1 D2
438+
439+
a = rand(137);
440+
b1 = @avx @. 3*a + sin(a) + sqrt(a);
441+
b2 = @. 3*a + sin(a) + sqrt(a);
442+
@test b1 b2
393443
end
394444
end
395445

0 commit comments

Comments
 (0)