Skip to content

Commit b2ec589

Browse files
committed
Make index parsing able to handle more complicated expressions without falling back to operations.
1 parent 4074301 commit b2ec589

File tree

7 files changed

+286
-138
lines changed

7 files changed

+286
-138
lines changed

src/codegen/lower_load.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ end
187187
@inline lastunroll(vu::VecUnroll) = last(getfield(vu,:data))
188188
@inline lastunroll(x) = x
189189
function lower_load_for_optranslation!(
190-
q::Expr, op::Operation, ls::LoopSet, td::UnrollArgs, mask::Bool, translationind::Int
190+
q::Expr, op::Operation, posindicator::UInt8, ls::LoopSet, td::UnrollArgs, mask::Bool, translationind::Int
191191
)
192192
@unpack u₁loop, u₂loop, vloop, u₁, u₂max, suffix = td
193193

@@ -206,10 +206,9 @@ function lower_load_for_optranslation!(
206206
step₁ = gethint(step(u₁loop))
207207
step₂ = gethint(step(u₂loop))
208208

209-
210209
# abs of steps are equal
211210
#
212-
equal_steps = step₁ == step₂
211+
equal_steps = (step₁ == step₂) (posindicator 0x03)
213212
_td = UnrollArgs(u₁loop, u₂loop, vloop, total_unroll, u₂max, Core.ifelse(equal_steps, 0, u₂max - 1))
214213
gespinds = mem_offset(op, _td, inds_by_ptroff, false)
215214
ptr = vptr(op)
@@ -286,12 +285,11 @@ function lower_load!(
286285
if (suffix != -1) && ls.loadelimination[]
287286
if (u₁ > 1) & (u₂max > 1)
288287
istr, ispl = isoptranslation(ls, op, UnrollSymbols(u₁loopsym, u₂loopsym, vloopsym))
289-
else
290-
istr, ispl = 0, false
288+
if istr 0
289+
return lower_load_for_optranslation!(q, op, ispl, ls, td, mask, istr)
290+
end
291291
end
292-
if !iszero(istr) & ispl
293-
return lower_load_for_optranslation!(q, op, ls, td, mask, istr)
294-
elseif (suffix > 0) && (u₂loopsym !== vloopsym)
292+
if (suffix > 0) && (u₂loopsym !== vloopsym)
295293
mno, id = maxnegativeoffset(ls, op, u₂loopsym)
296294
if -suffix < mno < 0 # already checked that `suffix != -1` above
297295
varnew = variable_name(op, suffix)

src/codegen/lower_memory_common.jl

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,7 @@ function symbolind(ind::Symbol, op::Operation, td::UnrollArgs)
1010
@unpack u₁, u₁loopsym, u₂loopsym, u₂max, suffix = td
1111
parent = parents(op)[id]
1212
pvar, u₁op, u₂op = variable_name_and_unrolled(parent, u₁loopsym, u₂loopsym, u₂max, suffix)
13-
# pvar = if u₂loopsym ∈ loopdependencies(parent)
14-
# variable_name(parent, suffix)
15-
# else
16-
# mangledvar(parent)
17-
# end
18-
u = u₁op ? u₁ : 1
19-
ex = Symbol(pvar, '_', u)
20-
Expr(:call, lv(:staticm1), ex), parent
13+
Symbol(pvar, '_', Core.ifelse(u₁op, u₁, 1)), parent
2114
end
2215

2316
staticexpr(x::Int) = Expr(:call, Expr(:curly, lv(:Static), x))
@@ -146,6 +139,7 @@ function mem_offset(op::Operation, td::UnrollArgs, inds_calc_by_ptr_offset::Vect
146139
if loopedindex[n]
147140
addoffset!(ret, indvectorized, vstep, stride, ind, offset, inds_calc_by_ptr_offset[n] | (ind === CONSTANTZEROINDEX))
148141
else
142+
offset -= 1
149143
newname, parent = symbolind(ind, op, td)
150144
# _mmi = indvectorized && parent !== op && (!isvectorized(parent))
151145
# addoffset!(ret, newname, stride, offset, _mmi)
@@ -213,10 +207,10 @@ function unrolled_curly(op::Operation, u₁::Int, u₁loop::Loop, vloop::Loop, m
213207
# Expr(:call, Expr(:curly, lv(:Unroll), AU, 1, u₁, AV, intvecsym, M, 1), ind)
214208
Expr(:curly, lv(:Unroll), AU, gethint(step(u₁loop)), u₁, AV, intvecsym, M, X)
215209
else
216-
if isone(step(u₁loop))
210+
if isone(X)
217211
Expr(:curly, lv(:Unroll), AU, intvecsym, u₁, AV, intvecsym, M, X)
218212
else
219-
unrollstepexpr = :(Int($(mulexpr(VECTORWIDTHSYMBOL, step(u₁loop)))))
213+
unrollstepexpr = :(Int($(mulexpr(VECTORWIDTHSYMBOL, X))))
220214
Expr(:curly, lv(:Unroll), AU, unrollstepexpr, u₁, AV, intvecsym, M, X)
221215
end
222216
end
@@ -263,11 +257,16 @@ function mem_offset_u(op::Operation, td::UnrollArgs, inds_calc_by_ptr_offset::Ve
263257
elseif loopedindex[n]
264258
addoffset!(ret, indvectorized, vstep, stride, ind, offset, ind_by_offset)
265259
else
260+
offset -= 1
266261
newname, parent = symbolind(ind, op, td)
267-
# _mmi = _mm && indvectorized && parent !== op && (!isvectorized(parent))
262+
_mmi = _mm && indvectorized && parent !== op && (!isvectorized(parent))
268263
# addoffset!(ret, newname, 1, offset, _mmi)
269264
@assert !_mmi "Please file an issue with an example of how you got this."
270-
addoffset!(ret, newname, 1, offset, false)
265+
if stride == 1
266+
addoffset!(ret, 0, newname, offset, false)
267+
else
268+
addoffset!(ret, 0, mulexpr(newname,stride), offset, false)
269+
end
271270
end
272271
end
273272
end

src/codegen/lowering.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,15 @@ function lower_block(
4242
for prepost 1:2
4343
# !u₁ && !u₂
4444
lower!(blockq, ops[1,1,prepost,n], ls, unrollsyms, u₁, u₂, -1, mask, true, true)
45+
# isu₁unrolled, isu₂unrolled, after_loop, n
4546
opsv1 = ops[1,2,prepost,n]
4647
opsv2 = ops[2,2,prepost,n]
4748
if length(opsv1) + length(opsv2) > 0
4849
nstores = 0
4950
iszero(length(opsv1)) || (nstores += sum(isstore, opsv1))
5051
iszero(length(opsv2)) || (nstores += sum(isstore, opsv2))
5152
# if nstores
52-
if (length(opsv1) + length(opsv2) == nstores) # all_u₂_ops_store
53+
if (length(opsv1) + length(opsv2) == nstores) && u₂ > 1 # all_u₂_ops_store
5354
lower!(blockq, ops[2,1,prepost,n], ls, unrollsyms, u₁, u₂, -1, mask, true, true) # for u ∈ 0:u₁-1
5455
lower_tiled_store!(blockq, opsv1, opsv2, ls, unrollsyms, u₁, u₂, mask)
5556
else

src/modeling/determinestrategy.jl

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -674,17 +674,19 @@ function isoptranslation(ls::LoopSet, op::Operation, unrollsyms::UnrollSymbols)
674674

675675
istranslation = 0
676676
inds = getindices(op); li = op.ref.loopedindex
677-
translationplus = false
678677
for i eachindex(li)
679678
if !li[i]
680679
opp = findparent(ls, inds[i + (first(inds) === DISCONTIGUOUS)])
681-
if instruction(opp).instr (:+, :-) && isu₁unrolled(opp) && isu₂unrolled(opp)
682-
istranslation = i
683-
translationplus = instruction(opp).instr === :+
680+
if isu₁unrolled(opp) && isu₂unrolled(opp)
681+
isadd = instruction(opp).instr === :(+)
682+
issub = instruction(opp).instr === :(-)
683+
if isadd | issub
684+
return i, isadd
685+
end
684686
end
685687
end
686688
end
687-
istranslation, translationplus
689+
0, false
688690
end
689691
function maxnegativeoffset(ls::LoopSet, op::Operation, u::Symbol)
690692
mno = typemin(Int)

0 commit comments

Comments
 (0)