Skip to content

Commit 9e7a783

Browse files
committed
fix a couple strided loop bugs, fixes #348
1 parent ce37b6b commit 9e7a783

12 files changed

+333
-241
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.12.89"
4+
version = "0.12.90"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/codegen/lower_compute.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ function parent_unroll_status(op::Operation, u₁loop::Symbol, u₂loop::Symbol,
6161
parents_u₁syms, parents_u₂syms
6262
end
6363

64-
function _add_loopvalue!(ex::Expr, loopval::Symbol, vloop::Loop, u::Int)
64+
function _add_loopvalue!(ex::Expr, loopval::Symbol, vloop::Loop, u::Int, loop::Loop)
6565
vloopsym = vloop.itersymbol
6666
if loopval === vloopsym
6767
if iszero(u)
@@ -77,24 +77,26 @@ function _add_loopvalue!(ex::Expr, loopval::Symbol, vloop::Loop, u::Int)
7777
end
7878
elseif u == 0
7979
push!(ex.args, loopval)
80+
elseif isknown(step(loop))
81+
push!(ex.args, Expr(:call, lv(:vadd_nsw), loopval, staticexpr(u*gethint(step(loop)))))
8082
else
81-
push!(ex.args, Expr(:call, lv(:vadd_nsw), loopval, staticexpr(u)))
83+
push!(ex.args, Expr(:call, lv(:vadd_nsw), loopval, mulexpr(step(loop), u)))
8284
end
8385
end
84-
function add_loopvalue!(instrcall::Expr, loopval, ua::UnrollArgs, u₁::Int)
86+
function add_loopvalue!(instrcall::Expr, loopval, ua::UnrollArgs, u₁::Int, loop::Loop)
8587
@unpack u₁loopsym, u₂loopsym, vloopsym, vloop, suffix = ua
8688
if loopval === u₁loopsym #parentsunrolled[n]
8789
if isone(u₁)
88-
_add_loopvalue!(instrcall, loopval, vloop, 0)
90+
_add_loopvalue!(instrcall, loopval, vloop, 0, loop)
8991
else
9092
t = Expr(:tuple)
9193
for u 0:u₁-1
92-
_add_loopvalue!(t, loopval, vloop, u)
94+
_add_loopvalue!(t, loopval, vloop, u, loop)
9395
end
9496
push!(instrcall.args, Expr(:call, lv(:VecUnroll), t))
9597
end
9698
elseif suffix > 0 && loopval === u₂loopsym
97-
_add_loopvalue!(instrcall, loopval, vloop, suffix)
99+
_add_loopvalue!(instrcall, loopval, vloop, suffix, loop)
98100
elseif loopval === vloopsym
99101
push!(instrcall.args, _MMind(loopval, step(vloop)))
100102
else
@@ -519,7 +521,7 @@ function lower_compute!(
519521
opp = parents_op[n]
520522
if isloopvalue(opp)
521523
loopval = first(loopdependencies(opp))
522-
add_loopvalue!(instrcall, loopval, ua, u₁)
524+
add_loopvalue!(instrcall, loopval, ua, u₁, getloop(ls,loopval))
523525
elseif name(opp) === name(op)
524526
selfdep = n
525527
if ((isvectorized(opp) && !isvectorized(op))) ||

src/codegen/lower_load.jl

Lines changed: 86 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,6 @@ function lower_load_no_optranslation!(
151151
loadexpr = Expr(:call, lv(:_vload), sptr(op), inds)
152152
add_memory_mask!(loadexpr, op, td, mask, ls, 0)
153153
push!(loadexpr.args, falseexpr, rs) # unaligned load
154-
# @show op loadexpr
155154
push!(q.args, Expr(:(=), mvar, loadexpr))
156155
elseif (u₁ > 1) & opu₁
157156
t = Expr(:tuple)
@@ -417,7 +416,6 @@ function rejectinterleave(ls::LoopSet, op::Operation, vloop::Loop, idsformap::Su
417416
end
418417
end
419418
vloopsym = vloop.itersymbol;
420-
# @show op first(getindices(op)) length(idsformap), first(getstrides(op)), gethint(strd)
421419
(first(getindices(op)) === vloopsym) && (length(idsformap) abs(first(getstrides(op)) * gethint(strd)))
422420
end
423421
# function lower_load_collection_manual_u₁unroll!(
@@ -436,100 +434,99 @@ end
436434
# op.mangledvariable = _mvar
437435
# end
438436
function lower_load_collection!(
439-
q::Expr, ls::LoopSet, opidmap::Vector{Int},
440-
idsformap::SubArray{Tuple{Int,Int}, 1, Vector{Tuple{Int,Int}}, Tuple{UnitRange{Int}}, true},
441-
ua::UnrollArgs, mask::Bool, inds_calc_by_ptr_offset::Vector{Bool}
437+
q::Expr, ls::LoopSet, opidmap::Vector{Int},
438+
idsformap::SubArray{Tuple{Int,Int}, 1, Vector{Tuple{Int,Int}}, Tuple{UnitRange{Int}}, true},
439+
ua::UnrollArgs, mask::Bool, inds_calc_by_ptr_offset::Vector{Bool}
442440
)
443-
@unpack u₁, u₁loop, u₁loopsym, u₂loopsym, vloopsym, vloop, suffix = ua
441+
@unpack u₁, u₁loop, u₁loopsym, u₂loopsym, vloopsym, vloop, suffix = ua
444442

445-
ops = operations(ls)
446-
nouter = length(idsformap)
447-
# ua = UnrollArgs(nouter, unrollsyms, u₂, 0)
448-
# idsformap contains (index, offset) pairs
449-
op = ops[opidmap[first(first(idsformap))]]
450-
# if isu₁unrolled(op) && u₁ > 1 && !isknown(step(u₁loop))
451-
# return lower_load_collection_manual_u₁unroll!(
452-
# q, ls, opidmap, idsformap, ua,
453-
# mask, inds_calc_by_ptr_offset, op
454-
# )
455-
# end
456-
opindices = getindices(op)
457-
interleave = first(opindices) === vloopsym
458-
# construct dummy unrolled loop
459-
offset_dummy_loop = Loop(first(opindices), MaybeKnown(1), MaybeKnown(1024), MaybeKnown(1), Symbol(""), Symbol(""))
460-
unrollcurl₂ = unrolled_curly(op, nouter, offset_dummy_loop, vloop, mask, 1) # interleave always 1 here
461-
inds = mem_offset_u(op, ua, inds_calc_by_ptr_offset, false, 0, ls, false)
462-
falseexpr = Expr(:call, lv(:False)); rs = staticexpr(reg_size(ls));
443+
ops = operations(ls)
444+
nouter = length(idsformap)
445+
# ua = UnrollArgs(nouter, unrollsyms, u₂, 0)
446+
# idsformap contains (index, offset) pairs
447+
op = ops[opidmap[first(first(idsformap))]]
448+
# if isu₁unrolled(op) && u₁ > 1 && !isknown(step(u₁loop))
449+
# return lower_load_collection_manual_u₁unroll!(
450+
# q, ls, opidmap, idsformap, ua,
451+
# mask, inds_calc_by_ptr_offset, op
452+
# )
453+
# end
454+
opindices = getindices(op)
455+
# construct dummy unrolled loop
456+
offset_dummy_loop = Loop(first(opindices), MaybeKnown(1), MaybeKnown(1024), MaybeKnown(1), Symbol(""), Symbol(""))
457+
unrollcurl₂ = unrolled_curly(op, nouter, offset_dummy_loop, vloop, mask, 1) # interleave always 1 here
458+
inds = mem_offset_u(op, ua, inds_calc_by_ptr_offset, false, 0, ls, false)
459+
falseexpr = Expr(:call, lv(:False)); rs = staticexpr(reg_size(ls));
463460

464-
opu₁, opu₂ = isunrolled_sym(op, u₁loopsym, u₂loopsym, vloopsym, ls)
465-
manualunrollu₁ = if opu₁ && u₁ > 1 # both unrolled
466-
if isknown(step(u₁loop)) && sum(Base.Fix2(===,u₁loopsym), getindicesonly(op)) == 1
467-
if interleave # TODO: handle this better than using `rejectinterleave`
468-
interleaveval = -nouter
469-
else
470-
interleaveval = 0
471-
end
472-
unrollcurl₁ = unrolled_curly(op, u₁, ua.u₁loop, vloop, mask, interleaveval)
473-
inds = Expr(:call, unrollcurl₁, inds)
474-
false
475-
else
476-
true # u₁ > 1 already checked to reach here
477-
end
461+
opu₁, opu₂ = isunrolled_sym(op, u₁loopsym, u₂loopsym, vloopsym, ls)
462+
manualunrollu₁ = if opu₁ && u₁ > 1 # both unrolled
463+
if isknown(step(u₁loop)) && sum(Base.Fix2(===,u₁loopsym), getindicesonly(op)) == 1
464+
# if first(opindices) === u₁loopsym#vloopsym
465+
# interleaveval = -nouter
466+
# else
467+
interleaveval = 0
468+
# end
469+
unrollcurl₁ = unrolled_curly(op, u₁, ua.u₁loop, vloop, mask, interleaveval)
470+
inds = Expr(:call, unrollcurl₁, inds)
471+
false
478472
else
479-
false
473+
true # u₁ > 1 already checked to reach here
480474
end
481-
uinds = Expr(:call, unrollcurl₂, inds)
482-
sptrsym = sptr!(q, op)
483-
loadexpr = Expr(:call, lv(:_vload), sptrsym, uinds)
484-
# not using `add_memory_mask!(storeexpr, op, ua, mask, ls, 0)` because we checked `isconditionalmemop` earlier in `lower_load_collection!`
485-
u₁vectorized = u₁loopsym === vloopsym
486-
if (mask && isvectorized(op))
487-
if !(manualunrollu₁ & u₁vectorized)
488-
push!(loadexpr.args, MASKSYMBOL)
489-
end
475+
else
476+
false
477+
end
478+
uinds = Expr(:call, unrollcurl₂, inds)
479+
sptrsym = sptr!(q, op)
480+
loadexpr = Expr(:call, lv(:_vload), sptrsym, uinds)
481+
# not using `add_memory_mask!(storeexpr, op, ua, mask, ls, 0)` because we checked `isconditionalmemop` earlier in `lower_load_collection!`
482+
u₁vectorized = u₁loopsym === vloopsym
483+
if (mask && isvectorized(op))
484+
if !(manualunrollu₁ & u₁vectorized)
485+
push!(loadexpr.args, MASKSYMBOL)
490486
end
491-
push!(loadexpr.args, falseexpr, rs)
492-
collectionname = Symbol(vptr(op), "##collection##number#", opidmap[first(first(idsformap))], "#", suffix, "##size##", nouter, "##u₁##", u₁)
493-
gf = GlobalRef(Core,:getfield)
494-
if manualunrollu₁
495-
masklast = mask & u₁vectorized & isvectorized(op)
496-
extractedvs = Vector{Expr}(undef, length(idsformap))
497-
for i eachindex(extractedvs)
498-
extractedvs[i] = Expr(:tuple)
499-
end
500-
for u 0:u₁-1
501-
collectionname_u = Symbol(collectionname, :_, u)
502-
if u 0
503-
inds = mem_offset_u(op, ua, inds_calc_by_ptr_offset, false, u, ls, false)
504-
uinds = Expr(:call, unrollcurl₂, inds)
505-
loadexpr = copy(loadexpr)
506-
loadexpr.args[3] = Expr(:call, unrollcurl₂, inds)
507-
(((u+1) == u₁) & masklast) && insert!(loadexpr.args, length(loadexpr.args)-1, MASKSYMBOL) # 1 for `falseexpr` pushed at end
508-
end
509-
# unpack_collection!(q, ls, opidmap, idsformap, ua, loadexpr, collectionname, op, false)
510-
push!(q.args, Expr(:(=), collectionname_u, Expr(:call, gf, loadexpr, 1)))
511-
# getfield to extract data from `VecUnroll` object, so we have a tuple
512-
for (i,(opid,o)) enumerate(idsformap)
513-
ext = extractedvs[i]
514-
if (u+1) == u₁
515-
_op = ops[opidmap[opid]]
516-
mvar = Symbol(variable_name(_op, Core.ifelse(opu₂, suffix, -1)), '_', u₁)
517-
push!(q.args, Expr(:(=), mvar, Expr(:call, lv(:VecUnroll), ext)))
518-
end
519-
push!(ext.args, Expr(:call, gf, collectionname_u, i, false))
520-
end
487+
end
488+
push!(loadexpr.args, falseexpr, rs)
489+
collectionname = Symbol(vptr(op), "##collection##number#", opidmap[first(first(idsformap))], "#", suffix, "##size##", nouter, "##u₁##", u₁)
490+
gf = GlobalRef(Core,:getfield)
491+
if manualunrollu₁
492+
masklast = mask & u₁vectorized & isvectorized(op)
493+
extractedvs = Vector{Expr}(undef, length(idsformap))
494+
for i eachindex(extractedvs)
495+
extractedvs[i] = Expr(:tuple)
496+
end
497+
for u 0:u₁-1
498+
collectionname_u = Symbol(collectionname, :_, u)
499+
if u 0
500+
inds = mem_offset_u(op, ua, inds_calc_by_ptr_offset, false, u, ls, false)
501+
uinds = Expr(:call, unrollcurl₂, inds)
502+
loadexpr = copy(loadexpr)
503+
loadexpr.args[3] = Expr(:call, unrollcurl₂, inds)
504+
(((u+1) == u₁) & masklast) && insert!(loadexpr.args, length(loadexpr.args)-1, MASKSYMBOL) # 1 for `falseexpr` pushed at end
505+
end
506+
# unpack_collection!(q, ls, opidmap, idsformap, ua, loadexpr, collectionname, op, false)
507+
push!(q.args, Expr(:(=), collectionname_u, Expr(:call, gf, loadexpr, 1)))
508+
# getfield to extract data from `VecUnroll` object, so we have a tuple
509+
for (i,(opid,o)) enumerate(idsformap)
510+
ext = extractedvs[i]
511+
if (u+1) == u₁
512+
_op = ops[opidmap[opid]]
513+
mvar = Symbol(variable_name(_op, Core.ifelse(opu₂, suffix, -1)), '_', u₁)
514+
push!(q.args, Expr(:(=), mvar, Expr(:call, lv(:VecUnroll), ext)))
521515
end
522-
else
523-
push!(q.args, Expr(:(=), collectionname, Expr(:call, gf, loadexpr, 1)))
524-
# getfield to extract data from `VecUnroll` object, so we have a tuple
525-
u = Core.ifelse(opu₁, u₁, 1)
526-
for (i,(opid,o)) enumerate(idsformap)
527-
extractedv = Expr(:call, gf, collectionname, i, false)
516+
push!(ext.args, Expr(:call, gf, collectionname_u, i, false))
517+
end
518+
end
519+
else
520+
push!(q.args, Expr(:(=), collectionname, Expr(:call, gf, loadexpr, 1)))
521+
# getfield to extract data from `VecUnroll` object, so we have a tuple
522+
u = Core.ifelse(opu₁, u₁, 1)
523+
for (i,(opid,o)) enumerate(idsformap)
524+
extractedv = Expr(:call, gf, collectionname, i, false)
528525

529-
_op = ops[opidmap[opid]]
530-
mvar = Symbol(variable_name(_op, Core.ifelse(opu₂, suffix, -1)), '_', u)
531-
push!(q.args, Expr(:(=), mvar, extractedv))
532-
end
533-
# unpack_collection!(q, ls, opidmap, idsformap, ua, loadexpr, collectionname, op, true)
526+
_op = ops[opidmap[opid]]
527+
mvar = Symbol(variable_name(_op, Core.ifelse(opu₂, suffix, -1)), '_', u)
528+
push!(q.args, Expr(:(=), mvar, extractedv))
534529
end
530+
# unpack_collection!(q, ls, opidmap, idsformap, ua, loadexpr, collectionname, op, true)
531+
end
535532
end

src/codegen/lower_memory_common.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,6 @@ end
113113
function addvectoroffset!(
114114
ret::Expr, mm::Bool, unrolledsteps::Int, unrolledloopstride, vloopstride, indexstride::Integer, index, offset::Integer, calcbypointeroffset::Bool, indvectorized::Bool
115115
) # 10 -> (7 or 8) args
116-
# if !isknown(unrolledloopstride)
117-
# @show unrolledsteps, calcbypointeroffset, _isone(unrolledloopstride)
118-
# end
119116
if unrolledsteps == 0 # neither unrolledloopstride or indexstride can be 0
120117
addoffset!(ret, mm, vloopstride, indexstride, index, offset, calcbypointeroffset) # 7 arg
121118
elseif indvectorized
@@ -203,7 +200,6 @@ function unrolled_curly(op::Operation, u₁::Int, u₁loop::Loop, vloop::Loop, m
203200
# @unpack u₁, u₁loopsym, vloopsym = td
204201
AV = AU = -1
205202
for (n,ind) enumerate(indices)
206-
# @show AU, op, n, ind, vloopsym, u₁loopsym
207203
if li[n]
208204
if ind === vloopsym
209205
@assert AV == -1 # FIXME: these asserts should be replaced with checks that prevent using `unrolled_curly` in these cases (also to be reflected in cost modeling, to avoid those)
@@ -218,12 +214,13 @@ function unrolled_curly(op::Operation, u₁::Int, u₁loop::Loop, vloop::Loop, m
218214
end
219215
else
220216
opp = findop(parents(op), ind)
221-
# @show opp
222217
if isvectorized(opp)
223218
@assert AV == -1
224219
AV = n
225220
end
226-
if (u₁loopsym === CONSTANTZEROINDEX) ? (CONSTANTZEROINDEX loopdependencies(opp)) : (isu₁unrolled(opp) || (ind === u₁loopsym))
221+
# if (u₁loopsym === CONSTANTZEROINDEX) ? (CONSTANTZEROINDEX ∈ loopdependencies(opp)) : (isu₁unrolled(opp) || (ind === u₁loopsym))
222+
# can't check isu₁unrolled(opp) because we may be lying.
223+
if (u₁loopsym === CONSTANTZEROINDEX) ? (CONSTANTZEROINDEX loopdependencies(opp)) : (u₁loopsym loopdependencies(opp) || (ind === u₁loopsym))
227224
@assert AU == -1
228225
AU = n
229226
end
@@ -247,18 +244,24 @@ function unrolled_curly(op::Operation, u₁::Int, u₁loop::Loop, vloop::Loop, m
247244
@assert isknown(step(u₁loop)) "Unrolled loops must have known steps to use `Unroll` type; this is a bug, shouldn't have reached here"
248245
if AV > 0
249246
@assert isknown(step(vloop)) "Vectorized loops must have known steps to use `Unroll` type; this is a bug, shouldn't have reached here."
247+
XU = convert(Int, getstrides(op)[AU]) * gethint(step(u₁loop))
250248
X = convert(Int, getstrides(op)[AV])
251249
X *= gethint(step(vloop))
252250
intvecsym = :(Int($VECTORWIDTHSYMBOL))
253251
if interleave > 0
254252
Expr(:curly, lv(:Unroll), AU, interleave, u₁, AV, intvecsym, M, X)
255253
elseif interleave < 0
256-
unrollstepexpr = :(Int($(mulexpr(VECTORWIDTHSYMBOL, -interleave))))
257-
Expr(:curly, lv(:Unroll), AU, unrollstepexpr, u₁, AV, intvecsym, M, X)
254+
interleave *= -XU
255+
if AU == AV
256+
unrollstepexpr = :(Int($(mulexpr(VECTORWIDTHSYMBOL, interleave))))
257+
Expr(:curly, lv(:Unroll), AU, unrollstepexpr, u₁, AV, intvecsym, M, X)
258+
else
259+
Expr(:curly, lv(:Unroll), AU, interleave, u₁, AV, intvecsym, M, X)
260+
end
258261
else
259262
if vecnotunrolled
260263
# Expr(:call, Expr(:curly, lv(:Unroll), AU, 1, u₁, AV, intvecsym, M, 1), ind)
261-
Expr(:curly, lv(:Unroll), AU, gethint(step(u₁loop)), u₁, AV, intvecsym, M, X)
264+
Expr(:curly, lv(:Unroll), AU, XU, u₁, AV, intvecsym, M, X)
262265
else
263266
if isone(X)
264267
Expr(:curly, lv(:Unroll), AU, intvecsym, u₁, AV, intvecsym, M, X)
@@ -269,7 +272,7 @@ function unrolled_curly(op::Operation, u₁::Int, u₁loop::Loop, vloop::Loop, m
269272
end
270273
end
271274
else
272-
Expr(:curly, lv(:Unroll), AU, gethint(step(u₁loop)), u₁, 0, 1, M, 1)
275+
Expr(:curly, lv(:Unroll), AU, convert(Int, getstrides(op)[AU])*gethint(step(u₁loop)), u₁, 0, 1, M, 1)
273276
end
274277
end
275278
function unrolledindex(op::Operation, td::UnrollArgs, mask::Bool, inds_calc_by_ptr_offset::Vector{Bool}, ls::LoopSet)
@@ -309,9 +312,6 @@ function mem_offset_u(
309312
if ind === u₁loopsym
310313
addvectoroffset!(ret, indvectorizedmm, incr₁, u₁step, vstep, stride, ind, offset, ind_by_offset, indvectorized) # 9 arg
311314
elseif ind === u₂loopsym
312-
# if isstore(op)
313-
# @show indvectorized, ind === vloopsym, u₂loopsym, incr₂
314-
# end
315315
addvectoroffset!(ret, indvectorizedmm, incr₂, u₂step, vstep, stride, ind, offset, ind_by_offset, indvectorized) # 9 arg
316316
elseif loopedindex[n]
317317
addoffset!(ret, indvectorizedmm, vstep, stride, ind, offset, ind_by_offset) # 7 arg

0 commit comments

Comments
 (0)