Skip to content

Commit 8da0974

Browse files
committed
Add support for broadcasting arrays where one of the sizes other than the first is 1 (will still crash if the first dimension is 1; use a LowDimArray to work around this).
1 parent d7259dd commit 8da0974

15 files changed

+320
-110
lines changed

Manifest.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6161

6262
[[SIMDPirates]]
6363
deps = ["MacroTools", "VectorizationBase"]
64-
git-tree-sha1 = "5d8212a28fd747bb5f77fe8b8f8d21b4024548d3"
64+
git-tree-sha1 = "bdb86180981859208e759adab67e2bdd8de55c64"
6565
uuid = "21efa798-c60a-11e8-04d3-e1a92915a26a"
66-
version = "0.3.2"
66+
version = "0.3.3"
6767

6868
[[SLEEFPirates]]
6969
deps = ["SIMDPirates", "VectorizationBase"]
70-
git-tree-sha1 = "740604e0e5bb739488da2c2c25b5f3d71517a490"
70+
git-tree-sha1 = "b7c597af915f4425adf1ec64248a49ac23605fa5"
7171
uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa"
72-
version = "0.3.1"
72+
version = "0.3.2"
7373

7474
[[Serialization]]
7575
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
@@ -83,6 +83,6 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8383

8484
[[VectorizationBase]]
8585
deps = ["CpuId", "LinearAlgebra"]
86-
git-tree-sha1 = "4c587c04a6daa03370023e8ea4c431fcae5f7371"
86+
git-tree-sha1 = "07785cdd42d94d3b5b13e1ba67d3f8feee4e670c"
8787
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
88-
version = "0.2.4"
88+
version = "0.2.5"

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.6.3"
4+
version = "0.6.4"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

benchmark/driver.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ AtmulB_bench = fetch(AtmulB_future)
4747
AmulBt_bench = fetch(AmulBt_future)
4848
Atmulvb_bench = fetch(Atmulvb_future)
4949

