Skip to content

Commit 0bc4916

Browse files
committed
Keep track of parent module of functions, and test @_avx on a function not defined in LoopVectorization. Fixes #31.
1 parent 9c82ae9 commit 0bc4916

File tree

9 files changed

+168
-155
lines changed

9 files changed

+168
-155
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ using MacroTools: prewalk, postwalk
1313

1414

1515
export LowDimArray, stridedpointer, vectorizable,
16-
@avx, *ˡ, ,
16+
@avx, @_avx, *ˡ, _avx_!,
1717
vmap, vmap!
1818

1919

@@ -39,8 +39,6 @@ include("condense_loopset.jl")
3939
include("reconstruct_loopset.jl")
4040
include("constructors.jl")
4141

42-
export @_avx, _avx, @_avx_, avx_!
43-
4442
# include("precompile.jl")
4543
# _precompile_()
4644

src/add_compute.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ function add_reduction_update_parent!(
8686
)
8787
parent = getop(ls, var, elementbytes)
8888
isouterreduction = parent.instruction === LOOPCONSTANT
89-
Instr = Instruction(instr)
89+
Instr = instruction(ls, instr)
9090
instrclass = reduction_instruction_class(Instr) # key allows for faster lookups
9191
# if parent is not an outer reduction...
9292
if !isouterreduction
@@ -164,7 +164,7 @@ function add_compute!(
164164
if reduction || search_tree(parents, var)
165165
add_reduction_update_parent!(parents, deps, reduceddeps, ls, var, instr, reduction, elementbytes)
166166
else
167-
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, parents)
167+
op = Operation(length(operations(ls)), var, elementbytes, instruction(ls,instr), compute, deps, reduceddeps, parents)
168168
pushop!(ls, op, var)
169169
end
170170
end

src/broadcast.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,12 +180,12 @@ end
180180
# size of dest determines loops
181181
# function vmaterialize!(
182182
@generated function vmaterialize!(
183-
dest::AbstractArray{T,N}, bc::BC
184-
) where {T <: Union{Float32,Float64}, N, BC <: Broadcasted}
183+
dest::AbstractArray{T,N}, bc::BC, ::Val{Mod}
184+
) where {T <: Union{Float32,Float64}, N, BC <: Broadcasted, Mod}
185185
# we have an N dimensional loop.
186186
# need to construct the LoopSet
187187
loopsyms = [gensym(:n) for n 1:N]
188-
ls = LoopSet()
188+
ls = LoopSet(Mod)
189189
sizes = Expr(:tuple)
190190
for (n,itersym) enumerate(loopsyms)
191191
Nsym = gensym(:N)
@@ -204,12 +204,12 @@ end
204204
# ls
205205
end
206206
@generated function vmaterialize!(
207-
dest′::Union{Adjoint{T,A},Transpose{T,A}}, bc::BC
208-
) where {T <: Union{Float32,Float64}, N, A <: AbstractArray{T,N}, BC <: Broadcasted}
207+
dest′::Union{Adjoint{T,A},Transpose{T,A}}, bc::BC, ::Val{Mod}
208+
) where {T <: Union{Float32,Float64}, N, A <: AbstractArray{T,N}, BC <: Broadcasted, Mod}
209209
# we have an N dimensional loop.
210210
# need to construct the LoopSet
211211
loopsyms = [gensym(:n) for n 1:N]
212-
ls = LoopSet()
212+
ls = LoopSet(Mod)
213213
pushpreamble!(ls, Expr(:(=), :dest, Expr(:call, :parent, :dest′)))
214214
sizes = Expr(:tuple)
215215
for (n,itersym) enumerate(loopsyms)
@@ -229,7 +229,7 @@ end
229229
# ls
230230
end
231231

232-
@inline function vmaterialize(bc::Broadcasted)
232+
@inline function vmaterialize(bc::Broadcasted, ::Val{Mod}) where {Mod}
233233
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)
234-
vmaterialize!(similar(bc, ElType), bc)
234+
vmaterialize!(similar(bc, ElType), bc, Val{Mod}())
235235
end

src/constructors.jl

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ function Base.copyto!(ls::LoopSet, q::Expr)
66
add_loop!(ls, q)
77
end
88

