Skip to content

Commit 3937c88

Browse files
committed
misc fixes
1 parent 609237f commit 3937c88

17 files changed

+236
-90
lines changed

src/codegen/loopstartstopmanager.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,9 @@ function pointermax_index(
250250
else
251251
staticexpr(stophint - sub)
252252
end
253-
stride = getstrides(ar)[i]
253+
stride = getstrides(ar)[j]
254254
if isknown(incr)
255-
stride *= gethint
255+
stride *= gethint(incr)
256256
else
257257
_ind = mulexpr(_ind, getsym(incr))
258258
end
@@ -353,7 +353,7 @@ function append_pointer_maxes!(
353353
dim = length(getindicesonly(ar))
354354
# OFFSETPRECALCDEF = true
355355
# if OFFSETPRECALCDEF
356-
strd = getstrides(ar)[dim]
356+
strd = getstrides(ar)[ind]
357357
for sub 0:submax-1
358358
ptrcmp = Expr(:call, lv(:gesp), pointercompbase, offsetindex(dim, ind, (submax - sub)*strd, isvectorized, incr))
359359
push!(loopstart.args, Expr(:(=), maxsym(vptr_ar, sub), ptrcmp))
@@ -373,6 +373,7 @@ function append_pointer_maxes!(
373373
end
374374
function append_pointer_maxes!(loopstart::Expr, ls::LoopSet, ar::ArrayReferenceMeta, n::Int, submax::Int, isvectorized::Bool)
375375
loop = getloop(ls, n)
376+
@assert loop.itersymbol == names(ls)[n]
376377
start = first(loop)
377378
stop = last(loop)
378379
incr = step(loop)

src/codegen/lower_compute.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ function ifelselastexpr(hasf::Bool, M::Int, vargtypes, K::Int, S::Int, maskearly
164164
push!(t.args, :(getfield($(vargs[K]), $m, false)))
165165
end
166166
# push!(q.args, :(VecUnroll($t)::VecUnroll{$N,$W,$T,$V}))
167+
# push!(q.args, Expr(:call, lv(:VecUnroll), t))
167168
push!(q.args, :(VecUnroll($t)))
168169
q
169170
end
@@ -361,7 +362,7 @@ function lower_compute!(
361362
newpname = Symbol(newparentname, '_', u₁)
362363
push!(q.args, Expr(:(=), newpname, Symbol(parentname, '_', u₁)))
363364
# @show newparentop op instruction(newparentop)
364-
reduce_expr!(q, newparentname, instruction(newparentop), u₁, -1)
365+
reduce_expr!(q, newparentname, instruction(newparentop), u₁, -1, true)
365366
push!(q.args, Expr(:(=), Symbol(newparentname, '_', 1), Symbol(newparentname, "##onevec##")))
366367
end
367368
end

src/codegen/lower_load.jl

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -124,17 +124,17 @@ function lower_load_no_optranslation!(
124124
push!(loadexpr.args, falseexpr, rs) # unaligned load
125125
push!(q.args, Expr(:(=), mvar, loadexpr))
126126
elseif u₁ > 1
127-
# t = Expr(:tuple)
128-
# for u ∈ 1:u₁
129-
let t = u₁, t = q
127+
t = Expr(:tuple)
128+
for u 1:u₁
130129
inds = mem_offset_u(op, td, inds_calc_by_ptr_offset, true, u-1)
131130
loadexpr = Expr(:call, lv(:_vload), vptr(op), inds)
132-
add_memory_mask!(loadexpr, op, td, mask & ((u == u₁) | isvectorized(op)))
131+
domask = mask && (isvectorized(op) & ((u == u₁) | (vloopsym !== u₁loopsym)))
132+
add_memory_mask!(loadexpr, op, td, domask)
133133
push!(loadexpr.args, falseexpr, rs)
134-
# push!(t.args, loadexpr)
135-
push!(t.args, Expr(:(=), mvar, loadexpr))
134+
push!(t.args, loadexpr)
135+
# push!(q.args, Expr(:(=), mvar, loadexpr))
136136
end
137-
# push!(q.args, Expr(:(=), mvar, Expr(:call, lv(:VecUnroll), t)))
137+
push!(q.args, Expr(:(=), mvar, Expr(:call, lv(:VecUnroll), t)))
138138
else
139139
inds = mem_offset_u(op, td, inds_calc_by_ptr_offset, true, 0)
140140
loadexpr = Expr(:call, lv(:_vload), vptr(op), inds)
@@ -185,6 +185,8 @@ end
185185
@inline firstunroll(x) = x
186186
@inline lastunroll(vu::VecUnroll) = last(getfield(vu,:data))
187187
@inline lastunroll(x) = x
188+
@inline unmm(x) = x
189+
@inline unmm(x::MM) = getfield(x, :i)
188190
function lower_load_for_optranslation!(
189191
q::Expr, op::Operation, posindicator::UInt8, ls::LoopSet, td::UnrollArgs, mask::Bool, translationind::Int
190192
)
@@ -217,7 +219,7 @@ function lower_load_for_optranslation!(
217219
if i == translationind
218220
gespinds.args[i] = Expr(:call, lv(Core.ifelse(equal_steps, :firstunroll, :lastunroll)), gespinds.args[i])
219221
# else
220-
# gespinds.args[i] = Expr(:call, lv(:data), gespinds.args[i])
222+
# gespinds.args[i] = Expr(:call, lv(:unmm), gespinds.args[i])
221223
end
222224
end
223225
push!(q.args, Expr(:(=), gptr, Expr(:call, lv(:gesp), ptr, gespinds)))
@@ -267,6 +269,7 @@ function lower_load_for_optranslation!(
267269
broadcasted_data = broadcastedname(variable_name_data)
268270
push!(q.args, :($broadcasted_data = getfield($(broadcastedname(variable_name_u)), 1)))
269271
end
272+
gf = GlobalRef(Core,:getfield)
270273
for u₂ 0:u₂max-1
271274
variable_name_u₂ = Symbol(variable_name(op, u₂), '_', u₁)
272275
t = Expr(:tuple)
@@ -279,14 +282,14 @@ function lower_load_for_optranslation!(
279282
else
280283
u - u₂ + u₂max - 1
281284
end
282-
push!(t.args, :(getfield($variable_name_data, $uu)))
285+
push!(t.args, :($gf($variable_name_data, $uu)))
283286
if shouldbroadcast
284-
push!(tb.args, :(getfield($broadcasted_data, $uu)))
287+
push!(tb.args, :($gf($broadcasted_data, $uu)))
285288
end
286289
end
287-
push!(q.args, :($variable_name_u₂ = VecUnroll($t)))
290+
push!(q.args, Expr(:(=), variable_name_u₂, Expr(:call, lv(:VecUnroll), t)))
288291
if shouldbroadcast
289-
push!(q.args, :($(broadcastedname(variable_name_u₂)) = VecUnroll($tb)))
292+
push!(q.args, Expr(:(=), broadcastedname(variable_name_u₂), Expr(:call, lv(:VecUnroll), tb)))
290293
end
291294
end
292295
nothing
@@ -324,8 +327,8 @@ function _lower_load!(
324327
)
325328
omop = offsetloadcollection(ls)
326329
batchid, opind = omop.batchedcollectionmap[identifier(op)]
327-
# @show batchid == 0 (!isvectorized(op)) rejectinterleave(op, td.vloop, idsformap)
328-
if batchid == 0 || (!isvectorized(op)) || (rejectinterleave(op, td.vloop, omop.batchedcollections[batchid]))
330+
# @show batchid == 0 (!isvectorized(op)) rejectinterleave(ls, op, td.vloop, idsformap)
331+
if batchid == 0 || (!isvectorized(op)) || (rejectinterleave(ls, op, td.vloop, omop.batchedcollections[batchid]))
329332
lower_load_no_optranslation!(q, ls, op, td, mask, inds_calc_by_ptr_offset)
330333
elseif opind == 1# only lower loads once
331334
# I do not believe it is possible for `opind == 1` to be lowered after an operation depending on a different opind.
@@ -337,9 +340,26 @@ function _lower_load!(
337340
lower_load_collection!(q, ls, opidmap, idsformap, td, mask, inds_calc_by_ptr_offset)
338341
end
339342
end
340-
function rejectinterleave(op::Operation, vloop::Loop, idsformap::SubArray{Tuple{Int,Int}, 1, Vector{Tuple{Int,Int}}, Tuple{UnitRange{Int}}, true})
343+
function addive_loopinductvar_only(op::Operation)
344+
isloopvalue(op) && return true
345+
iscompute(op) || return false
346+
additive_instr = (:add_fast, :(+), :vadd, :identity, :sub_fast, :(-), :vsub)
347+
Base.sym_in(instruction(op).instr, additive_instr) || return false
348+
return all(addive_loopinductvar_only, parents(op))
349+
end
350+
351+
function rejectinterleave(ls::LoopSet, op::Operation, vloop::Loop, idsformap::SubArray{Tuple{Int,Int}, 1, Vector{Tuple{Int,Int}}, Tuple{UnitRange{Int}}, true})
341352
vloopsym = vloop.itersymbol; strd = step(vloop)
342353
isknown(strd) || return true
354+
# TODO: reject if there is a vectorized !loopedindex index
355+
for (li,ind) zip(op.ref.loopedindex,getindicesonly(op))
356+
li && continue
357+
for indop operations(ls)
358+
if (name(indop) === ind) && isvectorized(indop)
359+
addive_loopinductvar_only(op) || return true # so that it is `MM`
360+
end
361+
end
362+
end
343363
(first(getindices(op)) === vloopsym) && (length(idsformap) first(getstrides(op)) * gethint(strd))
344364
end
345365
function lower_load_collection!(
@@ -375,7 +395,7 @@ function lower_load_collection!(
375395
# not using `add_memory_mask!(storeexpr, op, ua, mask)` because we checked `isconditionalmemop` earlier in `lower_load_collection!`
376396
(mask && isvectorized(op)) && push!(loadexpr.args, MASKSYMBOL)
377397
push!(loadexpr.args, falseexpr, rs)
378-
collectionname = Symbol(vp, "##collection##number", first(first(idsformap)), "##size##", nouter, "##u₁##", u₁)
398+
collectionname = Symbol(vp, "##collection##number#", opidmap[first(first(idsformap))], "#", suffix, "##size##", nouter, "##u₁##", u₁)
379399
# getfield to extract data from `VecUnroll` object, so we have a tuple
380400
push!(q.args, Expr(:(=), collectionname, Expr(:call, :getfield, loadexpr, 1)))
381401
u = Core.ifelse(isu₁unrolled(op), u₁, 1)

src/codegen/lower_memory_common.jl

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ function addoffset!(ret::Expr, indvectorized::Bool, vloopstride, indexstride, in
8181
end
8282
end
8383

84-
function addoffset!(ret::Expr, indvectorized::Bool, unrolledsteps, vloopstride, indexstride, index, offset, calcbypointeroffset::Bool) # 8 -> 7 args
84+
function addvectoroffset!(ret::Expr, indvectorized::Bool, unrolledsteps, vloopstride, indexstride, index, offset, calcbypointeroffset::Bool) # 8 -> 7 args
8585
# if _iszero(unrolledsteps) # if no steps, pass through; should be unreachable
8686
# addoffset!(ret, indvectorized, vloopstride, indexstride, index, offset, calcbypointeroffset)
8787
# else
@@ -98,22 +98,22 @@ function addoffset!(ret::Expr, indvectorized::Bool, unrolledsteps, vloopstride,
9898
end
9999
end
100100
# unrolledloopstride is a stride multiple on `unrolledsteps`
101-
function addoffset!(
102-
ret::Expr, indvectorized::Bool, unrolledsteps::Int, unrolledloopstride, vloopstride, indexstride::Integer, index, offset::Integer, calcbypointeroffset::Bool
103-
) # 9 -> (7 or 8) args
101+
function addvectoroffset!(
102+
ret::Expr, mm::Bool, unrolledsteps::Int, unrolledloopstride, vloopstride, indexstride::Integer, index, offset::Integer, calcbypointeroffset::Bool, indvectorized::Bool
103+
) # 10 -> (7 or 8) args
104104
if unrolledsteps == 0 # neither unrolledloopstride or indexstride can be 0
105-
addoffset!(ret, indvectorized, vloopstride, indexstride, index, offset, calcbypointeroffset) # 7 arg
105+
addoffset!(ret, mm, vloopstride, indexstride, index, offset, calcbypointeroffset) # 7 arg
106106
elseif indvectorized
107107
unrolledsteps *= indexstride
108108
if isknown(unrolledloopstride)
109-
addoffset!(ret, indvectorized, gethint(unrolledloopstride)*unrolledsteps, vloopstride, indexstride, index, offset, calcbypointeroffset) # 8 arg
109+
addvectoroffset!(ret, mm, gethint(unrolledloopstride)*unrolledsteps, vloopstride, indexstride, index, offset, calcbypointeroffset) # 8 arg
110110
elseif unrolledsteps == 1
111-
addoffset!(ret, indvectorized, unrolledloopstride, vloopstride, indexstride, index, offset, calcbypointeroffset) # 8 arg
111+
addvectoroffset!(ret, mm, unrolledloopstride, vloopstride, indexstride, index, offset, calcbypointeroffset) # 8 arg
112112
else
113-
addoffset!(ret, indvectorized, mulexpr(unrolledloopstride,unrolledsteps), vloopstride, indexstride, index, offset, calcbypointeroffset) # 8 arg
113+
addvectoroffset!(ret, mm, mulexpr(unrolledloopstride,unrolledsteps), vloopstride, indexstride, index, offset, calcbypointeroffset) # 8 arg
114114
end
115115
else
116-
addoffset!(ret, indvectorized, vloopstride, indexstride, index, offset + unrolledsteps, calcbypointeroffset) # 7 arg
116+
addoffset!(ret, mm, vloopstride, indexstride, index, offset + unrolledsteps, calcbypointeroffset) # 7 arg
117117
end
118118
end
119119

@@ -137,15 +137,28 @@ function mem_offset(op::Operation, td::UnrollArgs, inds_calc_by_ptr_offset::Vect
137137
stride = strides[n] % Int
138138
@unpack vstep = td
139139
if loopedindex[n]
140-
addoffset!(ret, indvectorized, vstep, stride, ind, offset, inds_calc_by_ptr_offset[n] | (ind === CONSTANTZEROINDEX))
140+
addoffset!(ret, indvectorized, vstep, stride, ind, offset, inds_calc_by_ptr_offset[n] | (ind === CONSTANTZEROINDEX)) # 7 arg
141141
else
142142
offset -= 1
143143
newname, parent = symbolind(ind, op, td)
144144
# _mmi = indvectorized && parent !== op && (!isvectorized(parent))
145145
# addoffset!(ret, newname, stride, offset, _mmi)
146146
_mmi = indvectorized && parent !== op && (!isvectorized(parent))
147147
@assert !_mmi "Please file an issue with an example of how you got this."
148-
addoffset!(ret, 0, newname, offset, false)
148+
if isu₁unrolled(parent) & (td.u₁ > 1)
149+
gf = GlobalRef(Core,:getfield)
150+
firstnew = Expr(:call, gf, Expr(:call, gf, newname, 1), 1, false)
151+
if isvectorized(parent) & (!_mm)
152+
firstnew = Expr(:call, lv(:unmm), firstnew)
153+
end
154+
addoffset!(ret, 0, firstnew, offset, false)
155+
else
156+
if isvectorized(parent) & (!_mm)
157+
addoffset!(ret, 0, Expr(:call, lv(:unmm), newname), offset, false)
158+
else
159+
addoffset!(ret, 0, newname, offset, false)
160+
end
161+
end
149162
end
150163
end
151164
ret
@@ -249,20 +262,42 @@ function mem_offset_u(op::Operation, td::UnrollArgs, inds_calc_by_ptr_offset::Ve
249262
ind_by_offset = inds_calc_by_ptr_offset[n] | (ind === CONSTANTZEROINDEX)
250263
offset = convert(Int, offsets[n])
251264
stride = convert(Int, strides[n])
252-
indvectorized = _mm & (ind === vloopsym)
265+
indvectorized = ind === vloopsym
266+
indvectorizedmm = _mm & indvectorized
253267
if ind === u₁loopsym
254-
addoffset!(ret, indvectorized, incr₁, u₁step, vstep, stride, ind, offset, ind_by_offset)
268+
addvectoroffset!(ret, indvectorizedmm, incr₁, u₁step, vstep, stride, ind, offset, ind_by_offset, indvectorized) # 9 arg
255269
elseif ind === u₂loopsym
256-
addoffset!(ret, indvectorized, incr₂, u₂step, vstep, stride, ind, offset, ind_by_offset)
270+
# if isstore(op)
271+
# @show indvectorized, ind === vloopsym, u₂loopsym, incr₂
272+
# end
273+
addvectoroffset!(ret, indvectorizedmm, incr₂, u₂step, vstep, stride, ind, offset, ind_by_offset, indvectorized) # 9 arg
257274
elseif loopedindex[n]
258-
addoffset!(ret, indvectorized, vstep, stride, ind, offset, ind_by_offset)
275+
addoffset!(ret, indvectorizedmm, vstep, stride, ind, offset, ind_by_offset) # 7 arg
259276
else
260277
offset -= 1
261278
newname, parent = symbolind(ind, op, td)
262-
_mmi = _mm && indvectorized && parent !== op && (!isvectorized(parent))
279+
_mmi = indvectorizedmm && parent !== op && (!isvectorized(parent))
263280
# addoffset!(ret, newname, 1, offset, _mmi)
264281
@assert !_mmi "Please file an issue with an example of how you got this."
265-
if stride == 1
282+
if isvectorized(parent) & (!_mm)
283+
if isu₁unrolled(parent) & (td.u₁ > 1)
284+
gf = GlobalRef(Core,:getfield)
285+
newname_unmm = Expr(:call, lv(:unmm), Expr(:call, gf, Expr(:call, gf, newname, 1), incr₁+1, false))
286+
else
287+
newname_unmm = Expr(:call, lv(:unmm), newname)
288+
end
289+
if stride 1
290+
newname_unmm = mulexpr(newname_unmm,stride)
291+
end
292+
addoffset!(ret, 0, newname_unmm, offset, false)
293+
elseif isu₁unrolled(parent) & (td.u₁ > 1)
294+
gf = GlobalRef(Core,:getfield)
295+
firstnew = Expr(:call, gf, Expr(:call, gf, newname, 1), incr₁+1, false)
296+
if stride 1
297+
firstnew = mulexpr(firstnew,stride)
298+
end
299+
addoffset!(ret, 0, firstnew, offset, false)
300+
elseif stride == 1
266301
addoffset!(ret, 0, newname, offset, false)
267302
else
268303
addoffset!(ret, 0, mulexpr(newname,stride), offset, false)
@@ -276,7 +311,7 @@ end
276311
@inline and_last(a, b) = a & b
277312
@generated function and_last(v::VecUnroll{N}, m) where {N}
278313
q = Expr(:block, Expr(:meta,:inline), :(vd = data(v)))
279-
t = Expr(:tuple)
314+
t = Expr(:call, lv(:promote))
280315
for n 1:N
281316
push!(t.args, :(getfield(vd, $n, false)))
282317
end

src/codegen/lower_store.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@ function reduce_expr_u₂(toreduct::Symbol, instr::Instruction, u₂::Int)
2626
end
2727
Expr(:call, lv(:reduce_tup), reduce_to_onevecunroll(instr), t)
2828
end
29-
function reduce_expr!(q::Expr, toreduct::Symbol, instr::Instruction, u₁::Int, u₂::Int)
29+
function reduce_expr!(q::Expr, toreduct::Symbol, instr::Instruction, u₁::Int, u₂::Int, isu₁unrolled::Bool)
3030
if u₂ != -1
3131
_toreduct = Symbol(toreduct, 0)
3232
push!(q.args, Expr(:(=), _toreduct, reduce_expr_u₂(toreduct, instr, u₂)))
3333
else
3434
_toreduct = Symbol(toreduct, '_', u₁)
3535
end
36-
if u₁ == 1
36+
if (u₁ == 1) | (~isu₁unrolled)
3737
push!(q.args, Expr(:(=), Symbol(toreduct, "##onevec##"), _toreduct))
3838
else
3939
push!(q.args, Expr(:(=), Symbol(toreduct, "##onevec##"), Expr(:call, lv(reduction_to_single_vector(instr)), _toreduct)))
@@ -105,7 +105,7 @@ function lower_store!(
105105

106106
omop = offsetloadcollection(ls)
107107
batchid, opind = omop.batchedcollectionmap[identifier(op)]
108-
if ((batchid 0) && isvectorized(op)) && (!rejectinterleave(op, vloop, omop.batchedcollections[batchid]))
108+
if ((batchid 0) && isvectorized(op)) && (!rejectinterleave(ls, op, vloop, omop.batchedcollections[batchid]))
109109
(opind == 1) && lower_store_collection!(q, ls, op, ua, mask, inds_calc_by_ptr_offset)
110110
return
111111
end
@@ -142,7 +142,8 @@ function lower_store!(
142142
else
143143
Expr(:call, lv(:_vstore!), lv(reductfunc), vptr(op), mvaru, inds)
144144
end
145-
add_memory_mask!(storeexpr, op, ua, mask & ((u == u₁) | isvectorized(op)))
145+
domask = mask && (isvectorized(op) & ((u == u₁) | (vloopsym !== u₁loopsym)))
146+
add_memory_mask!(storeexpr, op, ua, domask)# & ((u == u₁) | isvectorized(op)))
146147
push!(storeexpr.args, falseexpr, trueexpr, falseexpr, rs)
147148
push!(q.args, storeexpr)
148149
end
@@ -176,7 +177,7 @@ function donot_tile_store(ls::LoopSet, op::Operation, vloop::Loop, reductfunc::S
176177

177178
omop = offsetloadcollection(ls)
178179
batchid, opind = omop.batchedcollectionmap[identifier(op)]
179-
return ((batchid 0) && isvectorized(op)) && (!rejectinterleave(op, vloop, omop.batchedcollections[batchid]))
180+
return ((batchid 0) && isvectorized(op)) && (!rejectinterleave(ls, op, vloop, omop.batchedcollections[batchid]))
180181
end
181182

182183
# VectorizationBase implements optimizations for certain grouped stores
@@ -212,7 +213,7 @@ function lower_tiled_store!(blockq::Expr, op::Operation, ls::LoopSet, ua::Unroll
212213
mvar = Symbol(variable_name(opp, t), '_', u)
213214
push!(tup.args, mvar)
214215
end
215-
vut = :(VecUnroll($tup)) # `VecUnroll` of `VecUnroll`s
216+
vut = Expr(:call, lv(:VecUnroll), tup) # `VecUnroll` of `VecUnroll`s
216217
inds = mem_offset_u(op, ua, inds_calc_by_ptr_offset, false)
217218
unrollcurl₂ = unrolled_curly(op, u₂, u₂loop, vloop, mask)
218219
falseexpr = Expr(:call, lv(:False)); trueexpr = Expr(:call, lv(:True)); rs = staticexpr(reg_size(ls));

src/codegen/lowering.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ function reduce_expr!(q::Expr, ls::LoopSet, U::Int)
540540
var = name(op)
541541
mvar = mangledvar(op)
542542
instr = instruction(op)
543-
reduce_expr!(q, mvar, instr, u1f, u2f)
543+
reduce_expr!(q, mvar, instr, u1f, u2f, isu₁unrolled(op))
544544
if !iszero(length(ls.opdict))
545545
if (isu₁unrolled(op) | isu₂unrolled(op))
546546
push!(q.args, Expr(:(=), var, Expr(:call, lv(reduction_scalar_combine(instr)), Symbol(mvar, "##onevec##"), var)))

0 commit comments

Comments
 (0)