Skip to content

Commit ad6c988

Browse files
committed
Set compat to require SIMDPirates 0.7.10 (fixes #89), ignore :inbounds meta expr (fixes #88), and check whether functions in lowered broadcast expressions are SSA values (fixes #87).
1 parent d30fb6e commit ad6c988

File tree

9 files changed

+57
-21
lines changed

9 files changed

+57
-21
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.6.29"
4+
version = "0.6.30"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -13,7 +13,7 @@ VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1313

1414
[compat]
1515
OffsetArrays = "1"
16-
SIMDPirates = "0.7.8"
16+
SIMDPirates = "0.7.10"
1717
SLEEFPirates = "0.4.4"
1818
UnPack = "0"
1919
VectorizationBase = "0.10"

src/broadcast.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,13 @@ function add_broadcast!(
169169
) where {T<:Number}
170170
add_constant!(ls, bcname, elementbytes) # or replace elementbytes with sizeof(T) ? u
171171
end
172+
function add_broadcast!(
173+
ls::LoopSet, ::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{Base.RefValue{T}}, elementbytes::Int
174+
) where {T}
175+
refextract = gensym(bcname)
176+
pushpreamble!(ls, Expr(:(=), refextract, Expr(:ref, bcname)))
177+
add_constant!(ls, refextract, elementbytes) # or replace elementbytes with sizeof(T) ? u
178+
end
172179
function add_broadcast!(
173180
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
174181
::Type{SubArray{T,N,A,S,B}}, elementbytes::Int

src/constructors.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@ function Base.copyto!(ls::LoopSet, q::Expr)
77
end
88

99
function add_ci_call!(q::Expr, f, args, syms, i, mod = nothing)
10-
call = Expr(:call, f)
10+
call = if f isa Core.SSAValue
11+
Expr(:call, syms[f.id])
12+
else
13+
Expr(:call, f)
14+
end
1115
for arg @view(args[2:end])
1216
if arg isa Core.SSAValue
1317
push!(call.args, syms[arg.id])

src/graphs.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ end
244244

245245
includesarray(ls::LoopSet, array::Symbol) = array ls.includedarrays
246246

247-
function LoopSet(mod::Symbol, W = Symbol("##Wvecwidth##"), T = Symbol("Tloopeltype"))# = :LoopVectorization)
247+
function LoopSet(mod::Symbol, W = Symbol("##Wvecwidth##"), T = Symbol("##Tloopeltype##"))# = :LoopVectorization)
248248
LoopSet(
249249
Symbol[], [0], Loop[],
250250
Dict{Symbol,Operation}(),
@@ -332,6 +332,7 @@ end
332332
function add_block!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
333333
for x ex.args
334334
x isa Expr || continue # be that general?
335+
x.head === :inbounds && continue
335336
push!(ls, x, elementbytes, position)
336337
end
337338
end
@@ -546,6 +547,18 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
546547
add_andblock!(ls, ex, elementbytes, position)
547548
elseif ex.head === :||
548549
add_orblock!(ls, ex, elementbytes, position)
550+
elseif ex.head === :local # Handle locals introduced by `@inbounds`; using `local` with `@avx` is not recomended (nor is `@inbounds`; which applies automatically regardless)
551+
@assert length(ex.args) == 1 # TODO replace assert + first with "only" once support for Julia < 1.4 is dropped
552+
localbody = first(ex.args)
553+
@assert localbody.head === :(=)
554+
@assert length(localbody.args) == 2
555+
LHS = (localbody.args[1])::Symbol
556+
RHS = push!(ls, (localbody.args[2]), elementbytes, position)
557+
if isstore(RHS)
558+
RHS
559+
else
560+
add_compute!(ls, LHS, :identity, [RHS], elementbytes)
561+
end
549562
else
550563
throw("Don't know how to handle expression:\n$ex")
551564
end

src/reconstruct_loopset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ function _avx_loopset(OPSsv, ARFsv, AMsv, LPSYMsv, LBsv, vargs)
413413
)
414414
end
415415
@generated function _avx_!(::Val{UT}, ::Type{OPS}, ::Type{ARF}, ::Type{AM}, ::Type{LPSYM}, lb::LB, vargs...) where {UT, OPS, ARF, AM, LPSYM, LB}
416-
1 + 1
416+
# 1 + 1 # Irrelevant line you can comment out/in to force recompilation...
417417
ls = _avx_loopset(OPS.parameters, ARF.parameters, AM.parameters, LPSYM.parameters, LB.parameters, vargs)
418418
avx_body(ls, UT)
419419
end

src/split_loops.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ function lower_and_split_loops(ls::LoopSet)
7070
remaining_ops[1:ind-1] .= @view(split_candidates[1:ind-1]); remaining_ops[ind:end] .= @view(split_candidates[ind+1:end])
7171
ls_2 = split_loopset(ls, remaining_ops)
7272
order_2, unrolled_2, tiled_2, vectorized_2, U_2, T_2, cost_2 = choose_order_cost(ls_2)
73-
# U_1 = T_1 = U_2 = T_2 = 2
74-
if cost_1 + cost_2 < cost_fused
73+
U_1 = T_1 = U_2 = T_2 = 2
74+
if cost_1 + cost_2 cost_fused
7575
ls_2_lowered = if length(remaining_ops) > 1
7676
lower_and_split_loops(ls_2)
7777
else

test/copy.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ using LoopVectorization, Test
3737
end
3838
function offset_copy_avx1!(A, B)
3939
@_avx for i=1:size(A,1), j=1:size(B,2)
40-
A[i,j+2] = B[i,j]
40+
@inbounds A[i,j+2] = B[i,j]
4141
end
4242
end
4343
function offset_copyavx2!(A, B)
@@ -64,7 +64,7 @@ using LoopVectorization, Test
6464
end
6565
function make23avx!(x)
6666
@avx for i eachindex(x)
67-
x[i] = 23
67+
@inbounds x[i] = 23
6868
end
6969
end
7070
function make23_avx!(x)
@@ -82,6 +82,7 @@ using LoopVectorization, Test
8282
x[i] = a
8383
end
8484
end
85+
8586

8687
for T (Float32, Float64, Int32, Int64)
8788
@show T, @__LINE__
@@ -129,6 +130,11 @@ using LoopVectorization, Test
129130
myfillavx!(x, a);
130131
fill!(q2, a);
131132
@test x == q2
133+
q2 .= 23;
134+
fill!(q1, -99999); make23_avx!(q1);
135+
@test q2 == q1
136+
fill!(q1, -99999); make23avx!(q1);
137+
@test q2 == q1
132138
if T <: Union{Float32,Float64}
133139
make2point3avx!(x)
134140
fill!(q2, 2.3)

test/gemm.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -521,11 +521,12 @@
521521
end
522522
end
523523

524-
function twogemms!(Ab, Bb, Cb, A, B)
524+
function threegemms!(Ab, Bb, Cb, A, B, C)
525525
M, N = size(Cb); K = size(B,1)
526526
@avx for m in 1:M, k in 1:K, n in 1:N
527-
Ab[m,k] += Cb[m,n] * B[k,n]
528-
Bb[k,n] += A[m,k] * Cb[m,n]
527+
Ab[m,k] += C[m,n] * B[k,n]
528+
Bb[k,n] += A[m,k] * C[m,n]
529+
Cb[m,n] += A[m,k] * B[k,n]
529530
end
530531
end
531532
# M = 77;
@@ -639,10 +640,11 @@
639640
Bbit = B .> 0.5
640641
fill!(C, 9999.999); AmulBavx1!(C, A, Bbit)
641642
@test C A * Bbit
642-
Ab = zero(A); Bb = zero(B);
643-
twogemms!(Ab, Bb, C, A, B)
643+
Ab = zero(A); Bb = zero(B); Cb = zero(C);
644+
threegemms!(Ab, Bb, Cb, A, B, C)
644645
@test Ab C * B'
645646
@test Bb A' * C
647+
@test Cb A * B
646648
end
647649
@time @testset "_avx $T dynamic gemm" begin
648650
AmulB_avx1!(C, A, B)

test/special.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,13 @@
163163

164164
return log1p(s-1) + u
165165
end
166-
feq = :(for i = 1:n
167-
tmp = exp(x[i] - u)
168-
r[i] = tmp
169-
s += tmp
170-
end)
171-
lsfeq = LoopVectorization.LoopSet(feq);
172-
# lsfeq.operations
166+
# feq = :(for i = 1:n
167+
# tmp = exp(x[i] - u)
168+
# r[i] = tmp
169+
# s += tmp
170+
# end)
171+
# lsfeq = LoopVectorization.LoopSet(feq);
172+
# # lsfeq.operations
173173

174174
function vpow0!(y, x)
175175
@avx for i eachindex(y, x)
@@ -225,6 +225,10 @@
225225
y[i] = x[i] ^ -5
226226
end);
227227
ls = LoopVectorization.LoopSet(q);
228+
q2 = :(for i eachindex(y, x)
229+
y[i] = x[i] ^ 5
230+
end);
231+
ls2 = LoopVectorization.LoopSet(q2);
228232

229233
function vpow5!(y, x)
230234
@avx for i eachindex(y, x)

0 commit comments

Comments
 (0)