Skip to content

Commit 89676e0

Browse files
committed
Fix bug in loop splitting, and add literal ^ support (fixes #85).
1 parent 6d299e0 commit 89676e0

File tree

9 files changed

+177
-8
lines changed

9 files changed

+177
-8
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@ env:
2626
- COVERALLS_PARALLEL=true
2727
notifications:
2828
webhooks: "https://coveralls.io/webhook"
29+

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.6.28"
4+
version = "0.6.29"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/LoopVectorization.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ using SIMDPirates: VECTOR_SYMBOLS, evadd, evsub, evmul, evfdiv, vrange, reduced_
1212
sizeequivalentfloat, sizeequivalentint, vadd!, vsub!, vmul!, vfdiv!, vfmadd!, vfnmadd!, vfmsub!, vfnmsub!,
1313
vfmadd231, vfmsub231, vfnmadd231, vfnmsub231, sizeequivalentfloat, sizeequivalentint, #prefetch,
1414
vmullog2, vmullog10, vdivlog2, vdivlog10, vmullog2add!, vmullog10add!, vdivlog2add!, vdivlog10add!, vfmaddaddone
15+
using SLEEFPirates: pow
1516
using Base.Broadcast: Broadcasted, DefaultArrayStyle
1617
using LinearAlgebra: Adjoint, Transpose
1718
using Base.Meta: isexpr
1819

20+
1921
const SUPPORTED_TYPES = Union{Float16,Float32,Float64,Integer}
2022

2123
export LowDimArray, stridedpointer, vectorizable,

src/add_compute.jl

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,16 @@ function add_reduction_update_parent!(
180180
pushop!(ls, child, name(parent))
181181
opout
182182
end
183+
184+
183185
function add_compute!(
184186
ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int, position::Int,
185187
mpref::Union{Nothing,ArrayReferenceMetaPosition} = nothing
186188
)
187189
@assert ex.head === :call
188190
instr = instruction(first(ex.args))::Symbol
189191
args = @view(ex.args[2:end])
192+
(instr === :(^) && length(args) == 2 && (args[2] isa Number)) && return add_pow!(ls, var, args[1], args[2], elementbytes, position)
190193
parents = Operation[]
191194
deps = Symbol[]
192195
reduceddeps = Symbol[]
@@ -215,12 +218,11 @@ function add_compute!(
215218
end
216219
end
217220
reduction = reduction_ind > 0
221+
loopnestview = view(ls.loopsymbols, 1:position)
218222
if iszero(length(deps)) && reduction
219-
loopnestview = view(ls.loopsymbols, 1:position)
220223
append!(deps, loopnestview)
221224
append!(reduceddeps, loopnestview)
222225
else
223-
loopnestview = view(ls.loopsymbols, 1:position)
224226
newloopdeps = Symbol[]; newreduceddeps = Symbol[];
225227
setdiffv!(newloopdeps, newreduceddeps, deps, loopnestview)
226228
mergesetv!(newreduceddeps, reduceddeps)
@@ -252,3 +254,60 @@ function add_compute!(
252254
pushop!(ls, op, LHS)
253255
end
254256

257+
# adds x ^ (p::Real)
258+
function add_pow!(
259+
ls::LoopSet, var::Symbol, x, p::Real, elementbytes::Int, position::Int
260+
)
261+
xop = if x isa Expr
262+
add_operation!(ls, gensym(:xpow), x, elementbytes, position)
263+
elseif x isa Symbol
264+
xo = get(ls.opdict, x, nothing)
265+
if isnothing(xo)
266+
pushpreamble!(ls, Expr(:(=), var, Expr(:call, :(^), x, p)))
267+
return add_constant!(ls, var, elementbytes)
268+
end
269+
xo
270+
elseif x isa Number
271+
pushpreamble!(ls, Expr(:(=), var, x ^ p))
272+
return add_constant!(ls, var, elementbytes)
273+
end
274+
pint = round(Int, p)
275+
if p != pint
276+
pop = add_constant!(ls, p, elementbytes)
277+
return add_compute!(ls, var, :^, [xop, pop], elementbytes)
278+
end
279+
if pint == -1
280+
return add_compute!(ls, var, :vinv, [xop], elementbytes)
281+
elseif pint < 0
282+
xop = add_compute!(ls, gensym(:inverse), :vinv, [xop], elementbytes)
283+
pint = - pint
284+
end
285+
if pint == 0
286+
op = Operation(length(operations(ls)), var, elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS)
287+
push!(ls.preamble_ones, (identifier(op),IntOrFloat))
288+
return pushop!(ls, op)
289+
elseif pint == 1
290+
return add_compute!(ls, var, :identity, [xop], elementbytes)
291+
elseif pint == 2
292+
return add_compute!(ls, var, :vabs2, [xop], elementbytes)
293+
end
294+
295+
# Implementation from https://github.com/JuliaLang/julia/blob/a965580ba7fd0e8314001521df254e30d686afbf/base/intfuncs.jl#L216
296+
t = trailing_zeros(pint) + 1
297+
pint >>= t
298+
while (t -= 1) > 0
299+
varname = (iszero(pint) && isone(t)) ? var : gensym(:pbs)
300+
xop = add_compute!(ls, varname, :vabs2, [xop], elementbytes)
301+
end
302+
yop = xop
303+
while pint > 0
304+
t = trailing_zeros(pint) + 1
305+
pint >>= t
306+
while (t -= 1) >= 0
307+
xop = add_compute!(ls, gensym(:pbs), :vabs2, [xop], elementbytes)
308+
end
309+
yop = add_compute!(ls, iszero(pint) ? var : gensym(:pbs), :vmul, [xop, yop], elementbytes)
310+
end
311+
yop
312+
end
313+

src/graphs.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ end
244244

245245
includesarray(ls::LoopSet, array::Symbol) = array ls.includedarrays
246246

247-
function LoopSet(mod::Symbol)# = :LoopVectorization)
247+
function LoopSet(mod::Symbol, W = Symbol("##Wvecwidth##"), T = Symbol("Tloopeltype"))# = :LoopVectorization)
248248
LoopSet(
249249
Symbol[], [0], Loop[],
250250
Dict{Symbol,Operation}(),
@@ -261,7 +261,7 @@ function LoopSet(mod::Symbol)# = :LoopVectorization)
261261
Matrix{Float64}(undef, 4, 2),
262262
Matrix{Float64}(undef, 4, 2),
263263
Bool[], Bool[],
264-
gensym(:W), gensym(:T), mod
264+
W, T, mod
265265
)
266266
end
267267

src/reconstruct_loopset.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,10 @@ function _avx_loopset(OPSsv, ARFsv, AMsv, LPSYMsv, LBsv, vargs)
413413
)
414414
end
415415
@generated function _avx_!(::Val{UT}, ::Type{OPS}, ::Type{ARF}, ::Type{AM}, ::Type{LPSYM}, lb::LB, vargs...) where {UT, OPS, ARF, AM, LPSYM, LB}
416+
1 + 1
416417
ls = _avx_loopset(OPS.parameters, ARF.parameters, AM.parameters, LPSYM.parameters, LB.parameters, vargs)
417418
avx_body(ls, UT)
418419
end
419420

421+
422+

src/split_loops.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,16 +70,17 @@ function lower_and_split_loops(ls::LoopSet)
7070
remaining_ops[1:ind-1] .= @view(split_candidates[1:ind-1]); remaining_ops[ind:end] .= @view(split_candidates[ind+1:end])
7171
ls_2 = split_loopset(ls, remaining_ops)
7272
order_2, unrolled_2, tiled_2, vectorized_2, U_2, T_2, cost_2 = choose_order_cost(ls_2)
73+
# U_1 = T_1 = U_2 = T_2 = 2
7374
if cost_1 + cost_2 < cost_fused
7475
ls_2_lowered = if length(remaining_ops) > 1
7576
lower_and_split_loops(ls_2)
7677
else
77-
lower(ls_2, unrolled_2, tiled_2, vectorized_2, U_2, T_2)
78+
lower(ls_2, order_2, unrolled_2, tiled_2, vectorized_2, U_2, T_2)
7879
end
7980
return Expr(
8081
:block,
8182
ls.preamble,
82-
lower(ls_1, unrolled_1, tiled_1, vectorized_1, U_1, T_1),
83+
lower(ls_1, order_1, unrolled_1, tiled_1, vectorized_1, U_1, T_1),
8384
ls_2_lowered
8485
)
8586
end

test/gemm.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,13 @@
521521
end
522522
end
523523

524+
function twogemms!(Ab, Bb, Cb, A, B)
525+
M, N = size(C); K = size(B,1)
526+
@avx for m in 1:M, k in 1:K, n in 1:N
527+
Ab[m,k] += Cb[m,n] * B[k,n]
528+
Bb[k,n] += A[m,k] * Cb[m,n]
529+
end
530+
end
524531
# M = 77;
525532
# A = rand(M,M); B = rand(M,M); C = similar(A);
526533
# mulCAtB_2x2block_avx!(C,A,B)
@@ -632,6 +639,10 @@
632639
Bbit = B .> 0.5
633640
fill!(C, 9999.999); AmulBavx1!(C, A, Bbit)
634641
@test C A * Bbit
642+
Ab = zero(A); Bb = zero(B);
643+
twogemms!(Ab, Bb, C, A, B)
644+
@test Ab C * B'
645+
@test Bb A' * C
635646
end
636647
@time @testset "_avx $T dynamic gemm" begin
637648
AmulB_avx1!(C, A, B)

test/special.jl

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,82 @@
170170
end)
171171
lsfeq = LoopVectorization.LoopSet(feq);
172172
# lsfeq.operations
173+
174+
function vpow0!(y, x)
175+
@avx for i eachindex(y, x)
176+
y[i] = x[i] ^ 0
177+
end; y
178+
end
179+
function vpown1!(y, x)
180+
@avx for i eachindex(y, x)
181+
y[i] = x[i] ^ -1
182+
end; y
183+
end
184+
function vpow1!(y, x)
185+
@avx for i eachindex(y, x)
186+
y[i] = x[i] ^ 1
187+
end; y
188+
end
189+
function vpown2!(y, x)
190+
@avx for i eachindex(y, x)
191+
y[i] = x[i] ^ -2
192+
end; y
193+
end
194+
function vpow2!(y, x)
195+
@avx for i eachindex(y, x)
196+
y[i] = x[i] ^ 2
197+
end; y
198+
end
199+
function vpown3!(y, x)
200+
@avx for i eachindex(y, x)
201+
y[i] = x[i] ^ -3
202+
end; y
203+
end
204+
function vpow3!(y, x)
205+
@avx for i eachindex(y, x)
206+
y[i] = x[i] ^ 3
207+
end; y
208+
end
209+
function vpown4!(y, x)
210+
@avx for i eachindex(y, x)
211+
y[i] = x[i] ^ -4
212+
end; y
213+
end
214+
function vpow4!(y, x)
215+
@avx for i eachindex(y, x)
216+
y[i] = x[i] ^ 4
217+
end; y
218+
end
219+
function vpown5!(y, x)
220+
@avx for i eachindex(y, x)
221+
y[i] = x[i] ^ -5
222+
end; y
223+
end
224+
q = :(for i eachindex(y, x)
225+
y[i] = x[i] ^ -5
226+
end);
227+
ls = LoopVectorization.LoopSet(q);
228+
229+
function vpow5!(y, x)
230+
@avx for i eachindex(y, x)
231+
y[i] = x[i] ^ 5
232+
end; y
233+
end
234+
function vpowf!(y, x)
235+
@avx for i eachindex(y, x)
236+
y[i] = x[i] ^ 2.3
237+
end; y
238+
end
239+
function vpowf!(y, x, p::Number)
240+
@avx for i eachindex(y, x)
241+
y[i] = x[i] ^ p
242+
end; y
243+
end
244+
function vpowf!(y, x, p::AbstractArray)
245+
@avx for i eachindex(y, x)
246+
y[i] = x[i] ^ p[i]
247+
end; y
248+
end
173249

174250

175251
for T (Float32, Float64)
@@ -193,7 +269,7 @@
193269
@test ld trianglelogdetavx(A)
194270
@test ld trianglelogdet_avx(A)
195271

196-
x = rand(T, 1000);
272+
x = rand(T, 999);
197273
r1 = similar(x);
198274
r2 = similar(x);
199275
lse = logsumexp!(r1, x);
@@ -218,5 +294,21 @@
218294
@test A1 A2
219295
fill!(A2, 0); offset_exp_avx!(A2, B)
220296
@test A1 A2
297+
298+
@test all(isone, vpow0!(r1, x))
299+
@test vpown1!(r1, x) map!(inv, r2, x)
300+
@test vpow1!(r1, x) == x
301+
@test vpown2!(r1, x) map!(abs2 inv, r2, x)
302+
@test vpow2!(r1, x) map!(abs2, r2, x)
303+
@test vpown3!(r1, x) (r2 .= x .^ -3)
304+
@test vpow3!(r1, x) (r2 .= x .^ 3)
305+
@test vpown4!(r1, x) (r2 .= x .^ -4)
306+
@test vpow4!(r1, x) (r2 .= x .^ 4)
307+
@test vpown5!(r1, x) (r2 .= x .^ -5)
308+
@test vpow5!(r1, x) (r2 .= x .^ 5)
309+
@test vpowf!(r1, x) (r2 .= x .^ 2.3)
310+
@test vpowf!(r1, x, -1.7) (r2 .= x .^ -1.7)
311+
p = randn(length(x));
312+
@test vpowf!(r1, x, x) (r2 .= x .^ x)
221313
end
222314
end

0 commit comments

Comments
 (0)