9-
function add_ci_call!(q::Expr, f, args, syms, i)
9+
function add_ci_call!(q::Expr, f, args, syms, i, mod = nothing)
1010
call = Expr(:call, f)
1111
for arg @view(args[2:end])
1212
if arg isa Core.SSAValue
@@ -15,10 +15,11 @@ function add_ci_call!(q::Expr, f, args, syms, i)
1515
push!(call.args, arg)
1616
end
1717
end
18+
mod === nothing || push!(call.args, Expr(:call, Expr(:curly, :Val, QuoteNode(mod))))
1819
push!(q.args, Expr(:(=), syms[i], call))
1920
end
2021

21-
function substitute_broadcast(q::Expr)
22+
function substitute_broadcast(q::Expr, mod::Symbol)
2223
ci = first(Meta.lower(LoopVectorization, q).args).code
2324
nargs = length(ci)-1
2425
ex = Expr(:block,)
@@ -30,32 +31,24 @@ function substitute_broadcast(q::Expr)
3031
if ciₙ.head === :(=)
3132
push!(ex.args, Expr(:(=), f, syms[((ciₙargs[2])::Core.SSAValue).id]))
3233
elseif f === GlobalRef(Base, :materialize!)
33-
add_ci_call!(ex, lv(:vmaterialize!), ciₙargs, syms, n)
34+
add_ci_call!(ex, lv(:vmaterialize!), ciₙargs, syms, n, mod)
3435
elseif f === GlobalRef(Base, :materialize)
35-
add_ci_call!(ex, lv(:vmaterialize), ciₙargs, syms, n)
36+
add_ci_call!(ex, lv(:vmaterialize), ciₙargs, syms, n, mod)
3637
else
3738
add_ci_call!(ex, f, ciₙargs, syms, n)
3839
end
3940
end
4041
ex
4142
end
4243

43-
function LoopSet(q::Expr)
44+
function LoopSet(q::Expr, mod::Symbol = :LoopVectorization)
4445
q = SIMDPirates.contract_pass(q)
45-
ls = LoopSet()
46+
ls = LoopSet(mod)
4647
copyto!(ls, q)
4748
resize!(ls.loop_order, num_loops(ls))
4849
ls
4950
end
5051

51-
function LoopSet(q::Expr, types::Dict{Symbol,DataType})
52-
q = SIMDPirates.contract_pass(q)
53-
ls = LoopSet()
54-
copyto!(ls, q, types)
55-
resize!(ls.loop_order, num_loops(ls))
56-
ls
57-
end
58-
5952

6053
"""
6154
@avx
@@ -91,10 +84,11 @@ true
9184
9285
"""
9386
macro avx(q)
87+
mod = Symbol(__module__)
9488
q2 = if q.head === :for
95-
setup_call(LoopSet(q))
89+
setup_call(LoopSet(q, mod))
9690
else# assume broadcast
97-
substitute_broadcast(q)
91+
substitute_broadcast(q, mod)
9892
end
9993
esc(q2)
10094
end
@@ -136,24 +130,24 @@ macro avx(arg, q)
136130
@assert q.head === :for
137131
@assert arg.head === :(=)
138132
inline, U, T = check_macro_kwarg(arg)
139-
esc(setup_call(LoopSet(q), inline, U, T))
133+
esc(setup_call(LoopSet(q, Symbol(__module__)), inline, U, T))
140134
end
141135
macro avx(arg1, arg2, q)
142136
@assert q.head === :for
143137
inline, U, T = check_macro_kwarg(arg1)
144138
inline, U, T = check_macro_kwarg(arg2, inline, U, T)
145-
esc(setup_call(LoopSet(q), inline, U, T))
139+
esc(setup_call(LoopSet(q, Symbol(__module__)), inline, U, T))
146140
end
147141

148142

149143

150144
macro _avx(q)
151-
esc(lower(LoopSet(q)))
145+
esc(lower(LoopSet(q, Symbol(__module__))))
152146
end
153147
macro _avx(arg, q)
154148
@assert q.head === :for
155149
inline, U, T = check_macro_kwarg(arg)
156-
esc(lower(LoopSet(q), U, T))
150+
esc(lower(LoopSet(q, Symbol(__module__)), U, T))
157151
end
158152

159153

0 commit comments

Comments
 (0)