50-
v = 4
50+
v = 6
5151
const PICTURES = "/home/chriselrod/Pictures"
5252
save(joinpath(PICTURES, "bench_gemm_v$v.png"), plot(gemm_bench));
5353
save(joinpath(PICTURES, "bench_AtmulB_v$v.png"), plot(AtmulB_bench));
@@ -63,6 +63,9 @@ save(joinpath(PICTURES, "bench_random_access_v$v.png"), plot(randomaccess_bench)
6363
save(joinpath(PICTURES, "bench_AmulBt_v$v.png"), plot(AmulBt_bench));
6464
save(joinpath(PICTURES, "bench_Atmulvb_v$v.png"), plot(Atmulvb_bench));
6565

66+
67+
68+
6669
plot(gemm_bench)
6770
plot(AtmulB_bench)
6871
plot(dot_bench)

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module LoopVectorization
33
using VectorizationBase, SIMDPirates, SLEEFPirates, MacroTools, Parameters
44
using VectorizationBase: REGISTER_SIZE, REGISTER_COUNT, extract_data, num_vector_load_expr,
55
mask, masktable, pick_vector_width_val, valmul, valrem, valmuladd, valadd, valsub, _MM,
6-
maybestaticlength, maybestaticsize, staticm1, subsetview, vzero,
6+
maybestaticlength, maybestaticsize, staticm1, subsetview, vzero, stridedpointer_for_broadcast,
77
Static, StaticUnitRange, StaticLowerUnitRange, StaticUpperUnitRange,
88
PackedStridedPointer, SparseStridedPointer, RowMajorStridedPointer, StaticStridedPointer, StaticStridedStruct
99
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul, vrange, reduced_add, reduced_prod, reduce_to_add, reduce_to_prod#,

src/add_compute.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ function update_reduction_status!(parentvec::Vector{Operation}, deps::Vector{Sym
8686
end
8787
end
8888
end
89+
function add_compute!(ls::LoopSet, op::Operation)
90+
@assert iscompute(op)
91+
pushop!(ls, child, name(op))
92+
end
93+
8994
function add_reduction_update_parent!(
9095
vparents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet,
9196
parent::Operation, instr::Symbol, directdependency::Bool, elementbytes::Int

src/add_loads.jl

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
function add_load!(ls::LoopSet, op::Operation, actualarray::Bool = true, broadcast::Bool = false)
2+
@assert isload(op)
3+
ref = op.ref
4+
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
5+
# try to CSE
6+
if id === nothing
7+
push!(ls.syms_aliasing_refs, name(op))
8+
push!(ls.refs_aliasing_syms, ref)
9+
else
10+
opp = ls.opdict[ls.syms_aliasing_refs[id]] # throw an error if not found.
11+
return isstore(opp) ? getop(ls, first(parents(opp))) : opp
12+
end
13+
add_vptr!(ls, op.ref.ref.array, vptr(op.ref), actualarray, broadcast)
14+
pushop!(ls, op, name(op))
15+
end
116

217
function add_load!(
318
ls::LoopSet, var::Symbol, array::Symbol, rawindices, elementbytes::Int
@@ -10,24 +25,13 @@ function add_load!(
1025
)
1126
length(mpref.loopdependencies) == 0 && return add_constant!(ls, var, mpref, elementbytes)
1227
ref = mpref.mref
13-
# try to CSE
14-
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
15-
if id === nothing
16-
push!(ls.syms_aliasing_refs, var)
17-
push!(ls.refs_aliasing_syms, ref)
18-
else
19-
opp = getop(ls, ls.syms_aliasing_refs[id], elementbytes)
20-
return isstore(opp) ? getop(ls, first(parents(opp))) : opp
21-
end
22-
# else, don't
2328
op = Operation( ls, var, elementbytes, :getindex, memload, mpref )
24-
add_vptr!(ls, op)
25-
pushop!(ls, op, var)
29+
add_load!(ls, op, true, false)
2630
end
2731

2832
# for use with broadcasting
2933
function add_simple_load!(
30-
ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int, actualarray::Bool = true
34+
ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int, actualarray::Bool = true, broadcast::Bool = false
3135
)
3236
loopdeps = Symbol[s for s ref.indices]
3337
mref = ArrayReferenceMeta(
@@ -38,7 +42,7 @@ function add_simple_load!(
3842
:getindex, memload, loopdeps,
3943
NODEPENDENCY, NOPARENTS, mref
4044
)
41-
add_vptr!(ls, op.ref.ref.array, vptr(op.ref), actualarray)
45+
add_vptr!(ls, op.ref.ref.array, vptr(op.ref), actualarray, broadcast)
4246
pushop!(ls, op, var)
4347
end
4448
function add_load_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int)

src/add_stores.jl

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,14 @@ function cse_store!(ls::LoopSet, op::Operation)
88
ls.opdict[op.variable] = op
99
op
1010
end
11-
function add_store!(ls::LoopSet, op::Operation)
12-
nops = length(ls.operations)
11+
function add_store!(ls::LoopSet, op::Operation, add_pvar::Bool = name(first(parents(op))) ls.syms_aliasing_refs)
12+
@assert isstore(op)
13+
if add_pvar
14+
push!(ls.syms_aliasing_refs, name(first(parents(op))))
15+
push!(ls.refs_aliasing_syms, op.ref)
16+
end
1317
id = op.identifier
14-
id == nops ? add_unique_store!(ls, op) : cse_store!(ls, op)
18+
id == length(operations(ls)) ? add_unique_store!(ls, op) : cse_store!(ls, op)
1519
end
1620
function add_copystore!(
1721
ls::LoopSet, parent::Operation, mpref::ArrayReferenceMetaPosition, elementbytes::Int
@@ -20,6 +24,7 @@ function add_copystore!(
2024
add_store!(ls, name(op), mpref, elementbytes, op)
2125
end
2226

27+
2328
function add_store!(
2429
ls::LoopSet, var::Symbol, mpref::ArrayReferenceMetaPosition, elementbytes::Int, parent = getop(ls, var, mpref.loopdependencies, elementbytes)
2530
)
@@ -29,10 +34,7 @@ function add_store!(
2934
reduceddeps = mpref.reduceddeps
3035
pvar = name(parent)
3136
id = length(ls.operations)
32-
if pvar ls.syms_aliasing_refs
33-
push!(ls.syms_aliasing_refs, pvar)
34-
push!(ls.refs_aliasing_syms, mpref.mref)
35-
else
37+
if pvar ls.syms_aliasing_refs
3638
# try to cse store, by replacing the previous one
3739
ref = mpref.mref.ref
3840
for opp operations(ls)
@@ -44,10 +46,13 @@ function add_store!(
4446
# @show ref opp.ref.ref
4547
end
4648
end
49+
add_pvar = false
50+
else
51+
add_pvar = true
4752
end
4853
pushparent!(parents, ldref, reduceddeps, parent)
4954
op = Operation( id, name(mpref), elementbytes, :setindex!, memstore, mpref )#loopdependencies, reduceddeps, parents, mpref.mref )
50-
add_store!(ls, op)
55+
add_store!(ls, op, add_pvar)
5156
end
5257

5358
function add_store!(

src/broadcast.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,21 +101,21 @@ function add_broadcast!(
101101
) where {D,T,N}
102102
fulldims = Symbol[loopsyms[n] for n 1:N if D[n]]
103103
ref = ArrayReference(bcname, fulldims)
104-
add_simple_load!(ls, destname, ref, elementbytes )::Operation
104+
add_simple_load!(ls, destname, ref, elementbytes, true, false )::Operation
105105
end
106106
function add_broadcast_adjoint_array!(
107107
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{A}, elementbytes::Int
108108
) where {T,N,A<:AbstractArray{T,N}}
109109
parent = gensym(:parent)
110110
pushpreamble!(ls, Expr(:(=), parent, Expr(:call, :parent, bcname)))
111111
ref = ArrayReference(parent, Symbol[loopsyms[N + 1 - n] for n 1:N])
112-
add_simple_load!( ls, destname, ref, elementbytes )::Operation
112+
add_simple_load!( ls, destname, ref, elementbytes, true, true )::Operation
113113
end
114114
function add_broadcast_adjoint_array!(
115115
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{<:AbstractVector}, elementbytes::Int
116116
)
117117
ref = ArrayReference(bcname, Symbol[loopsyms[2]])
118-
add_simple_load!( ls, destname, ref, elementbytes )
118+
add_simple_load!( ls, destname, ref, elementbytes, true, true )
119119
end
120120
function add_broadcast!(
121121
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
@@ -133,7 +133,7 @@ function add_broadcast!(
133133
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
134134
::Type{<:AbstractArray{T,N}}, elementbytes::Int
135135
) where {T,N}
136-
add_simple_load!(ls, destname, ArrayReference(bcname, @view(loopsyms[1:N])), elementbytes)
136+
add_simple_load!(ls, destname, ArrayReference(bcname, @view(loopsyms[1:N])), elementbytes, true, true)
137137
end
138138
function add_broadcast!(
139139
ls::LoopSet, ::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{T}, elementbytes::Int
@@ -147,7 +147,7 @@ function add_broadcast!(
147147
inds = Vector{Symbol}(undef, N+1)
148148
inds[1] = Symbol("##DISCONTIGUOUSSUBARRAY##")
149149
inds[2:end] .= @view(loopsyms[1:N])
150-
add_simple_load!(ls, destname, ArrayReference(bcname, inds), elementbytes)
150+
add_simple_load!(ls, destname, ArrayReference(bcname, inds), elementbytes, true, true)
151151
end
152152
function add_broadcast!(
153153
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},

src/condense_loopset.jl

Lines changed: 109 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,48 @@ end
185185
@inline array_wrapper(A::Adjoint) = Adjoint
186186
@inline array_wrapper(A::SubArray) = A.indices
187187

188+
189+
190+
@generated function __avx__!(
191+
::Val{UT}, ::Type{OPS}, ::Type{ARF}, ::Type{AM}, lb::LB,
192+
::Val{AR}, ::Val{D}, ::Val{IND}, subsetvals, arraydescript, vargs::Vararg{<:Any,N}
193+
) where {UT, OPS, ARF, AM, LB, N, AR, D, IND}
194+
num_vptrs = length(ARF.parameters)::Int
195+
vptrs = [gensym(:vptr) for _ 1:num_vptrs]
196+
call = Expr(:call, lv(:_avx_!), Val{UT}(), OPS, ARF, AM, :lb)
197+
for n 1:num_vptrs
198+
push!(call.args, vptrs[n])
199+
end
200+
q = Expr(:block)
201+
j = 0
202+
assigned_names = Vector{Symbol}(undef, length(AR))
203+
num_arrays = 0
204+
for i eachindex(AR)
205+
ari = (AR[i])::Int
206+
ind = (IND[i])::Union{Nothing,Int}
207+
LHS = ind === nothing ? gensym() : vptrs[ind]
208+
assigned_names[i] = LHS
209+
d = (D[i])::Union{Nothing,Int}
210+
if d === nothing # stridedpointer
211+
if ari == -1
212+
RHS = Expr(:call, :LoopValue)
213+
else
214+
num_arrays += 1
215+
RHS = Expr(:call, lv(:stridedpointer), Expr(:ref, :vargs, ari), Expr(:ref, :arraydescript, ari))
216+
end
217+
else #subsetview
218+
j += 1
219+
RHS = Expr(:call, :subsetview, assigned_names[ari], Expr(:call, Expr(:curly, :Val, d)), Expr(:ref, :subsetvals, j))
220+
end
221+
push!(q.args, Expr(:(=), LHS, RHS))
222+
end
223+
for n num_arrays+1:N
224+
push!(call.args, Expr(:ref, :vargs, n))
225+
end
226+
push!(q.args, call)
227+
Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__, @__FILE__), q)
228+
end
229+
188230
# Try to condense in type stable manner
189231
function generate_call(ls::LoopSet, IUT)
190232
operation_descriptions = Expr(:curly, :Tuple)
@@ -200,20 +242,77 @@ function generate_call(ls::LoopSet, IUT)
200242
foreach(ref -> push!(arrayref_descriptions.args, ArrayRefStruct(ls, ref, arraysymbolinds)), ls.refs_aliasing_syms)
201243
argmeta = argmeta_and_consts_description(ls, arraysymbolinds)
202244
loop_bounds = loop_boundaries(ls)
203-
204-
q = Expr(:call, lv(:_avx_!), Expr(:call, Expr(:curly, :Val, IUT)), operation_descriptions, arrayref_descriptions, argmeta, loop_bounds)
205-
foreach(ref -> push!(q.args, vptr(ref)), ls.refs_aliasing_syms)
245+
inline, U, T = IUT
246+
if inline
247+
q = Expr(:call, lv(:_avx_!), Expr(:call, Expr(:curly, :Val, (U,T))), operation_descriptions, arrayref_descriptions, argmeta, loop_bounds)
248+
foreach(ref -> push!(q.args, vptr(ref)), ls.refs_aliasing_syms)
249+
else
250+
arraydescript = Expr(:tuple)
251+
q = Expr(:call, lv(:__avx__!), Expr(:call, Expr(:curly, :Val, (U,T))), operation_descriptions, arrayref_descriptions, argmeta, loop_bounds, arraydescript)
252+
for array ls.includedactualarrays
253+
push!(q.args, Expr(:call, lv(:unwrap_array), array))
254+
push!(arraydescript.args, Expr(:call, lv(:array_wrapper), array))
255+
end
256+
end
206257
foreach(is -> push!(q.args, last(is)), ls.preamble_symsym)
207258
append!(q.args, arraysymbolinds)
208259
add_reassigned_syms!(q, ls)
209260
add_external_functions!(q, ls)
210261
q
211262
end
212263

213-
function setup_call_noinline(ls::LoopSet, inline = Int8(2), U = zero(Int8), T = zero(Int8))
214-
call = generate_call(ls, (inline,U,T))
264+
function setup_call_noinline(ls::LoopSet, U = zero(Int8), T = zero(Int8))
265+
call = generate_call(ls, (false,U,T))
215266
hasouterreductions = length(ls.outer_reductions) > 0
216-
q = ls.preamble
267+
q = Expr(:block)
268+
vptrarrays = Expr(:tuple)
269+
vptrsubsetvals = Expr(:tuple)
270+
vptrsubsetdims = Expr(:tuple)
271+
vptrindices = Expr(:tuple)
272+
stridedpointerLHS = Symbol[]
273+
loopvalueLHS = Symbol[]
274+
for ex ls.preamble.args
275+
# vptrcalls = Expr(:tuple)
276+
if ex isa Expr && ex.head === :(=) && length(ex.args) == 2
277+
if ex.args[2] isa Expr && ex.args[2].head === :call
278+
gr = first(ex.args[2].args)
279+
if gr == lv(:stridedpointer)
280+
array = ex.args[2].args[2]
281+
arrayid = findfirst(a -> a === array, ls.includedactualarrays)
282+
if arrayid isa Int
283+
push!(vptrarrays.args, arrayid)
284+
else
285+
@assert array loopvalueLHS
286+
push!(vptrarrays.args, -1)
287+
end
288+
push!(vptrsubsetdims.args, nothing)
289+
vp = first(ex.args)::Symbol
290+
push!(stridedpointerLHS, vp)
291+
push!(vptrindices.args, findfirst(a -> vptr(a) == vp, ls.refs_aliasing_syms))
292+
elseif gr == lv(:subsetview)
293+
array = ex.args[2].args[2]
294+
vptrarrayid = findfirst(a -> a === array, stridedpointerLHS)#::Int
295+
if vptrarrayid === nothing
296+
@show array, stridedpointerLHS
297+
@assert vptrarrayid isa Int
298+
end
299+
push!(vptrarrays.args, vptrarrayid::Int)
300+
push!(vptrsubsetdims.args, ex.args[2].args[3].args[1].args[2])
301+
push!(vptrsubsetvals.args, ex.args[2].args[4])
302+
vp = first(ex.args)::Symbol
303+
push!(stridedpointerLHS, vp)
304+
push!(vptrindices.args, findfirst(a -> vptr(a) == vp, ls.refs_aliasing_syms))
305+
end
306+
elseif ex.args[2] == LoopValue()
307+
push!(loopvalueLHS, first(ex.args))
308+
end
309+
end
310+
push!(q.args, ex)
311+
end
312+
insert!(call.args, 7, Expr(:call, Expr(:curly, :Val, vptrarrays)))
313+
insert!(call.args, 8, Expr(:call, Expr(:curly, :Val, vptrsubsetdims)))
314+
insert!(call.args, 9, Expr(:call, Expr(:curly, :Val, vptrindices)))
315+
insert!(call.args, 10, vptrsubsetvals)
217316
if hasouterreductions
218317
outer_reducts = Expr(:local)
219318
for or ls.outer_reductions
@@ -227,8 +326,6 @@ function setup_call_noinline(ls::LoopSet, inline = Int8(2), U = zero(Int8), T =
227326
retv = loopset_return_value(ls, Val(false))
228327
call = Expr(:(=), retv, call)
229328
push!(q.args, gc_preserve(ls, call))
230-
push!(q.args, Expr(:return, retv))
231-
q = Expr(:block, Expr(:(=), retv, Expr(:call, Expr(:(->), Expr(:tuple, ls.includedactualarrays...), q), ls.includedactualarrays...)))
232329
for or ls.outer_reductions
233330
op = ls.operations[or]
234331
var = name(op)
@@ -239,13 +336,11 @@ function setup_call_noinline(ls::LoopSet, inline = Int8(2), U = zero(Int8), T =
239336
end
240337
else
241338
push!(q.args, gc_preserve(ls, call))
242-
push!(q.args, Expr(:return, :nothing))
243-
q = Expr(:call, Expr(:(->), Expr(:tuple, ls.includedactualarrays...), q), ls.includedactualarrays...)
244339
end
245340
q
246341
end
247-
function setup_call_inline(ls::LoopSet, inline = Int8(2), U = zero(Int8), T = zero(Int8))
248-
call = generate_call(ls, (inline,U,T))
342+
function setup_call_inline(ls::LoopSet, U = zero(Int8), T = zero(Int8))
343+
call = generate_call(ls, (true,U,T))
249344
hasouterreductions = length(ls.outer_reductions) > 0
250345
if hasouterreductions
251346
retv = loopset_return_value(ls, Val(false))
@@ -273,9 +368,9 @@ function setup_call(ls::LoopSet, inline = Int8(2), U = zero(Int8), T = zero(Int8
273368
# Creating an anonymous function and calling it also achieves the outlining, while still
274369
# inlining the generated function into the loop preamble.
275370
if inline == Int8(2)
276-
setup_call_inline(ls, Int8(2), U, T)
371+
setup_call_inline(ls, U, T)
277372
else
278-
setup_call_noinline(ls, Int8(2), U, T)
373+
setup_call_noinline(ls, U, T)
279374
end
280375
end
281376

0 commit comments

Comments
 (0)