Skip to content

Commit f98363b

Browse files
committed
Improved cost modeling of loops, basing cost on getindex instead of setindex!
1 parent f99621d commit f98363b

File tree

7 files changed

+162
-18
lines changed

7 files changed

+162
-18
lines changed

README.md

Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,132 @@ Pkg.add(PackageSpec(url="https://github.com/chriselrod/LoopVectorization.jl"))
1818

1919
## Usage
2020

21-
The current version of LoopVectorization provides a simple, dumb, transform on a single loop.
22-
What I mean by this is that it will not check for the transformations for validity. To be safe, I would straight loops that transform arrays or calculate reductions.
21+
This library provides the `@avx` macro, which may be used to prefix a `for` loop or broadcast statement.
22+
It then tries to vectorize the loop to improve runtime performance.
23+
24+
The macro assumes that loop iterations can be reordered. It also currently supports simple nested loops, where loop bounds of inner loops are constant across iterations of the outer loop, and only a single loop at each level of noop lest. These limitations should be removed in a future version.
25+
26+
A simple example with a single loop is the dot product:
27+
```julia
28+
using LoopVectorization, BenchmarkTools
29+
function mydot(a, b)
30+
s = 0.0
31+
@inbounds @simd for i eachindex(a,b)
32+
s += a[i]*b[i]
33+
end
34+
s
35+
end
36+
function mydotavx(a, b)
37+
s = 0.0
38+
@avx for i eachindex(a,b)
39+
s += a[i]*b[i]
40+
end
41+
s
42+
end
43+
a = rand(256); b = rand(256);
44+
@btime mydot($a, $b)
45+
@btime mydotavx($a, $b)
46+
a = rand(43); b = rand(43);
47+
@btime mydot($a, $b)
48+
@btime mydotavx($a, $b)
49+
```
50+
51+
On most recent CPUs, the performance of the dot product is bounded by
52+
the speed at which it can load data; most recent x86_64 CPUs can perform
53+
two aligned loads and two fused multiply adds (`fma`) per clock cycle.
54+
However, the dot product requires two loads per `fma`.
55+
56+
A self-dot function, on the otherhand, requires one load per fma:
57+
```julia
58+
function myselfdot(a)
59+
s = 0.0
60+
@inbounds @simd for i eachindex(a)
61+
s += a[i]*a[i]
62+
end
63+
s
64+
end
65+
function myselfdotavx(a)
66+
s = 0.0
67+
@avx for i eachindex(a)
68+
s += a[i]*a[i]
69+
end
70+
s
71+
end
72+
a = rand(256);
73+
@btime myselfdotavx($a)
74+
@btime myselfdot($a)
75+
@btime myselfdotavx($b)
76+
@btime myselfdot($b)
77+
```
78+
For this reason, the `@avx` version is roughly twice as fast. The `@inbounds @simd` version, however, is not, because it runs into the problem of loop carried dependencies: to add `a[i]*b[i]` to `s_new = s_old + a[i-j]*b[i-j]`, we must have first finished calculating `s_new`, but -- while two `fma` instructions can be initiated per cycle -- they each take several clock cycles to complete.
79+
For this reason, we need to unroll the operation to run several independent instances concurrently. The `@avx` macro models this cost to try and pick an optimal unroll factor.
80+
81+
Note that 14 and 12 nm Ryzen chips can only do 1 full width `fma` per clock cycle (and 2 loads), so they should see similar performance with the dot and selfdot. I haven't verified this, but would like to hear from anyone who can.
82+
83+
84+
We can also vectorize fancier loops. A likely familiar example to dive into:
85+
```julia
86+
function mygemm!(C, A, B)
87+
@inbounds for i 1:size(A,1), j 1:size(B,2)
88+
Cᵢⱼ = 0.0
89+
@fastmath for k 1:size(A,2)
90+
Cᵢⱼ += A[i,k] * B[k,j]
91+
end
92+
C[i,j] = Cᵢⱼ
93+
end
94+
end
95+
function mygemmavx!(C, A, B)
96+
@avx for i 1:size(A,1), j 1:size(B,2)
97+
Cᵢⱼ = 0.0
98+
for k 1:size(A,2)
99+
Cᵢⱼ += A[i,k] * B[k,j]
100+
end
101+
C[i,j] = Cᵢⱼ
102+
end
103+
end
104+
M, K, N = 72, 75, 71;
105+
C1 = Matrix{Float64}(undef, M, N); A = randn(M, K); B = randn(K, N);
106+
C2 = similar(C1); C3 = similar(C1);
107+
@btime mygemmavx!($C1, $A, $B)
108+
@btime mygemm!($C2, $A, $B)
109+
using LinearAlgebra, Test
110+
@test all(C1 .≈ C2)
111+
BLAS.set_num_threads(1); BLAS.vendor()
112+
@btime mul!($C3, $A, $B)
113+
@test all(C1 .≈ C3)
114+
```
115+
It can produce a decent macro kernel.
116+
In the future, I would like it to also model the cost of memory movement in the L1 and L2 cache, and use these to generate loops around the macro kernel following the work of [Low, et al. (2016)](http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf).
117+
118+
Until then, performance will degrade rapidly compared to BLAS as the size of the matrices increase. The advantage of the `@avx` macro, however, is that it is general. Not every operation is supported by BLAS.
119+
120+
For example, what if `A` were the outter product of two vectors?
121+
```julia
122+
123+
124+
```
125+
126+
127+
Another example, a straightforward operation expressed well via broadcasting:
128+
```julia
129+
a = rand(37); B = rand(37, 47); c = rand(47); c′ = c';
130+
131+
d1 = @. a + B * c′;
132+
d2 = @avx @. a + B * c′;
133+
134+
@test all(d1 .≈ d2)
135+
136+
@time @. $d1 = $a + $B * $c′;
137+
@time @avx @. $d2 = $a + $B * $c′;
138+
@test all(d1 .≈ d2)
139+
```
140+
can be optimized in a similar manner to BLAS, albeit to a much smaller degree because the naive version already benefits from vectorization (unlike the naive BLAS).
141+
142+
143+
144+
145+
Originally, LoopVectorization only provided a simple, dumb, transform on a single loop using the `@vectorize` macro. This transformation took element type and unroll factor arguments, performing no analysis of the loop, simply applying the specified arguments.
146+
For backwards compatability, this macro is still currently supported. However, it may eventually be deprecated.
23147

