Skip to content

Commit 2768d72

Browse files
committed
Fix for tiled-only operations, adding broadcast support.
1 parent 97daeca commit 2768d72

File tree

10 files changed

+307
-38
lines changed

10 files changed

+307
-38
lines changed

Manifest.toml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,13 @@ uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
6161

6262
[[SIMDPirates]]
6363
deps = ["MacroTools", "VectorizationBase"]
64-
git-tree-sha1 = "3e45c76dfcc349ff208a955e1ce6e92b1be6d15e"
65-
repo-rev = "master"
66-
repo-url = "https://github.com/chriselrod/SIMDPirates.jl"
64+
git-tree-sha1 = "296cae2ccd6e4766aad669e748c1248fb99ab69c"
6765
uuid = "21efa798-c60a-11e8-04d3-e1a92915a26a"
6866
version = "0.1.0"
6967

7068
[[SLEEFPirates]]
7169
deps = ["SIMDPirates", "VectorizationBase"]
72-
git-tree-sha1 = "ba032bbcc7038853867119f4cac383a0051b62a8"
70+
git-tree-sha1 = "01ff5ddb2fe743e93a6d80b072a02cceb90592bf"
7371
repo-rev = "master"
7472
repo-url = "https://github.com/chriselrod/SLEEFPirates.jl"
7573
uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa"
@@ -87,8 +85,6 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
8785

8886
[[VectorizationBase]]
8987
deps = ["CpuId", "LinearAlgebra"]
90-
git-tree-sha1 = "1cc48a9bce5c18f2f70fa16cc5b2b39b39332a9e"
91-
repo-rev = "master"
92-
repo-url = "https://github.com/chriselrod/VectorizationBase.jl"
88+
git-tree-sha1 = "30dd7fd08829bfa0fa6c57bf84a7daeac2e9462b"
9389
uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
94-
version = "0.1.0"
90+
version = "0.1.1"

Project.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,21 @@ authors = ["Chris Elrod <[email protected]>"]
44
version = "0.1.0"
55

66
[deps]
7+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
89
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
910
SIMDPirates = "21efa798-c60a-11e8-04d3-e1a92915a26a"
1011
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
1112
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"
1213

14+
[compat]
15+
MacroTools = "0.5"
16+
Parameters = "0.12.0"
17+
SIMDPirates = "0.1.0"
18+
SLEEFPirates = "0.1.0"
19+
VectorizationBase = "0.1.2"
20+
julia = "1.0.0"
21+
1322
[extras]
1423
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1524

src/LoopVectorization.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@ module LoopVectorization
33
using VectorizationBase, SIMDPirates, SLEEFPirates, MacroTools, Parameters
44
using VectorizationBase: REGISTER_SIZE, extract_data, num_vector_load_expr, mask
55
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul, vrange, reduced_add, reduced_prod
6+
using Base.Broadcast: Broadcasted, DefaultArrayStyle
7+
using LinearAlgebra: Adjoint
68
using MacroTools: prewalk, postwalk
79

8-
export vectorizable, @vectorize, @vvectorize, @avx
10+
11+
export LowDimArray, stridedpointer, vectorizable,
12+
@vectorize, @vvectorize, @avx,
913

1014
function isdense end #
1115

@@ -896,6 +900,7 @@ end
896900
include("costs.jl")
897901
include("operations.jl")
898902
include("graphs.jl")
903+
include("broadcast.jl")
899904
include("determinestrategy.jl")
900905
include("lowering.jl")
901906
include("constructors.jl")

