Skip to content

Commit 331288a

Browse files
committed
Fix bug when unrolls are dynamic
1 parent 704e23c commit 331288a

File tree

6 files changed

+182
-51
lines changed

6 files changed

+182
-51
lines changed

src/codegen/lower_load.jl

Lines changed: 101 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ function lower_load_no_optranslation!(
117117
loopdeps = loopdependencies(op)
118118
# @assert isvectorized(op)
119119
opu₁ = isu₁unrolled(op)
120-
121120
u = ifelse(opu₁, u₁, 1)
122121
mvar = Symbol(variable_name(op, Core.ifelse(isu₂unrolled(op), suffix,-1)), '_', u)
123122
falseexpr = Expr(:call, lv(:False)); rs = staticexpr(reg_size(ls))
@@ -195,22 +194,16 @@ function lower_load_for_optranslation!(
195194
q::Expr, op::Operation, posindicator::UInt8, ls::LoopSet, td::UnrollArgs, mask::Bool, translationind::Int
196195
)
197196
@unpack u₁loop, u₂loop, vloop, u₁, u₂max, suffix = td
198-
199197
# @unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = td
200198
iszero(suffix) || return
201-
202199
total_unroll = u₁ + u₂max - 1
203-
204-
205200
mref = op.ref
206201
inds_by_ptroff = indices_calculated_by_pointer_offsets(ls, mref)
207202
# initial offset pointer
208-
209203
# Unroll directions can be + or -
210204
# we want to start at minimum position.
211205
step₁ = gethint(step(u₁loop))
212206
step₂ = gethint(step(u₂loop))
213-
214207
# abs of steps are equal
215208
equal_steps = (step₁ == step₂) (posindicator 0x03)
216209
# @show step₁, step₂, posindicator, equal_steps
@@ -227,9 +220,7 @@ function lower_load_for_optranslation!(
227220
end
228221
end
229222
push!(q.args, Expr(:(=), gptr, Expr(:call, lv(:gesp), ptr, gespinds)))
230-
231223
fill!(inds_by_ptroff, true)
232-
233224
@unpack ref, loopedindex = mref
234225
indices = copy(getindices(ref))
235226
# old_translation_index = indices[translationind]
@@ -262,9 +253,7 @@ function lower_load_for_optranslation!(
262253
op.ref = mref
263254
# loopedindex[translationind] = false
264255
# indices[translationind] = old_translation_index
265-
266256
shouldbroadcast = (!isvectorized(op)) && any(isvectorized, children(op))
267-
268257
# now we need to assign the `Vec`s from the `VecUnroll` to the correct name.
269258
variable_name_u = Symbol(variable_name(op, -1), '_', total_unroll)
270259
variable_name_data = Symbol(variable_name_u, "##data##")
@@ -399,46 +388,132 @@ function rejectinterleave(ls::LoopSet, op::Operation, vloop::Loop, idsformap::Su
399388
end
400389
(first(getindices(op)) === vloopsym) && (length(idsformap) first(getstrides(op)) * gethint(strd))
401390
end
391+
# function lower_load_collection_manual_u₁unroll!(
392+
# q::Expr, ls::LoopSet, opidmap::Vector{Int},
393+
# idsformap::SubArray{Tuple{Int,Int}, 1, Vector{Tuple{Int,Int}}, Tuple{UnitRange{Int}}, true},
394+
# ua::UnrollArgs, mask::Bool, inds_calc_by_ptr_offset::Vector{Bool}, op::Operation
395+
# )
396+
# @unpack u₁, u₁loop, u₁loopsym, u₂loopsym, vloopsym, vloop, suffix = ua
397+
# _mvar = mangledvar(op)
398+
# op.mangledvariable = gensym!(ls,_mvar)
399+
# for u ∈ 0:u₁-1
400+
# lower_load_collection!(
401+
# q, ls, opidmap, idsformap, ua, mask, inds_calc_by_ptr_offset
402+
# )
403+
# end
404+
# op.mangledvariable = _mvar
405+
# end
402406
function lower_load_collection!(
403407
q::Expr, ls::LoopSet, opidmap::Vector{Int},
404408
idsformap::SubArray{Tuple{Int,Int}, 1, Vector{Tuple{Int,Int}}, Tuple{UnitRange{Int}}, true},
405409
ua::UnrollArgs, mask::Bool, inds_calc_by_ptr_offset::Vector{Bool}
406410
)
407-
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, vloop, suffix = ua
411+
@unpack u₁, u₁loop, u₁loopsym, u₂loopsym, vloopsym, vloop, suffix = ua
412+
408413
ops = operations(ls)
409414
nouter = length(idsformap)
410415
# ua = UnrollArgs(nouter, unrollsyms, u₂, 0)
411416
# idsformap contains (index, offset) pairs
412417
op = ops[opidmap[first(first(idsformap))]]
418+
# if isu₁unrolled(op) && u₁ > 1 && !isknown(step(u₁loop))
419+
# return lower_load_collection_manual_u₁unroll!(
420+
# q, ls, opidmap, idsformap, ua,
421+
# mask, inds_calc_by_ptr_offset, op
422+
# )
423+
# end
413424
opindices = getindices(op)
414425
interleave = first(opindices) === vloopsym
415426
# construct dummy unrolled loop
416427
offset_dummy_loop = Loop(first(opindices), MaybeKnown(1), MaybeKnown(1024), MaybeKnown(1), Symbol(""), Symbol(""))
417428
unrollcurl₂ = unrolled_curly(op, nouter, offset_dummy_loop, vloop, mask, 1) # interleave always 1 here
418429
inds = mem_offset_u(op, ua, inds_calc_by_ptr_offset, false)
419430
falseexpr = Expr(:call, lv(:False)); rs = staticexpr(reg_size(ls));
420-
if isu₁unrolled(op) && u₁ > 1 # both unrolled
421-
if interleave # TODO: handle this better than using `rejectinterleave`
422-
interleaveval = -nouter
431+
432+
manualunrollu₁ = if isu₁unrolled(op) && u₁ > 1 # both unrolled
433+
if isknown(step(u₁loop)) && sum(Base.Fix2(===,u₁loopsym), getindicesonly(op)) == 1
434+
if interleave # TODO: handle this better than using `rejectinterleave`
435+
interleaveval = -nouter
436+
else
437+
interleaveval = 0
438+
end
439+
unrollcurl₁ = unrolled_curly(op, u₁, ua.u₁loop, vloop, mask, interleaveval)
440+
inds = Expr(:call, unrollcurl₁, inds)
441+
false
423442
else
424-
interleaveval = 0
443+
true # u₁ > 1 already checked to reach here
425444
end
426-
unrollcurl₁ = unrolled_curly(op, u₁, ua.u₁loop, vloop, mask, interleaveval)
427-
inds = Expr(:call, unrollcurl₁, inds)
445+
else
446+
false
428447
end
429448
uinds = Expr(:call, unrollcurl₂, inds)
430449
vp = vptr(op)
431450
loadexpr = Expr(:call, lv(:_vload), vp, uinds)
432451
# not using `add_memory_mask!(storeexpr, op, ua, mask)` because we checked `isconditionalmemop` earlier in `lower_load_collection!`
433-
(mask && isvectorized(op)) && push!(loadexpr.args, MASKSYMBOL)
452+
u₁vectorized = u₁loopsym === vloopsym
453+
if (mask && isvectorized(op))
454+
if !(manualunrollu₁ & u₁vectorized)
455+
push!(loadexpr.args, MASKSYMBOL)
456+
end
457+
end
434458
push!(loadexpr.args, falseexpr, rs)
435459
collectionname = Symbol(vp, "##collection##number#", opidmap[first(first(idsformap))], "#", suffix, "##size##", nouter, "##u₁##", u₁)
436-
# getfield to extract data from `VecUnroll` object, so we have a tuple
437-
push!(q.args, Expr(:(=), collectionname, Expr(:call, :getfield, loadexpr, 1)))
438-
u = Core.ifelse(isu₁unrolled(op), u₁, 1)
439-
for (i,(opid,o)) enumerate(idsformap)
440-
_op = ops[opidmap[opid]]
441-
mvar = Symbol(variable_name(_op, Core.ifelse(isu₂unrolled(_op), suffix, -1)), '_', u)
442-
push!(q.args, Expr(:(=), mvar, Expr(:call, :getfield, collectionname, i, false)))
460+
gf = GlobalRef(Core,:getfield)
461+
if manualunrollu₁
462+
masklast = mask & u₁vectorized & isvectorized(op)
463+
extractedvs = Vector{Expr}(undef, length(idsformap))
464+
for i eachindex(extractedvs)
465+
extractedvs[i] = Expr(:tuple)
466+
end
467+
for u 0:u₁-1
468+
collectionname_u = Symbol(collectionname, :_, u)
469+
if u 0
470+
inds = mem_offset_u(op, ua, inds_calc_by_ptr_offset, false, u)
471+
uinds = Expr(:call, unrollcurl₂, inds)
472+
loadexpr = copy(loadexpr)
473+
loadexpr.args[3] = Expr(:call, unrollcurl₂, inds)
474+
(((u+1) == u₁) & masklast) && push!(loadexpr.args, MASKSYMBOL)
475+
end
476+
# unpack_collection!(q, ls, opidmap, idsformap, ua, loadexpr, collectionname, op, false)
477+
push!(q.args, Expr(:(=), collectionname_u, Expr(:call, gf, loadexpr, 1)))
478+
# getfield to extract data from `VecUnroll` object, so we have a tuple
479+
for (i,(opid,o)) enumerate(idsformap)
480+
ext = extractedvs[i]
481+
if (u+1) == u₁
482+
_op = ops[opidmap[opid]]
483+
mvar = Symbol(variable_name(_op, Core.ifelse(isu₂unrolled(_op), suffix, -1)), '_', u₁)
484+
push!(q.args, Expr(:(=), mvar, Expr(:call, lv(:VecUnroll), ext)))
485+
end
486+
push!(ext.args, Expr(:call, gf, collectionname_u, i, false))
487+
end
488+
end
489+
else
490+
push!(q.args, Expr(:(=), collectionname, Expr(:call, gf, loadexpr, 1)))
491+
# getfield to extract data from `VecUnroll` object, so we have a tuple
492+
u = Core.ifelse(isu₁unrolled(op), u₁, 1)
493+
for (i,(opid,o)) enumerate(idsformap)
494+
extractedv = Expr(:call, gf, collectionname, i, false)
495+
496+
_op = ops[opidmap[opid]]
497+
mvar = Symbol(variable_name(_op, Core.ifelse(isu₂unrolled(_op), suffix, -1)), '_', u)
498+
push!(q.args, Expr(:(=), mvar, extractedv))
499+
end
500+
# unpack_collection!(q, ls, opidmap, idsformap, ua, loadexpr, collectionname, op, true)
443501
end
444502
end
503+
# function unpack_collection!(
504+
# q::Expr, ls::LoopSet, opidmap::Vector{Int},
505+
# idsformap::SubArray{Tuple{Int,Int}, 1, Vector{Tuple{Int,Int}}, Tuple{UnitRange{Int}}, true},
506+
# ua::UnrollArgs, loadexpr::Expr, collectionname::Symbol, op::Operation
507+
# )
508+
# gf = GlobalRef(Core,:getfield)
509+
# push!(q.args, Expr(:(=), collectionname, Expr(:call, gf, loadexpr, 1)))
510+
# # getfield to extract data from `VecUnroll` object, so we have a tuple
511+
# u = Core.ifelse(isu₁unrolled(op), u₁, 1)
512+
# for (i,(opid,o)) ∈ enumerate(idsformap)
513+
# extractedv = Expr(:call, gf, collectionname, i, false)
514+
515+
# _op = ops[opidmap[opid]]
516+
# mvar = Symbol(variable_name(_op, Core.ifelse(isu₂unrolled(_op), suffix, -1)), '_', u)
517+
# push!(q.args, Expr(:(=), mvar, extractedv))
518+
# end
519+
# end

src/codegen/lower_memory_common.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ end
101101
function addvectoroffset!(
102102
ret::Expr, mm::Bool, unrolledsteps::Int, unrolledloopstride, vloopstride, indexstride::Integer, index, offset::Integer, calcbypointeroffset::Bool, indvectorized::Bool
103103
) # 10 -> (7 or 8) args
104+
# if !isknown(unrolledloopstride)
105+
# @show unrolledsteps, calcbypointeroffset, _isone(unrolledloopstride)
106+
# end
104107
if unrolledsteps == 0 # neither unrolledloopstride or indexstride can be 0
105108
addoffset!(ret, mm, vloopstride, indexstride, index, offset, calcbypointeroffset) # 7 arg
106109
elseif indvectorized
@@ -112,8 +115,10 @@ function addvectoroffset!(
112115
else
113116
addvectoroffset!(ret, mm, mulexpr(unrolledloopstride,unrolledsteps), vloopstride, indexstride, index, offset, calcbypointeroffset) # 8 arg
114117
end
115-
else
118+
elseif _isone(unrolledloopstride)
116119
addoffset!(ret, mm, vloopstride, indexstride, index, offset + unrolledsteps, calcbypointeroffset) # 7 arg
120+
else
121+
addoffset!(ret, mm, vloopstride, mulexpr(unrolledloopstride,indexstride), index, addexpr(offset, lazymulexpr(unrolledloopstride, unrolledsteps)), calcbypointeroffset) # 7 arg
117122
end
118123
end
119124

@@ -261,7 +266,9 @@ function unrolledindex(op::Operation, td::UnrollArgs, mask::Bool, inds_calc_by_p
261266
Expr(:call, unrollcurl, ind)
262267
end
263268

264-
function mem_offset_u(op::Operation, td::UnrollArgs, inds_calc_by_ptr_offset::Vector{Bool}, _mm::Bool, incr₁::Int = 0)
269+
function mem_offset_u(
270+
op::Operation, td::UnrollArgs, inds_calc_by_ptr_offset::Vector{Bool}, _mm::Bool, incr₁::Int = 0
271+
)
265272
@assert accesses_memory(op) "Computing memory offset only makes sense for operations that access memory."
266273
@unpack u₁loopsym, u₂loopsym, vloopsym, u₁step, u₂step, vstep, suffix = td
267274

src/codegen/lower_store.jl

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,12 @@ function lower_store_collection!(
5252
opidmap = offsetloadcollection(ls).opids[collectionid]
5353
idsformap = omop.batchedcollections[batchid]
5454

55-
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, vloop, u₂max, suffix = ua
55+
@unpack u₁, u₁loop, u₁loopsym, u₂loopsym, vloopsym, vloop, u₂max, suffix = ua
5656
ops = operations(ls)
5757
# __u₂max = ls.unrollspecification[].u₂
5858
nouter = length(idsformap)
5959

6060
t = Expr(:tuple)
61-
# u = Core.ifelse(isu₁unrolled(op), u₁, 1)
62-
6361
for (i,(opid,_)) enumerate(idsformap)
6462
opp = first(parents(ops[opidmap[opid]]))
6563

@@ -79,22 +77,65 @@ function lower_store_collection!(
7977
falseexpr = Expr(:call, lv(:False));
8078
trueexpr = Expr(:call, lv(:True));
8179
rs = staticexpr(reg_size(ls));
82-
if isu₁unrolled(op) && u₁ > 1 # both unrolled
83-
if first(getindices(op)) === vloopsym
84-
interleaveval = -nouter
80+
manualunrollu₁ = if isu₁unrolled(op) && u₁ > 1 # both unrolled
81+
if isknown(step(u₁loop)) && sum(Base.Fix2(===,u₁loopsym), getindicesonly(op)) == 1
82+
if first(getindices(op)) === vloopsym
83+
interleaveval = -nouter
84+
else
85+
interleaveval = 0
86+
end
87+
unrollcurl₁ = unrolled_curly(op, u₁, ua.u₁loop, vloop, mask, interleaveval)
88+
inds = Expr(:call, unrollcurl₁, inds)
89+
false
8590
else
86-
interleaveval = 0
91+
true
8792
end
88-
unrollcurl₁ = unrolled_curly(op, u₁, ua.u₁loop, vloop, mask, interleaveval)
89-
inds = Expr(:call, unrollcurl₁, inds)
93+
else
94+
false
9095
end
9196
uinds = Expr(:call, unrollcurl₂, inds)
9297
vp = vptr(op)
9398
storeexpr = Expr(:call, lv(:_vstore!), vp, Expr(:call, lv(:VecUnroll), t), uinds)
9499
# not using `add_memory_mask!(storeexpr, op, ua, mask)` because we checked `isconditionalmemop` earlier in `lower_load_collection!`
95-
mask && push!(storeexpr.args, MASKSYMBOL)
100+
u₁vectorized = u₁loopsym === vloopsym
101+
if mask# && isvectorized(op))
102+
if !(manualunrollu₁ & u₁vectorized)
103+
push!(storeexpr.args, MASKSYMBOL)
104+
end
105+
end
96106
push!(storeexpr.args, falseexpr, trueexpr, falseexpr, rs)
97-
push!(q.args, storeexpr)
107+
if manualunrollu₁
108+
masklast = mask & u₁vectorized
109+
gf = GlobalRef(Core,:getfield)
110+
tv = Vector{Symbol}(undef, length(t.args))
111+
for i eachindex(tv)
112+
s = gensym!(ls, "##tmp##collection##store##")
113+
tv[i] = s
114+
push!(q.args, Expr(:(=), s, Expr(:call, gf, t.args[i], 1)))
115+
end
116+
# @show u₁, t
117+
for u 0:u₁-1
118+
lastiter = (u+1) == u₁
119+
storeexpr_tmp = if lastiter
120+
storeexpr
121+
(((u+1) == u₁) & masklast) && push!(storeexpr.args, MASKSYMBOL)
122+
storeexpr
123+
else
124+
copy(storeexpr)
125+
end
126+
vut = Expr(:tuple)
127+
for i eachindex(tv)
128+
push!(vut.args, Expr(:call, gf, tv[i], u+1, false))
129+
end
130+
storeexpr_tmp.args[3] = Expr(:call, lv(:VecUnroll), vut)
131+
if u 0
132+
storeexpr_tmp.args[4] = Expr(:call, unrollcurl₂, mem_offset_u(op, ua, inds_calc_by_ptr_offset, false, u))
133+
end
134+
push!(q.args, storeexpr_tmp)
135+
end
136+
else
137+
push!(q.args, storeexpr)
138+
end
98139
nothing
99140
end
100141
function lower_store!(

src/modeling/determinestrategy.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -753,6 +753,7 @@ end
753753
function maxnegativeoffset(ls::LoopSet, op::Operation, u::Symbol)
754754
mno::Int = typemin(Int)
755755
id = 0
756+
isknown(step(getloop(ls, u))) || return mno, id
756757
omop = offsetloadcollection(ls)
757758
collectionid, opind = omop.opidcollectionmap[identifier(op)]
758759
collectionid == 0 && return mno, id
@@ -1153,6 +1154,14 @@ function choose_unroll_order(ls::LoopSet, lowest_cost::Float64 = Inf)
11531154
end
11541155

11551156

1157+
1158+
function reject_reorder(ls::LoopSet, reordered::Symbol)
1159+
length(ls.outer_reductions) > 0 || return false
1160+
for op operations(ls)
1161+
reordered loopdependencies(op) && any(opp -> (iscompute(opp) && isanouterreduction(ls, opp)), parents(op)) && return true
1162+
end
1163+
false
1164+
end
11561165
"""
11571166
This function searches for unrolling combinations that will cause LoopVectorization to generate invalid code.
11581167
@@ -1203,14 +1212,6 @@ function reject_candidate(op::Operation, u₁loopsym::Symbol, u₂loopsym::Symbo
12031212
end
12041213
false
12051214
end
1206-
1207-
function reject_reorder(ls::LoopSet, reordered::Symbol)
1208-
length(ls.outer_reductions) > 0 || return false
1209-
for op operations(ls)
1210-
reordered loopdependencies(op) && any(opp -> (iscompute(opp) && isanouterreduction(ls, opp)), parents(op)) && return true
1211-
end
1212-
false
1213-
end
12141215
function reject_candidate(ls::LoopSet, u₁loopsym::Symbol, u₂loopsym::Symbol)
12151216
for op operations(ls)
12161217
reject_candidate(op, u₁loopsym, u₂loopsym) && return true

src/parse/memory_ops_common.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,8 +230,15 @@ function muladd_index!(
230230
ls::LoopSet, parents, loopdependencies, reduceddeps, indices, offsets, strides, loopedindex,
231231
mlt::Int, symop::Operation, offset::Int
232232
)
233-
indop = muladd_op!(ls, mlt, symop, offset)
234-
addopindex!(parents, loopdependencies, reduceddeps, indices, offsets, strides, loopedindex, indop)
233+
if byterepresentable(offset) & byterepresentable(mlt)
234+
addopindex!(
235+
parents, loopdependencies, reduceddeps, indices,
236+
offsets, strides, loopedindex, symop, mlt, offset
237+
)
238+
else
239+
indop = muladd_op!(ls, mlt, symop, offset)
240+
addopindex!(parents, loopdependencies, reduceddeps, indices, offsets, strides, loopedindex, indop)
241+
end
235242
end
236243

237244
function checkforoffset!(

src/reconstruct_loopset.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ function avx_body(ls::LoopSet, UNROLL::Tuple{Bool,Int8,Int8,Bool,Int,Int,Int,Int
575575
end
576576

577577
function _avx_loopset_debug(::Val{UNROLL}, ::Val{OPS}, ::Val{ARF}, ::Val{AM}, ::Val{LPSYM}, _vargs::Tuple{LB,V}) where {UNROLL, OPS, ARF, AM, LPSYM, LB, V}
578-
@show OPS ARF AM LPSYM _vargs
578+
# @show OPS ARF AM LPSYM _vargs
579579
_avx_loopset(OPS, ARF, AM, LPSYM, _vargs[1].parameters, V.parameters, UNROLL)
580580
end
581581
function tovector(@nospecialize(t))
@@ -630,7 +630,7 @@ Execute an `@avx` block. The block's code is represented via the arguments:
630630
@generated function _avx_!(
631631
::Val{UNROLL}, ::Val{OPS}, ::Val{ARF}, ::Val{AM}, ::Val{LPSYM}, var"#lv#tuple#args#"::Tuple{LB,V}
632632
) where {UNROLL, OPS, ARF, AM, LPSYM, LB, V}
633-
# 1 + 1 # Irrelevant line you can comment out/in to force recompilation...
633+
1 + 1 # Irrelevant line you can comment out/in to force recompilation...
634634
ls = _avx_loopset(OPS, ARF, AM, LPSYM, LB.parameters, V.parameters, UNROLL)
635635
# return @show avx_body(ls, UNROLL)
636636
if last(UNROLL) > 1

0 commit comments

Comments
 (0)