24148
For example,
25149
```julia
@@ -33,7 +157,7 @@ end
33157
using LoopVectorization, BenchmarkTools
34158
function sum_loopvec(x::AbstractVector{Float64})
35159
s = 0.0
36-
@vvectorize 4 for i eachindex(x)
160+
@vectorize 4 for i eachindex(x)
37161
s += x[i]
38162
end
39163
s

src/costs.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ const OPAQUE_INSTRUCTION = InstructionCost(50, 50.0, -1.0, VectorizationBase.REG
6363
# consolidated into a single register. The number of LICM-ed setindex!, on the other
6464
# hand, should indicate how many registers we're keeping live for the sake of eventually storing.
6565
const COST = Dict{Symbol,InstructionCost}(
66-
:getindex => InstructionCost(-3.0,0.5,3,0),
67-
:setindex! => InstructionCost(-3.0,1.0,3,1),
66+
:getindex => InstructionCost(-3.0,0.5,3,1),
67+
:setindex! => InstructionCost(-3.0,1.0,3,0),
6868
:zero => InstructionCost(1,0.5),
6969
:one => InstructionCost(3,0.5),
7070
:(+) => InstructionCost(4,0.5),

src/determinestrategy.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
unitstride(op, s) = first(loopdependencies(op)) === s
55

66
function cost(op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int = op.elementbytes)
7-
isconstant(op) && return 0.0, 0, 0
7+
isconstant(op) && return 0.0, 0, 1
88
# Wshift == dependson(op, unrolled) ? Wshift : 0
99
# c = first(cost(instruction(op), Wshift, size_T))::Int
1010
instr = instruction(op)
@@ -71,7 +71,6 @@ function evaluate_cost_unroll(
7171
for (id,op) enumerate(operations(ls))
7272
# won't define if already defined...
7373
# id = identifier(op)
74-
isconstant(op) && continue
7574
included_vars[id] && continue
7675
# it must also be a subset of defined symbols
7776
loopdependencies(op) nested_loop_syms || continue
@@ -193,18 +192,20 @@ function solve_tilesize(X, R)
193192
U, T = Ulow, Thigh
194193
end
195194
end
196-
# @show Uhigh*Tlow*R[1] + Uhigh*R[2]
197-
if RR Uhigh*Tlow*R[1] + Uhigh*R[2]
198-
tcost_temp = tile_cost(X, Uhigh, Tlow)
199-
if tcost_temp < tcost
200-
tcost = tcost_temp
201-
U, T = Uhigh, Tlow
202-
end
195+
# The RR + 1 is a hack to get it to favor Uhigh in more scenarios
196+
Tl = Tlow
197+
while RR < Uhigh*Tl*R[1] + Uhigh*R[2]
198+
Tl -= 1
199+
end
200+
tcost_temp = tile_cost(X, Uhigh, Tl)
201+
if tcost_temp < tcost
202+
tcost = tcost_temp
203+
U, T = Uhigh, Tl
203204
end
204205
if RR > Uhigh*Thigh*R[1] + Uhigh*R[2]
205206
throw("Something went wrong when solving for Tfloat and Ufloat.")
206207
end
207-
U, T, tcost
208+
min(U,RR), min(T,RR), tcost
208209
end
209210
function solve_tilesize_constU(X, R, U)
210211
floor(Int, (VectorizationBase.REGISTER_COUNT - R[3] - R[4] - U*R[2]) / (U * R[1]))
@@ -258,8 +259,8 @@ function evaluate_cost_tile(
258259
# @show order
259260
cost_vec = zeros(Float64, 4)
260261
reg_pressure = zeros(Int, 4)
261-
@inbounds reg_pressure[2] = 1
262-
@inbounds reg_pressure[3] = 1
262+
# @inbounds reg_pressure[2] = 1
263+
# @inbounds reg_pressure[3] = 1
263264
for n 1:N
264265
itersym = order[n]
265266
# Add to set of defined symbles
@@ -271,7 +272,7 @@ function evaluate_cost_tile(
271272
end
272273
# check which vars we can define at this level of loop nest
273274
for (id, op) enumerate(operations(ls))
274-
isconstant(op) && continue
275+
# isconstant(op) && continue
275276
# @assert id == identifier(op)+1 # testing, for now
276277
# won't define if already defined...
277278
included_vars[id] && continue

src/graphs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ function maybe_cse_load!(ls::LoopSet, expr::Expr, elementbytes::Int = 8)
346346
@view(expr.args[2+offset:end]),
347347
Ref(false)
348348
)::ArrayReference
349+
# @show ref.ref
349350
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
350351
if id === nothing
351352
add_load!( ls, gensym(:temporary), ref, elementbytes )

src/lowering.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ function lower_store_unrolled!(
199199
push!(q.args, instrcall)
200200
end
201201
else
202+
sn = findfirst(x -> x === unrolled, loopdependencies(op))::Int
202203
ustrides = Expr(:call, lv(:vmul), Expr(:call, :stride, ptr, sn), Expr(:call, lv(:vrange), Expr(:call, Expr(:curly, :Val, W))))
203204
for u 0:U-1
204205
instrcall = Expr(:call, lv(:scatter!), ptr, Symbol("##",var,:_,u), Expr(:call,lv(:vadd),mem_offset(op,u*W,unrolled),ustrides))

src/operations.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@ struct ArrayReference
22
array::Symbol
33
ref::Vector{Union{Symbol,Int}}
44
loaded::Base.RefValue{Bool}
5+
function ArrayReference(
6+
array, refsin, loadedin = Ref{Bool}(false)
7+
)
8+
ref = Vector{Union{Symbol,Int}}(undef, length(refsin))
9+
for i eachindex(ref)
10+
refᵢ = (refsin[i])::Union{Symbol,Int}
11+
ref[i] = refᵢ isa Int ? refᵢ - 1 : refᵢ
12+
end
13+
new(array, ref, loadedin)
14+
end
515
end
616
function ArrayReference(
717
array::Symbol,

test/runtests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,4 +330,11 @@ fill!(d4, 91000.0);
330330
@avx @. d4 = a + B c;
331331
@test all(d3 .≈ d4)
332332

333+
M, K, N = 77, 83, 57;
334+
A = rand(M,K); B = rand(K,N); C = rand(M,N);
335+
336+
D1 = C .+ A * B;
337+
D2 = @avx C .+ A B;
338+
@test all(D1 .≈ D2)
339+
333340
end

0 commit comments

Comments
 (0)