src/broadcast.jl

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
struct Product{A,B}
2+
a::A
3+
b::B
4+
end
5+
6+
@inline (a::A, b::B) where {A,B} = Product{A,B}(a, b)
7+
@inline Base.Broadcast.Broadcasted(::typeof(), a::A, b::B) where {A, B} = Product{A,B}(a, b)
8+
# TODO: Need to make this handle A or B being (1 or 2)-D broadcast objects.
9+
function add_broadcast!(
10+
ls::LoopSet, mC::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
11+
::Type{Product{A,B}}, elementbytes::Int = 8
12+
) where {T,A,B}
13+
K = gensym(:K)
14+
mA = gensym(:Aₘₖ)
15+
mB = gensym(:Bₖₙ)
16+
pushpreamble!(ls, Expr(:(=), mA, Expr(:(.), bcname, QuoteNode(:a))))
17+
pushpreamble!(ls, Expr(:(=), mB, Expr(:(.), bcname, QuoteNode(:b))))
18+
pushpreamble!(ls, Expr(:(=), K, Expr(:call, :size, mB, 1)))
19+
20+
k = gensym(:k)
21+
ls.loops[k] = Loop(k, K)
22+
m = loopsyms[1]; n = loopsyms[2];
23+
# load A
24+
# loadA = add_load!(ls, gensym(:A), productref(A, mA, m, k), elementbytes)
25+
loadA = add_broadcast!(ls, gensym(:A), mA, [m,k], A, elementbytes)
26+
# load B
27+
loadB = add_broadcast!(ls, gensym(:B), mB, [k,n], B, elementbytes)
28+
# set Cₘₙ = 0
29+
setC = add_constant!(ls, 0.0, Symbol[m, k], mC, elementbytes)
30+
# compute Cₘₙ += Aₘₖ * Bₖₙ
31+
reductop = Operation(
32+
ls, mC, elementbytes, :vmuladd, compute, Symbol[m, k, n], Symbol[k], Operation[loadA, loadB, setC]
33+
)
34+
pushop!(ls, reductop, mC)
35+
end
36+
37+
struct LowDimArray{D,T,N,A<:DenseArray{T,N}} <: DenseArray{T,N}
38+
data::A
39+
end
40+
@inline Base.pointer(A::LowDimArray) = pointer(A)
41+
function LowDimArray{D}(data::A) where {D,T,N,A <: AbstractArray{T,N}}
42+
LowDimArray{D,T,N,A}(data)
43+
end
44+
function add_broadcast!(
45+
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
46+
::Type{<:LowDimArray{D,T,N}}, elementbytes::Int = 8
47+
) where {D,T,N}
48+
fulldims = Union{Symbol,Int}[loopsyms[n] for n 1:N if D[n]]
49+
ref = ArrayReference(bcname, fulldims, Ref{Bool}(false))
50+
add_load!(ls, destname, ref, elementbytes)::Operation
51+
end
52+
function add_broadcast!(
53+
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
54+
::Type{Adjoint{T,A}}, elementbytes::Int = 8
55+
) where {T, N, A <: AbstractArray{T,N}}
56+
ref = ArrayReference(bcname, Union{Symbol,Int}[loopsyms[N + 1 - n] for n 1:N], Ref{Bool}(false))
57+
add_load!( ls, destname, ref, elementbytes )::Operation
58+
end
59+
function add_broadcast!(
60+
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
61+
::Type{Adjoint{T,V}}, elementbytes::Int = 8
62+
) where {T, V <: AbstractVector{T}}
63+
ref = ArrayReference(bcname, Union{Symbol,Int}[loopsyms[2]], Ref{Bool}(false))
64+
add_load!( ls, destname, ref, elementbytes )
65+
end
66+
function add_broadcast!(
67+
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
68+
::Type{<:AbstractArray{T,N}}, elementbytes::Int = 8
69+
) where {T,N}
70+
add_load!(ls, destname, ArrayReference(bcname, @view(loopsyms[1:N]), Ref{Bool}(false)), elementbytes)
71+
end
72+
function add_broadcast!(
73+
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
74+
::Type{Broadcasted{DefaultArrayStyle{N},Nothing,F,A}},
75+
elementbytes::Int = 8
76+
) where {N,F,A}
77+
instr = get(FUNCTIONSYMBOLS, F) do
78+
f = gensym(:f)
79+
pushpreamble!(ls, Expr(:(=), f, Expr(:(.), bcname, QuoteNode(:f))))
80+
f
81+
end
82+
args = A.parameters
83+
Nargs = length(args)
84+
bcargs = Expr(:(.), bcname, QuoteNode(:args))
85+
# this is the var name in the loop
86+
parents = Operation[]
87+
deps = Symbol[]
88+
reduceddeps = Symbol[]
89+
for (i,arg) enumerate(args)
90+
argname = gensym(:arg)
91+
pushpreamble!(ls, Expr(:(=), argname, Expr(:ref, bcargs, i)))
92+
# dynamic dispatch
93+
parent = add_broadcast!(ls, gensym(:temp), argname, loopsyms, arg)::Operation
94+
pushparent!(parents, deps, reduceddeps, parent)
95+
end
96+
op = Operation(
97+
length(operations(ls)), destname, elementbytes, instr, compute, deps, reduceddeps, parents
98+
)
99+
pushop!(ls, op, destname)
100+
end
101+
102+
# size of dest determines loops
103+
# @generated
104+
function vmaterialize!(
105+
dest::AbstractArray{T,N}, bc::BC
106+
# ) where {T, N, BC <: Broadcasted}
107+
) where {N, T, BC <: Broadcasted}
108+
# we have an N dimensional loop.
109+
# need to construct the LoopSet
110+
loopsyms = [gensym(:n) for n 1:N]
111+
ls = LoopSet()
112+
sizes = Expr(:tuple,)
113+
for (n,itersym) enumerate(loopsyms)
114+
Nsym = gensym(:N)
115+
ls.loops[itersym] = Loop(itersym, Nsym)
116+
push!(sizes.args, Nsym)
117+
end
118+
pushpreamble!(ls, Expr(:(=), sizes, Expr(:call, :size, :dest)))
119+
add_broadcast!(ls, :dest, :bc, loopsyms, BC)
120+
add_store!(ls, :dest, ArrayReference(:dest, loopsyms, Ref{Bool}(false)))
121+
resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product
122+
# lower(ls)
123+
ls
124+
end
125+
126+
function vmaterialize(bc::Broadcasted)
127+
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)
128+
vmaterialize!(similar(bc, ElType), bc)
129+
end

