Skip to content

Commit c4b7d1e

Browse files
committed
spmd -> brrr, and sometimes init reductions with loads rather than zeros when that load would otherwise be added at the end
1 parent e8b2d0d commit c4b7d1e

12 files changed

+216
-145
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ using Requires
5151

5252

5353
export LowDimArray, stridedpointer, indices,
54-
@avx, @avxt, @spmd, @spmdt, *ˡ, _avx_!,
54+
@avx, @avxt, @brrr, @tbrrr, *ˡ, _avx_!,
5555
vmap, vmap!, vmapt, vmapt!, vmapnt, vmapnt!, vmapntt, vmapntt!,
5656
tanh_fast, sigmoid_fast,
5757
vfilter, vfilter!, vmapreduce, vreduce

src/codegen/lower_compute.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ function lower_compute!(
525525
# elseif parents_u₂syms[n] & (!u₂unrolledsym)
526526
#&& (isouterreduction(ls, opp) != -1)
527527
# this checks if the parent is u₂ unrolled but this operation is not, in which case we need to reduce it.
528-
reduced_u₂ = reduce_expr_u₂(mangledvar(opp), instruction(opp), ureduct(ls))
528+
reduced_u₂ = reduce_expr_u₂(mangledvar(opp), instruction(opp), u₂max, Symbol("__", u₁))#ureduct(ls))
529529
reducedparentname = gensym!(ls, "reducedop")
530530
push!(q.args, Expr(:(=), reducedparentname, reduced_u₂))
531531
reduced_u₂ = reduce_parent!(q, ls, op, opp, reducedparentname)

src/codegen/lower_constant.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,11 @@ function typeof_sym(ls::LoopSet, op::Operation, zerotyp::NumberType)
4242
ELTYPESYMBOL
4343
end
4444
end
45-
# function in_reduced_children(op::Operation, s::Symbol)
46-
# end
4745

4846
function lower_zero!(
4947
q::Expr, op::Operation, ls::LoopSet, ua::UnrollArgs, zerotyp::NumberType = zerotype(ls, op)
5048
)
51-
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = ua
49+
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, vloop, u₂max, suffix = ua
5250
mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, suffix, ls)
5351
!opu₂ && suffix > 0 && return
5452
# TODO: for u₁, needs to consider if reducedchildren are u₁-unrolled

src/codegen/lower_store.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,17 @@ function storeinstr_preprend(op::Operation, vloopsym::Symbol)
1919
reduction_to_scalar(reduction_instruction_class(instruction(opp)))
2020
end
2121

22-
function reduce_expr_u₂(toreduct::Symbol, instr::Instruction, u₂::Int)
22+
function reduce_expr_u₂(toreduct::Symbol, instr::Instruction, u₂::Int, suffix::Symbol)
2323
t = Expr(:tuple)
2424
for u 0:u₂-1
25-
push!(t.args, Symbol(toreduct, u))
25+
push!(t.args, Symbol(toreduct, u, suffix))
2626
end
2727
Expr(:call, lv(:reduce_tup), reduce_to_onevecunroll(instr), t)
2828
end
2929
function reduce_expr!(q::Expr, toreduct::Symbol, instr::Instruction, u₁::Int, u₂::Int, isu₁unrolled::Bool, isu₂unrolled::Bool)
3030
if isu₂unrolled# u₂ != -1
3131
_toreduct = Symbol(toreduct, 0)
32-
push!(q.args, Expr(:(=), _toreduct, reduce_expr_u₂(toreduct, instr, u₂)))
32+
push!(q.args, Expr(:(=), _toreduct, reduce_expr_u₂(toreduct, instr, u₂, Symbol(""))))
3333
else#if u₂ == -1
3434
_toreduct = Symbol(toreduct, '_', u₁)
3535
# else

src/codegen/operation_evaluation_order.jl

