Skip to content

Commit 4074301

Browse files
committed
More fixes, rewrite lower_load_for_optranslation to call _lower_load so it can take advantage of those optimizations.
1 parent 814d403 commit 4074301

File tree

13 files changed

+161
-130
lines changed

13 files changed

+161
-130
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using VectorizationBase: register_size, register_count, cache_linesize, has_opma
99
maybestaticfirst, maybestaticlast, scalar_less, scalar_greaterequal, gep, gesp, pointerforcomparison, NativeTypes,
1010
vfmadd, vfmsub, vfnmadd, vfnmsub, vfmadd_fast, vfmsub_fast, vfnmadd_fast, vfnmsub_fast, vfmadd231, vfmsub231, vfnmadd231, vfnmsub231,
1111
vfma_fast, vmuladd_fast, vdiv_fast, vadd_fast, vsub_fast, vmul_fast,
12-
relu, stridedpointer, StridedPointer, StridedBitPointer, AbstractStridedPointer,
12+
relu, stridedpointer, StridedPointer, StridedBitPointer, AbstractStridedPointer, _vload, _vstore!,
1313
reduced_add, reduced_prod, reduce_to_add, reduce_to_prod, reduced_max, reduced_min, reduce_to_max, reduce_to_min,
1414
vsum, vprod, vmaximum, vminimum, unwrap, Unroll, VecUnroll,
1515
preserve_buffer, zero_vecunroll, vbroadcast_vecunroll, _vzero, _vbroadcast,