src/constructors.jl

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,40 @@ function Base.copyto!(ls::LoopSet, q::Expr)
66
add_loop!(ls, q)
77
end
88

9+
function add_ci_call!(q::Expr, f, args, syms, i)
10+
call = Expr(:call, f)
11+
for arg @view(args[2:end])
12+
if arg isa Core.SSAValue
13+
push!(call.args, syms[arg.id])
14+
else
15+
push!(call.args, arg)
16+
end
17+
end
18+
push!(q.args, Expr(:(=), syms[i], call))
19+
end
20+
21+
function substitute_broadcast(q::Expr)
22+
ci = first(Meta.lower(LoopVectorization, q).args).code
23+
nargs = length(ci)-1
24+
ex = Expr(:block,)
25+
syms = [gensym() for _ 1:nargs]
26+
for n 1:nargs
27+
ciₙ = ci[n]
28+
ciₙargs = ciₙ.args
29+
f = first(ciₙargs)
30+
if ciₙ.head === :(=)
31+
push!(ex.args, Expr(:(=), f, syms[((ciₙargs[2])::Core.SSAValue).id]))
32+
elseif f === GlobalRef(Base, :materialize!)
33+
add_ci_call!(ex, lv(:vmaterialize!), ciₙargs, syms, n)
34+
elseif f === GlobalRef(Base, :materialize)
35+
add_ci_call!(ex, lv(:vmaterialize), ciₙargs, syms, n)
36+
else
37+
add_ci_call!(ex, f, ciₙargs, syms, n)
38+
end
39+
end
40+
ex
41+
end
42+
943
function LoopSet(q::Expr)
1044
q = SIMDPirates.contract_pass(q)
1145
ls = LoopSet()
@@ -15,16 +49,13 @@ function LoopSet(q::Expr)
1549
end
1650

1751
macro avx(q)
18-
esc(lower(LoopSet(q)))
52+
q2 = if q.head === :for
53+
lower(LoopSet(q))
54+
else# assume broadcast
55+
substitute_broadcast(q)
56+
end
57+
esc(q2)
1958
end
2059

21-
#=
22-
@generated function vmaterialize(
23-
dest::AbstractArray{T,N}, bc::BC
24-
) where {T,N,BC <: Base.Broadcast.Broadcasted}
25-
# we have an N dimensional loop.
26-
# need to construct the LoopSet
2760