Lines changed: 80 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -56,40 +56,94 @@ function set_upstream_family!(adal::Vector{T}, op::Operation, val::T, ld::Vector
5656
set_upstream_family!(adal, opp, val, ld, id)
5757
end
5858
end
59-
59+
function search_for_reductinit!(op::Operation, opswap::Operation, var::Symbol, loopdeps::Vector{Symbol})
60+
for (i,opp) enumerate(parents(op))
61+
if (name(opp) === var) && (length(reduceddependencies(opp)) == 0) && (length(loopdependencies(opp)) == length(loopdeps)) && (length(children(opp)) == 1)
62+
if all(in(loopdeps), loopdependencies(opp))
63+
parents(op)[i] = opswap
64+
return opp
65+
end
66+
end
67+
opcheck = search_for_reductinit!(opp, opswap, var, loopdeps)
68+
opcheck === opp || return opcheck
69+
end
70+
return op
71+
end
6072
function addoptoorder!(
6173
ls::LoopSet, included_vars::Vector{Bool}, place_after_loop::Vector{Bool}, op::Operation,
6274
loopsym::Symbol, _n::Int, u₁loop::Symbol, u₂loop::Symbol, vectorized::Symbol, u₂max::Int
6375
)
64-
lo = ls.loop_order
65-
id = identifier(op)
66-
included_vars[id] || return nothing
67-
loopsym loopdependencies(op) || return nothing
68-
for opp parents(op) # ensure parents are added first
69-
addoptoorder!(ls, included_vars, place_after_loop, opp, loopsym, _n, u₁loop, u₂loop, vectorized, u₂max)
70-
end
71-
included_vars[id] || return nothing
72-
included_vars[id] = false
73-
isunrolled = (isu₁unrolled(op)) + 1
74-
istiled = isu₂unrolled(op) + 1
75-
# optype = Int(op.node_type) + 1
76-
after_loop = place_after_loop[id] + 1
77-
if !isloopvalue(op)
78-
isnopidentity(ls, op, u₁loop, u₂loop, vectorized, u₂max) || push!(lo[isunrolled,istiled,after_loop,_n], op)
79-
# if istiled
80-
# isnopidentity(ls, op, u₁loop, u₂loop, vectorized, u₂max, u₂max) || push!(lo[isunrolled,2,after_loop,_n], op)
81-
# else
82-
# isnopidentity(ls, op, u₁loop, u₂loop, vectorized, u₂max, nothing) || push!(lo[isunrolled,1,after_loop,_n], op)
83-
# end
76+
lo = ls.loop_order
77+
id = identifier(op)
78+
included_vars[id] || return nothing
79+
loopsym loopdependencies(op) || return nothing
80+
for opp parents(op) # ensure parents are added first
81+
addoptoorder!(ls, included_vars, place_after_loop, opp, loopsym, _n, u₁loop, u₂loop, vectorized, u₂max)
82+
end
83+
included_vars[id] || return nothing
84+
included_vars[id] = false
85+
isunrolled = (isu₁unrolled(op)) + 1
86+
istiled = isu₂unrolled(op) + 1
87+
# optype = Int(op.node_type) + 1
88+
after_loop = place_after_loop[id] + 1
89+
if !isloopvalue(op)
90+
isnopidentity(ls, op, u₁loop, u₂loop, vectorized, u₂max) || push!(lo[isunrolled,istiled,after_loop,_n], op)
91+
# if istiled
92+
# isnopidentity(ls, op, u₁loop, u₂loop, vectorized, u₂max, u₂max) || push!(lo[isunrolled,2,after_loop,_n], op)
93+
# else
94+
# isnopidentity(ls, op, u₁loop, u₂loop, vectorized, u₂max, nothing) || push!(lo[isunrolled,1,after_loop,_n], op)
95+
# end
96+
end
97+
# @show op, after_loop
98+
# isloopvalue(op) || push!(lo[isunrolled,istiled,after_loop,_n], op)
99+
# all(opp -> iszero(length(reduceddependencies(opp))), parents(op)) &&
100+
set_upstream_family!(place_after_loop, op, false, loopdependencies(op), identifier(op)) # parents that have already been included are not moved, so no need to check included_vars to filter
101+
nothing
102+
end
103+
function replace_reduct_init!(ls::LoopSet, op::Operation, opsub::Operation, opcheck::Operation)
104+
deleteat!(parents(op), 2)
105+
op.variable = opcheck.variable
106+
opsub.variable = opcheck.variable
107+
op.mangledvariable = opcheck.mangledvariable
108+
opsub.mangledvariable = opcheck.mangledvariable
109+
op.instruction = instruction(:identity)
110+
fill_children!(ls)
111+
end
112+
function nounrollreduction(op::Operation, u₁loop::Symbol, u₂loop::Symbol, vectorized::Symbol)
113+
reduceddeps = reduceddependencies(op)
114+
(vectorized reduceddeps) &&
115+
(u₁loop reduceddeps) &&
116+
(u₂loop reduceddeps)
117+
end
118+
function load_short_static_reduction_first!(ls::LoopSet, u₁loop::Symbol, u₂loop::Symbol, vectorized::Symbol)
119+
for op operations(ls)
120+
iscompute(op) || continue
121+
length(reduceddependencies(op)) == 0 && continue
122+
if (instruction(op).instr === :reduced_add)
123+
vecloop = getloop(ls, vectorized)
124+
if isstaticloop(vecloop) && (length(vecloop) 16) && nounrollreduction(op, u₁loop, u₂loop, vectorized)
125+
opsub = parents(op)[2]
126+
length(children(opsub)) == 1 || continue
127+
opsearch = parents(op)[1]
128+
opcheck = search_for_reductinit!(opsearch, opsub, name(opsearch), loopdependencies(op))
129+
opcheck === opsearch || replace_reduct_init!(ls, op, opsub, opcheck)
130+
131+
end
132+
elseif (instruction(op).instr === :add_fast) && (instruction(first(parents(op))).instr === :identity)
133+
vecloop = getloop(ls, vectorized)
134+
if isstaticloop(vecloop) && (length(vecloop) 16) && nounrollreduction(op, u₁loop, u₂loop, vectorized)
135+
opsub = parents(op)[2]
136+
((length(reduceddependencies(opsub)) == 0) & (length(children(opsub)) == 1)) || continue
137+
opsearch = parents(op)[1]
138+
opcheck = search_for_reductinit!(opsearch, opsub, name(opsearch), loopdependencies(op))
139+
opcheck === opsearch || replace_reduct_init!(ls, op, opsub, opcheck)
140+
end
84141
end
85-
# @show op, after_loop
86-
# isloopvalue(op) || push!(lo[isunrolled,istiled,after_loop,_n], op)
87-
# all(opp -> iszero(length(reduceddependencies(opp))), parents(op)) &&
88-
set_upstream_family!(place_after_loop, op, false, loopdependencies(op), identifier(op)) # parents that have already been included are not moved, so no need to check included_vars to filter
89-
nothing
142+
end
90143
end
91144

92145
function fillorder!(ls::LoopSet, order::Vector{Symbol}, u₁loop::Symbol, u₂loop::Symbol, u₂max::Int, vectorized::Symbol)
146+
load_short_static_reduction_first!(ls, u₁loop, u₂loop, vectorized)
93147
lo = ls.loop_order
94148
resize!(lo, length(ls.loopsymbols))
95149
ro = lo.loopnames # reverse order; will have same order as lo

src/constructors.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ function process_args(args; inline = false, check_empty = false, u₁ = zero(Int
111111
end
112112
inline, check_empty, u₁, u₂, threads
113113
end
114-
function spmd_macro(mod, src, q, args...)
114+
function brrr_macro(mod, src, q, args...)
115115
q = macroexpand(mod, q)
116116

117117
if q.head === :for
@@ -124,12 +124,12 @@ function spmd_macro(mod, src, q, args...)
124124
end
125125
end
126126
"""
127-
@spmd
127+
@brrr
128128
129129
Annotate a `for` loop, or a set of nested `for` loops whose bounds are constant across iterations, to optimize the computation. For example:
130130
131131
function AmulB!(C, A, B)
132-
@spmd for m ∈ 1:size(A,1), n ∈ 1:size(B,2)
132+
@brrr for m ∈ 1:size(A,1), n ∈ 1:size(B,2)
133133
Cₘₙ = zero(eltype(C))
134134
for k ∈ 1:size(A,2)
135135
Cₘₙ += A[m,k] * B[k,n]
@@ -148,23 +148,23 @@ julia> using LoopVectorization
148148
149149
julia> a = rand(100);
150150
151-
julia> b = @spmd exp.(2 .* a);
151+
julia> b = @brrr exp.(2 .* a);
152152
153153
julia> c = similar(b);
154154
155-
julia> @spmd @. c = exp(2a);
155+
julia> @brrr @. c = exp(2a);
156156
157157
julia> b ≈ c
158158
true
159159
```
160160
161161
# Extended help
162162
163-
Advanced users can customize the implementation of the `@spmd`-annotated block
163+
Advanced users can customize the implementation of the `@brrr`-annotated block
164164
using keyword arguments:
165165
166166
```
167-
@spmd inline=false unroll=2 body
167+
@brrr inline=false unroll=2 body
168168
```
169169
170170
where `body` is the code of the block (e.g., `for ... end`).
@@ -197,42 +197,42 @@ but it applies to the loop ordering and unrolling that will be chosen by LoopVec
197197
`uᵢ=0` (the default) indicates that LoopVectorization should pick its own value,
198198
and `uᵢ=-1` disables unrolling for the correspond loop.
199199
200-
The `@spmd` macro also checks the array arguments using `LoopVectorization.check_args` to try and determine
200+
The `@brrr` macro also checks the array arguments using `LoopVectorization.check_args` to try and determine
201201
if they are compatible with the macro. If `check_args` returns false, a fall back loop annotated with `@inbounds`
202202
and `@fastmath` is generated. Note that `VectorizationBase` provides functions such as `vadd` and `vmul` that will
203-
ignore `@fastmath`, preserving IEEE semantics both within `@spmd` and `@fastmath`.
203+
ignore `@fastmath`, preserving IEEE semantics both within `@brrr` and `@fastmath`.
204204
`check_args` currently returns false for some wrapper types like `LinearAlgebra.UpperTriangular`, requiring you to
205205
use their `parent`. Triangular loops aren't yet supported.
206206
"""
207-
macro spmd(args...)
208-
spmd_macro(__module__, __source__, last(args), Base.front(args)...)
207+
macro brrr(args...)
208+
brrr_macro(__module__, __source__, last(args), Base.front(args)...)
209209
end
210210
"""
211-
Equivalent to `@spmd`, except it adds `thread=true` as the first keyword argument.
211+
Equivalent to `@brrr`, except it adds `thread=true` as the first keyword argument.
212212
Note that later arguments take precendence.
213213
214-
Meant for convenience, as `@spmdt` is shorter than `@spmd thread=true`.
214+
Meant for convenience, as `@tbrrr` is shorter than `@brrr thread=true`.
215215
"""
216-
macro spmdt(args...)
217-
spmd_macro(__module__, __source__, last(args), :(thread=true), Base.front(args)...)
216+
macro tbrrr(args...)
217+
brrr_macro(__module__, __source__, last(args), :(thread=true), Base.front(args)...)
218218
end
219219

220220
"""
221-
@_spmd
221+
@_brrr
222222
223-
This macro mostly exists for debugging/testing purposes. It does not support many of the use cases of [`@spmd`](@ref).
223+
This macro mostly exists for debugging/testing purposes. It does not support many of the use cases of [`@brrr`](@ref).
224224
It emits loops directly, rather than punting to an `@generated` function, meaning it doesn't have access to type
225225
information when generating code or analyzing the loops, often leading to bad performance.
226226
227-
This macro accepts the `inline` and `unroll` keyword arguments like `@spmd`, but ignores the `check_empty` argument.
227+
This macro accepts the `inline` and `unroll` keyword arguments like `@brrr`, but ignores the `check_empty` argument.
228228
"""
229-
macro _spmd(q)
229+
macro _brrr(q)
230230
q = macroexpand(__module__, q)
231231
ls = LoopSet(q, __module__)
232232
set_hw!(ls)
233233
esc(Expr(:block, ls.prepreamble, lower_and_split_loops(ls, -1)))
234234
end
235-
macro _spmd(arg, q)
235+
macro _brrr(arg, q)
236236
@assert q.head === :for
237237
q = macroexpand(__module__, q)
238238
inline, check_empty, u₁, u₂ = check_macro_kwarg(arg, false, false, zero(Int8), zero(Int8), 1)
@@ -241,14 +241,14 @@ macro _spmd(arg, q)
241241
esc(Expr(:block, ls.prepreamble, lower(ls, u₁ % Int, u₂ % Int, -1)))
242242
end
243243

244-
macro spmd_debug(q)
244+
macro brrr_debug(q)
245245
q = macroexpand(__module__, q)
246246
ls = LoopSet(q, __module__)
247247
esc(LoopVectorization.setup_call_debug(ls))
248248
end
249249

250250
# define aliases
251-
const var"@avx" = var"@spmd"
252-
const var"@avxt" = var"@spmdt"
253-
const var"@avx_debug" = var"@spmd_debug"
251+
const var"@avx" = var"@brrr"
252+
const var"@avxt" = var"@tbrrr"
253+
const var"@avx_debug" = var"@brrr_debug"
254254

0 commit comments

Comments
 (0)