src/broadcast.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ function add_broadcast!(
111111
pushprepreamble!(ls, Expr(:(=), Klen, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__,Symbol(@__FILE__)), Expr(:ref, Expr(:call, :size, mB), 1))))
112112
pushpreamble!(ls, Expr(:(=), Krange, Expr(:call, :(:), staticexpr(1), Klen)))
113113
k = gensym!(ls, "k")
114-
add_loop!(ls, Loop(k, 1, Klen, Krange, Klen), k)
114+
add_loop!(ls, Loop(k, 1, Klen, 1, Krange, Klen), k)
115115
m = loopsyms[1];
116116
if numdims(B) == 1
117117
bloopsyms = Symbol[k]
@@ -345,7 +345,7 @@ function add_broadcast_loops!(ls::LoopSet, loopsyms::Vector{Symbol}, destsym::Sy
345345
Nlower = gensym!(ls, "N")
346346
Nupper = gensym!(ls, "N")
347347
Nlen = gensym!(ls, "N")
348-
add_loop!(ls, Loop(itersym, Nlower, Nupper, Nrange, Nlen), itersym)
348+
add_loop!(ls, Loop(itersym, Nlower, Nupper, 1, Nrange, Nlen), itersym)
349349
push!(axes_tuple.args, Nrange)
350350
pushpreamble!(ls, Expr(:(=), Nlower, Expr(:call, lv(:maybestaticfirst), Nrange)))
351351
pushpreamble!(ls, Expr(:(=), Nupper, Expr(:call, lv(:maybestaticlast), Nrange)))

src/codegen/loopstartstopmanager.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ function use_loop_induct_var!(ls::LoopSet, q::Expr, ar::ArrayReferenceMeta, alla
137137
# if ind === names(ls)[us.vloopnum]
138138
# push!(offsetprecalc_descript.args, 0)
139139
# elseif (ind === names(ls)[us.u₁loopnum]) & (us.u₁ > 3)
140-
# use_offsetprecalc = true
140+
use_offsetprecalc = true
141141
# push!(offsetprecalc_descript.args, us.u₁)
142142
# elseif (ind === names(ls)[us.u₂loopnum]) & (us.u₂ > 3)
143143
# use_offsetprecalc = true

src/codegen/lower_load.jl

Lines changed: 90 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,9 @@ function add_prefetches!(q::Expr, ls::LoopSet, op::Operation, td::UnrollArgs, pr
102102
end
103103
nothing
104104
end
105+
broadcastedname(mvar) = Symbol(mvar, "##broadcasted##")
105106
function pushbroadcast!(q::Expr, mvar::Symbol)
106-
push!(q.args, Expr(:(=), Symbol(mvar, "##broadcasted##"), Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, mvar)))
107+
push!(q.args, Expr(:(=), broadcastedname(mvar), Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, mvar)))
107108
end
108109
function lower_load_no_optranslation!(
109110
q::Expr, ls::LoopSet, op::Operation, td::UnrollArgs, mask::Bool, inds_calc_by_ptr_offset::Vector{Bool}
@@ -119,7 +120,7 @@ function lower_load_no_optranslation!(
119120

120121
if all(op.ref.loopedindex)
121122
inds = unrolledindex(op, td, mask, inds_calc_by_ptr_offset)
122-
loadexpr = Expr(:call, lv(:vload), vptr(op), inds)
123+
loadexpr = Expr(:call, lv(:_vload), vptr(op), inds)
123124
add_memory_mask!(loadexpr, op, td, mask)
124125
push!(loadexpr.args, falseexpr, rs) # unaligned load
125126
push!(q.args, Expr(:(=), mvar, loadexpr))
@@ -128,7 +129,7 @@ function lower_load_no_optranslation!(
128129
# for u ∈ 1:u₁
129130
let t = u₁, t = q
130131
inds = mem_offset_u(op, td, inds_calc_by_ptr_offset, true, u-1)
131-
loadexpr = Expr(:call, lv(:vload), vptr(op), inds)
132+
loadexpr = Expr(:call, lv(:_vload), vptr(op), inds)
132133
add_memory_mask!(loadexpr, op, td, mask & ((u == u₁) | isvectorized(op)))
133134
push!(loadexpr.args, falseexpr, rs)
134135
# push!(t.args, loadexpr)
@@ -137,7 +138,7 @@ function lower_load_no_optranslation!(
137138
# push!(q.args, Expr(:(=), mvar, Expr(:call, lv(:VecUnroll), t)))
138139
else
139140
inds = mem_offset_u(op, td, inds_calc_by_ptr_offset, true, 0)
140-
loadexpr = Expr(:call, lv(:vload), vptr(op), inds)
141+
loadexpr = Expr(:call, lv(:_vload), vptr(op), inds)
141142
add_memory_mask!(loadexpr, op, td, mask)
142143
push!(loadexpr.args, falseexpr, rs)
143144
push!(q.args, Expr(:(=), mvar, loadexpr))
@@ -183,82 +184,95 @@ function indisvectorized(ls::LoopSet, ind::Symbol)
183184
end
184185
@inline firstunroll(vu::VecUnroll) = getfield(getfield(vu,:data),1,false)
185186
@inline firstunroll(x) = x
187+
@inline lastunroll(vu::VecUnroll) = last(getfield(vu,:data))
188+
@inline lastunroll(x) = x
186189
function lower_load_for_optranslation!(
187190
q::Expr, op::Operation, ls::LoopSet, td::UnrollArgs, mask::Bool, translationind::Int
188191
)
189-
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = td
192+
@unpack u₁loop, u₂loop, vloop, u₁, u₂max, suffix = td
193+
194+
# @unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = td
190195
iszero(suffix) || return
191196

197+
total_unroll = u₁ + u₂max - 1
198+
199+
200+
mref = op.ref
201+
inds_by_ptroff = indices_calculated_by_pointer_offsets(ls, mref)
192202
# initial offset pointer
193-
gespinds = mem_offset(op, UnrollArgs(td, u₁), indices_calculated_by_pointer_offsets(ls, op.ref), false)
203+
204+
# Unroll directions can be + or -
205+
# we want to start at minimum position.
206+
step₁ = gethint(step(u₁loop))
207+
step₂ = gethint(step(u₂loop))
208+
209+
210+
# abs of steps are equal
211+
#
212+
equal_steps = step₁ == step₂
213+
_td = UnrollArgs(u₁loop, u₂loop, vloop, total_unroll, u₂max, Core.ifelse(equal_steps, 0, u₂max - 1))
214+
gespinds = mem_offset(op, _td, inds_by_ptroff, false)
194215
ptr = vptr(op)
195216
gptr = Symbol(ptr, "##GESPED##")
196217
for i eachindex(gespinds.args)
197218
if i == translationind
198-
gespinds.args[i] = Expr(:call, lv(:firstunroll), gespinds.args[i])
219+
gespinds.args[i] = Expr(:call, lv(Core.ifelse(equal_steps, :firstunroll, :lastunroll)), gespinds.args[i])
199220
else
200221
gespinds.args[i] = Expr(:call, lv(:data), gespinds.args[i])
201222
end
202223
end
203224
push!(q.args, Expr(:(=), gptr, Expr(:call, lv(:gesp), ptr, gespinds)))
204225

205-
shouldbroadcast = (!isvectorized(op)) && any(isvectorized, children(op))
226+
fill!(inds_by_ptroff, true)
206227

207-
inds = Expr(:tuple)
208-
indices = getindicesonly(op)
209-
for (i,ind) enumerate(indices)
210-
if i == translationind # vectorized ind cannot be the translation ind
211-
push!(inds.args, Expr(:call, Expr(:curly, lv(:Static), 0)))
212-
elseif (ind === vloopsym) || indisvectorized(ls, ind)
213-
push!(inds.args, _MMind(Expr(:call, lv(:Zero))))
214-
else
215-
push!(inds.args, Expr(:call, lv(:Zero)))
216-
end
217-
end
218-
variable_name0 = variable_name(op, 0)
219-
varbase = Symbol("##var##", variable_name0)
220-
loadcall = Expr(:call, lv(:vload), gptr, copy(inds))
221-
falseexpr = Expr(:call, lv(:False)); rs = staticexpr(reg_size(ls));
222-
mask && push!(loadcall.args, MASKSYMBOL)
223-
push!(loadcall.args, falseexpr, rs)
228+
@unpack ref, loopedindex = mref
229+
indices = getindicesonly(ref)
230+
old_translation_index = indices[translationind]
231+
indices[translationind] = u₁loop.itersymbol
232+
# getindicesonly returns a view of `getindices`
233+
dummyref = ArrayReference(ref.array, getindices(ref), zero(getoffsets(ref)), getstrides(ref))
234+
loopedindex[translationind] = true
235+
dummymref = ArrayReferenceMeta(dummyref, loopedindex, gptr)
224236

225-
varbase0 = Symbol(varbase, 0)
226-
t = Expr(:tuple, varbase0)
227-
push!(q.args, Expr(:(=), varbase0, loadcall))
228-
for u 1:u₁-1
229-
inds.args[translationind] = Expr(:call, Expr(:curly, lv(:Static), u))
230-
loadcall = Expr(:call, lv(:vload), gptr, copy(inds))
231-
mask && push!(loadcall.args, MASKSYMBOL)
232-
push!(loadcall.args, falseexpr, rs)
233-
varbaseu = Symbol(varbase, u)
234-
push!(q.args, Expr(:(=), varbaseu, loadcall))
235-
push!(t.args, varbaseu)
237+
_td = UnrollArgs(u₁loop, u₂loop, vloop, total_unroll, u₂max, -1)
238+
op.ref = dummymref
239+
_lower_load!(q, ls, op, _td, mask)
240+
# set old values
241+
op.ref = mref
242+
loopedindex[translationind] = false
243+
indices[translationind] = old_translation_index
244+
245+
shouldbroadcast = (!isvectorized(op)) && any(isvectorized, children(op))
246+
247+
# now we need to assign the `Vec`s from the `VecUnroll` to the correct name.
248+
variable_name_u = Symbol(variable_name(op, -1), '_', total_unroll)
249+
variable_name_data = Symbol(variable_name_u, "##data##")
250+
push!(q.args, :($variable_name_data = getfield($variable_name_u, 1)))
251+
if shouldbroadcast
252+
broadcasted_data = broadcastedname(variable_name_data)
253+
push!(q.args, :($broadcasted_data = getfield($(broadcastedname(variable_name_u)), 1)))
236254
end
237-
vecunroll_name = Symbol(variable_name0, '_', u₁)
238-
push!(q.args, Expr(:(=), vecunroll_name, Expr(:call, lv(:VecUnroll), t)))
239-
shouldbroadcast && pushbroadcast!(q, vecunroll_name)
240-
# this takes care of u₂ == 0
241-
offset = u₁
242-
for u₂ 1:u₂max-1
255+
for u₂ 0:u₂max-1
256+
variable_name_u₂ = Symbol(variable_name(op, u₂), '_', u₁)
243257
t = Expr(:tuple)
244-
varold = varbase
245-
varbase = variable_name(op, u₂)
246-
for u 0:u₁-2
247-
varbaseu = Symbol(varbase, u)
248-
push!(q.args, Expr(:(=), varbaseu, Symbol(varold, u + 1)))
249-
push!(t.args, varbaseu)
258+
if shouldbroadcast
259+
tb = Expr(:tuple)
260+
end
261+
for u 1:u₁
262+
uu = if equal_steps
263+
u + u₂
264+
else
265+
u - u₂ + u₂max - 1
266+
end
267+
push!(t.args, :(getfield($variable_name_data, $uu)))
268+
if shouldbroadcast
269+
push!(tb.args, :(getfield($broadcasted_data, $uu)))
270+
end
271+
end
272+
push!(q.args, :($variable_name_u₂ = VecUnroll($t)))
273+
if shouldbroadcast
274+
push!(q.args, :($(broadcastedname(variable_name_u₂)) = VecUnroll($tb)))
250275
end
251-
inds.args[translationind] = Expr(:call, Expr(:curly, lv(:Static), offset))
252-
loadcall = Expr(:call, lv(:vload), gptr, copy(inds))
253-
mask && push!(loadcall.args, MASKSYMBOL)
254-
push!(loadcall.args, falseexpr, rs)
255-
varload = Symbol(varbase, u₁ - 1)
256-
push!(q.args, Expr(:(=), varload, loadcall))
257-
push!(t.args, varload)
258-
offset += 1
259-
vecunroll_name = Symbol(variable_name(op, u₂), '_', u₁)
260-
push!(q.args, Expr(:(=), vecunroll_name, Expr(:call, lv(:VecUnroll), t)))
261-
shouldbroadcast && pushbroadcast!(q, vecunroll_name)
262276
end
263277
nothing
264278
end
@@ -268,40 +282,34 @@ end
268282
function lower_load!(
269283
q::Expr, op::Operation, ls::LoopSet, td::UnrollArgs, mask::Bool
270284
)
271-
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, suffix = td
285+
@unpack u₁, u₂max, u₁loopsym, u₂loopsym, vloopsym, suffix = td
272286
if (suffix != -1) && ls.loadelimination[]
273-
istr, ispl = isoptranslation(ls, op, UnrollSymbols(u₁loopsym, u₂loopsym, vloopsym))
287+
if (u₁ > 1) & (u₂max > 1)
288+
istr, ispl = isoptranslation(ls, op, UnrollSymbols(u₁loopsym, u₂loopsym, vloopsym))
289+
else
290+
istr, ispl = 0, false
291+
end
274292
if !iszero(istr) & ispl
275293
return lower_load_for_optranslation!(q, op, ls, td, mask, istr)
276-
elseif suffix > 0
277-
if u₂loopsym !== vloopsym
278-
mno, id = maxnegativeoffset(ls, op, u₂loopsym)
279-
if -suffix < mno < 0 # already checked that `suffix != -1` above
280-
varnew = variable_name(op, suffix)
281-
varold = variable_name(operations(ls)[id], suffix + mno)
282-
opold = operations(ls)[id]
283-
u = isu₁unrolled(op) ? u₁ : 1
284-
push!(q.args, Expr(:(=), Symbol(varnew, '_', u), Symbol(varold, '_', u)))
285-
# if isu₁unrolled(op)
286-
# for u ∈ 0:u₁-1
287-
# push!(q.args, Expr(:(=), Symbol(varnew, u), Symbol(varold, u)))
288-
# end
289-
# else
290-
291-
# end
292-
return
293-
end
294+
elseif (suffix > 0) && (u₂loopsym !== vloopsym)
295+
mno, id = maxnegativeoffset(ls, op, u₂loopsym)
296+
if -suffix < mno < 0 # already checked that `suffix != -1` above
297+
varnew = variable_name(op, suffix)
298+
varold = variable_name(operations(ls)[id], suffix + mno)
299+
opold = operations(ls)[id]
300+
u = isu₁unrolled(op) ? u₁ : 1
301+
push!(q.args, Expr(:(=), Symbol(varnew, '_', u), Symbol(varold, '_', u)))
302+
return
294303
end
295304
end
296305
end
297306
_lower_load!(q, ls, op, td, mask)
298307
end
299308
function _lower_load!(
300-
q::Expr, ls::LoopSet, op::Operation, td::UnrollArgs, mask::Bool
309+
q::Expr, ls::LoopSet, op::Operation, td::UnrollArgs, mask::Bool, inds_calc_by_ptr_offset::Vector{Bool} = indices_calculated_by_pointer_offsets(ls, op.ref)
301310
)
302311
omop = offsetloadcollection(ls)
303312
batchid, opind = omop.batchedcollectionmap[identifier(op)]
304-
inds_calc_by_ptr_offset = indices_calculated_by_pointer_offsets(ls, op.ref)
305313
# @show batchid == 0 (!isvectorized(op)) rejectinterleave(op, td.vloop, idsformap)
306314
if batchid == 0 || (!isvectorized(op)) || (rejectinterleave(op, td.vloop, omop.batchedcollections[batchid]))
307315
lower_load_no_optranslation!(q, ls, op, td, mask, inds_calc_by_ptr_offset)
@@ -349,7 +357,7 @@ function lower_load_collection!(
349357
end
350358
uinds = Expr(:call, unrollcurl₂, inds)
351359
vp = vptr(op)
352-
loadexpr = Expr(:call, lv(:vload), vp, uinds)
360+
loadexpr = Expr(:call, lv(:_vload), vp, uinds)
353361
# not using `add_memory_mask!(storeexpr, op, ua, mask)` because we checked `isconditionalmemop` earlier in `lower_load_collection!`
354362
(mask && isvectorized(op)) && push!(loadexpr.args, MASKSYMBOL)
355363
push!(loadexpr.args, falseexpr, rs)

src/codegen/lower_store.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ function lower_store_collection!(
9090
end
9191
uinds = Expr(:call, unrollcurl₂, inds)
9292
vp = vptr(op)
93-
storeexpr = Expr(:call, lv(:vstore!), vp, Expr(:call, lv(:VecUnroll), t), uinds)
93+
storeexpr = Expr(:call, lv(:_vstore!), vp, Expr(:call, lv(:VecUnroll), t), uinds)
9494
# not using `add_memory_mask!(storeexpr, op, ua, mask)` because we checked `isconditionalmemop` earlier in `lower_load_collection!`
9595
mask && push!(storeexpr.args, MASKSYMBOL)
9696
push!(storeexpr.args, falseexpr, trueexpr, falseexpr, rs)
@@ -99,7 +99,7 @@ function lower_store_collection!(
9999
end
100100
function lower_store!(
101101
q::Expr, ls::LoopSet, op::Operation, ua::UnrollArgs, mask::Bool,
102-
reductfunc::Symbol = storeinstr_preprend(op, ua.vloopsym), inds_calc_by_ptr_offset = indices_calculated_by_pointer_offsets(ls, op.ref)
102+
reductfunc::Symbol = storeinstr_preprend(op, ua.vloop.itersymbol), inds_calc_by_ptr_offset = indices_calculated_by_pointer_offsets(ls, op.ref)
103103
)
104104
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, vloop, u₂max, suffix = ua
105105

@@ -124,9 +124,9 @@ function lower_store!(
124124
inds = unrolledindex(op, ua, mask, inds_calc_by_ptr_offset)
125125

126126
storeexpr = if reductfunc === Symbol("")
127-
Expr(:call, lv(:vstore!), vptr(op), mvar, inds)
127+
Expr(:call, lv(:_vstore!), vptr(op), mvar, inds)
128128
else
129-
Expr(:call, lv(:vstore!), lv(reductfunc), vptr(op), mvar, inds)
129+
Expr(:call, lv(:_vstore!), lv(reductfunc), vptr(op), mvar, inds)
130130
end
131131
add_memory_mask!(storeexpr, op, ua, mask)
132132
push!(storeexpr.args, falseexpr, trueexpr, falseexpr, rs)
@@ -138,9 +138,9 @@ function lower_store!(
138138
mvaru = :(getfield($mvard, $u, false))
139139
inds = mem_offset_u(op, ua, inds_calc_by_ptr_offset, true, u-1)
140140
storeexpr = if reductfunc === Symbol("")
141-
Expr(:call, lv(:vstore!), vptr(op), mvaru, inds)
141+
Expr(:call, lv(:_vstore!), vptr(op), mvaru, inds)
142142
else
143-
Expr(:call, lv(:vstore!), lv(reductfunc), vptr(op), mvaru, inds)
143+
Expr(:call, lv(:_vstore!), lv(reductfunc), vptr(op), mvaru, inds)
144144
end
145145
add_memory_mask!(storeexpr, op, ua, mask & ((u == u₁) | isvectorized(op)))
146146
push!(storeexpr.args, falseexpr, trueexpr, falseexpr, rs)
@@ -149,9 +149,9 @@ function lower_store!(
149149
else
150150
inds = mem_offset_u(op, ua, inds_calc_by_ptr_offset, true, 0)
151151
storeexpr = if reductfunc === Symbol("")
152-
Expr(:call, lv(:vstore!), vptr(op), mvar, inds)
152+
Expr(:call, lv(:_vstore!), vptr(op), mvar, inds)
153153
else
154-
Expr(:call, lv(:vstore!), lv(reductfunc), vptr(op), mvar, inds)
154+
Expr(:call, lv(:_vstore!), lv(reductfunc), vptr(op), mvar, inds)
155155
end
156156
add_memory_mask!(storeexpr, op, ua, mask)
157157
push!(storeexpr.args, falseexpr, trueexpr, falseexpr, rs)
@@ -221,7 +221,7 @@ function lower_tiled_store!(blockq::Expr, op::Operation, ls::LoopSet, ua::Unroll
221221
inds = Expr(:call, unrollcurl₁, inds)
222222
end
223223
uinds = Expr(:call, unrollcurl₂, inds)
224-
storeexpr = Expr(:call, lv(:vstore!), vptr(op), vut, uinds)
224+
storeexpr = Expr(:call, lv(:_vstore!), vptr(op), vut, uinds)
225225
if mask && isvectorized(op)
226226
# add_memory_mask!(storeexpr, op, ua, mask)
227227
# we checked for `isconditionalmemop` earlier, so we skip this check

0 commit comments

Comments
 (0)