28-
end
29-
=#
3061

src/costs.jl

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,10 @@ const COST = Dict{Symbol,InstructionCost}(
7171
:(-) => InstructionCost(4,0.5),
7272
:(*) => InstructionCost(4,0.5),
7373
:(/) => InstructionCost(13,4.0,-2.0),
74+
:vadd => InstructionCost(4,0.5),
75+
:vsub => InstructionCost(4,0.5),
76+
:vmul => InstructionCost(4,0.5),
77+
:vdiv => InstructionCost(13,4.0,-2.0),
7478
:(==) => InstructionCost(1, 0.5),
7579
:isequal => InstructionCost(1, 0.5),
7680
:(&) => InstructionCost(1, 0.5),
@@ -104,6 +108,9 @@ const CORRESPONDING_REDUCTION = Dict{Symbol,Symbol}(
104108
:(+) => :vsum,
105109
:(-) => :vsum,
106110
:(*) => :vprod,
111+
:vadd => :vsum,
112+
:vsub => :vsum,
113+
:vmul => :vprod,
107114
:(&) => :vall,
108115
:(|) => :vany,
109116
:muladd => :vsum,
@@ -194,3 +201,56 @@ function callfun(f::Symbol)
194201
end
195202

196203

204+
205+
const FUNCTIONSYMBOLS = Dict{Type{<:Function},Symbol}(
206+
typeof(+) => :(+),
207+
typeof(SIMDPirates.vadd) => :(+),
208+
typeof(Base.FastMath.add_fast) => :(+),
209+
typeof(-) => :(-),
210+
typeof(SIMDPirates.vsub) => :(-),
211+
typeof(Base.FastMath.sub_fast) => :(-),
212+
typeof(*) => :(*),
213+
typeof(SIMDPirates.vmul) => :(*),
214+
typeof(Base.FastMath.mul_fast) => :(*),
215+
typeof(/) => :(/),
216+
typeof(SIMDPirates.vdiv) => :(/),
217+
typeof(Base.FastMath.div_fast) => :(/),
218+
typeof(==) => :(==),
219+
typeof(isequal) => :isequal,
220+
typeof(&) => :(&),
221+
typeof(|) => :(|),
222+
typeof(>) => :(>),
223+
typeof(<) => :(<),
224+
typeof(>=) => :(>=),
225+
typeof(<=) => :(<=),
226+
typeof(inv) => :inv,
227+
typeof(muladd) => :muladd,
228+
typeof(fma) => :fma,
229+
typeof(SIMDPirates.vmuladd) => :vmuladd,
230+
typeof(SIMDPirates.vfma) => :vfma,
231+
typeof(SIMDPirates.vfmadd) => :vfmadd,
232+
typeof(SIMDPirates.vfmsub) => :vfmsub,
233+
typeof(SIMDPirates.vfnmadd) => :vfnmadd,
234+
typeof(SIMDPirates.vfnmsub) => :vfnmsub,
235+
typeof(sqrt) => :sqrt,
236+
typeof(Base.FastMath.sqrt_fast) => :sqrt,
237+
typeof(SIMDPirates.vsqrt) => :sqrt,
238+
typeof(log) => :log,
239+
typeof(Base.FastMath.log_fast) => :log,
240+
typeof(SIMDPirates.vlog) => :log,
241+
typeof(SLEEFPirates.log) => :log,
242+
typeof(exp) => :exp,
243+
typeof(Base.FastMath.exp_fast) => :exp,
244+
typeof(SIMDPirates.vexp) => :exp,
245+
typeof(SLEEFPirates.exp) => :exp,
246+
typeof(sin) => :sin,
247+
typeof(Base.FastMath.sin_fast) => :sin,
248+
typeof(SLEEFPirates.sin) => :sin,
249+
typeof(cos) => :cos,
250+
typeof(Base.FastMath.cos_fast) => :cos,
251+
typeof(SLEEFPirates.cos) => :cos,
252+
typeof(sincos) => :sincos,
253+
typeof(Base.FastMath.sincos_fast) => :sincos,
254+
typeof(SLEEFPirates.sincos) => :sincos
255+
)
256+

0 commit comments

Comments
 (0)