Skip to content

Commit 7852f60

Browse files
committed
Check for 0 loopdeps before trying to find array in case of CartesianIndices, fixes #256. Place masks in correct order when manually unroling collections, fixes #257. Add more canonicalization/indexing cleanup.
1 parent a3cdfff commit 7852f60

File tree

9 files changed

+299
-213
lines changed

9 files changed

+299
-213
lines changed

src/codegen/loopstartstopmanager.jl

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -230,22 +230,27 @@ function normalize_offsets!(
230230
end
231231
return Int(minoffset)
232232
end
233-
function isloopvalue(ls::LoopSet, ind::Symbol)
234-
for op operations(ls)
233+
function isloopvalue(ls::LoopSet, ind::Symbol, isrooted::Union{Nothing,Vector{Bool}} = nothing)
234+
for (i,op) enumerate(operations(ls))
235+
if (isrooted nothing)
236+
isrooted[i] || continue
237+
end
235238
iscompute(op) || continue
236239
for opp parents(op)# this is to confirm `ind` still has children
237-
(isloopvalue(opp) && instruction(opp).instr === ind) && return true
240+
# (isloopvalue(opp) && instruction(opp).instr === ind) && return true
241+
if (isloopvalue(opp) && instruction(opp).instr === ind)
242+
return true
243+
end
238244
end
239245
end
240246
return false
241247
end
242248
function cse_constant_offsets!(
243-
ls::LoopSet, allarrayrefs::Vector{ArrayReferenceMeta}, allarrayrefsind::Int, name_to_array_map::Vector{Vector{Int}},
244-
arrayref_to_name_op_collection::Vector{Vector{Tuple{Int,Int,Int}}}, shouldindbyind::Vector{Bool}
249+
ls::LoopSet, allarrayrefs::Vector{ArrayReferenceMeta}, allarrayrefsind::Int, name_to_array_map::Vector{Vector{Int}}, arrayref_to_name_op_collection::Vector{Vector{Tuple{Int,Int,Int}}}
245250
)
246251
ar = allarrayrefs[allarrayrefsind]
247252
# @show ar
248-
vptrar = vptr(ar)
253+
# vptrar = vptr(ar)
249254
arrayref_to_name_op = arrayref_to_name_op_collection[allarrayrefsind]
250255
array_refs_with_same_name = name_to_array_map[first(first(arrayref_to_name_op))]
251256
us = ls.unrollspecification
@@ -254,7 +259,7 @@ function cse_constant_offsets!(
254259
strides = getstrides(ar)
255260
offset = first(indices) === DISCONTIGUOUS
256261
# gespindoffsets = fill(Symbol(""), length(li))
257-
gespinds = Expr(:tuple)
262+
gespindsummary = Vector{Tuple{Symbol,Int}}(undef, length(li))
258263
for i eachindex(li)
259264
gespsymbol::Symbol = Symbol("")
260265
ii = i + offset
@@ -372,9 +377,10 @@ function cse_constant_offsets!(
372377
end
373378
end
374379
constoffset = normalize_offsets!(ls, i, allarrayrefs, array_refs_with_same_name, arrayref_to_name_op_collection)
375-
pushgespind!(gespinds, ls, gespsymbol, constoffset, ind, li, i, check_shouldindbyind(ls, ind, shouldindbyind), true)
380+
gespindsummary[i] = (gespsymbol, constoffset)
381+
# pushgespind!(gespinds, ls, gespsymbol, constoffset, ind, li, i, check_shouldindbyind(ls, ind, shouldindbyind), true)
376382
end
377-
return gespinds
383+
return gespindsummary
378384
end
379385
@inline similardims(_, i) = i
380386
@inline similardims(::CartesianIndices{N}, i) where {N} = VectorizationBase.CartesianVIndex(ntuple(_ -> i, Val{N}()))
@@ -393,10 +399,22 @@ end
393399
# end
394400
# return nothing
395401
# end
402+
function calcgespinds(ls::LoopSet, ar::ArrayReferenceMeta, gespindsummary::Vector{Tuple{Symbol,Int}}, shouldindbyind::Vector{Bool})
403+
gespinds = Expr(:tuple)
404+
li = ar.loopedindex
405+
indices = getindicesonly(ar)
406+
for i eachindex(li)
407+
ind = indices[i]
408+
gespsymbol, constoffset = gespindsummary[i]
409+
pushgespind!(gespinds, ls, gespsymbol, constoffset, ind, li[i], check_shouldindbyind(ls, ind, shouldindbyind), true)
410+
end
411+
gespinds
412+
end
413+
396414
function pushgespind!(
397-
gespinds::Expr, ls::LoopSet, gespsymbol::Symbol, constoffset::Int, ind::Symbol, li::Vector{Bool}, i::Int, index_by_index::Bool, fromgsp::Bool
415+
gespinds::Expr, ls::LoopSet, gespsymbol::Symbol, constoffset::Int, ind::Symbol, isli::Bool, index_by_index::Bool, fromgsp::Bool
398416
)
399-
if li[i]
417+
if isli
400418
if ind === CONSTANTZEROINDEX
401419
if gespsymbol === Symbol("")
402420
push!(gespinds.args, staticexpr(constoffset))
@@ -448,13 +466,21 @@ function pushgespind!(
448466
elseif fromgsp # from gsp means that a loop could be a CartesianIndices, so we may need to expand
449467
#TODO: broadcast dimensions in case of cartesian indices
450468
rangesym = ind
469+
foundind = false
451470
for op operations(ls)
452471
if name(op) === ind
453-
loopsym = first(loopdependencies(op))
454-
rangesym = getloop(ls, loopsym).rangesym
472+
loopdeps = loopdependencies(op)
473+
foundind = true
474+
if length(loopdeps) 0
475+
rangesym = getloop(ls, first(loopdeps)).rangesym
476+
else
477+
isconstantop(op) || throw(LoopError("Please file an issue with LoopVectorization.jl with a reproducer; tried to eliminate a non-constant operation."))
478+
rangesym = name(op)
479+
end
480+
break
455481
end
456482
end
457-
@assert rangesym ind
483+
@assert foundind
458484
if rangesym === Symbol("") # there is no rangesym, must be statically sized.
459485
pushgespsym!(gespinds, gespsymbol, constoffset)
460486
else
@@ -518,32 +544,32 @@ function use_loop_induct_var!(
518544
vptrar = vptr(ar)
519545
# @show ar
520546
Wisz = false#ls.vector_width == 0
521-
for i eachindex(li)
547+
for (i,isli) enumerate(li)
522548
ii = i + offset
523549
ind = indices[ii]
524550
Wisz && push!(gespinds.args, staticexpr(0)) # wrong for `@_avx`...
525551
if !li[i] # if it wasn't set
526552
uliv[i] = 0
527553
push!(offsetprecalc_descript.args, 0)
528-
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, ind, li, i, true, false)
554+
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, ind, isli, true, false)
529555
elseif ind === CONSTANTZEROINDEX
530556
uliv[i] = 0
531557
push!(offsetprecalc_descript.args, 0)
532-
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, ind, li, i, true, false)
558+
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, ind, isli, true, false)
533559
elseif isbroadcast ||
534560
((isone(ii) && (last(looporder) === ind)) && !(otherindexunrolled(ls, ind, ar)) ||
535561
multiple_with_name(vptrar, allarrayrefs)) ||
536562
(iszero(ls.vector_width) && isstaticloop(getloop(ls, ind)))# ||
537563
# Not doing normal offset indexing
538564
uliv[i] = -findfirst(Base.Fix2(===,ind), looporder)::Int
539565
push!(offsetprecalc_descript.args, 0) # not doing offset indexing, so push 0
540-
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, ind, li, i, true, false)
566+
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, ind, isli, true, false)
541567
else
542568
uliv[i] = findfirst(Base.Fix2(===,ind), looporder)::Int
543569
loop = getloop(ls, ind)
544570
push!(offsetprecalc_descript.args, max(5,us.u₁+1,us.u₂+1))
545571
use_offsetprecalc = true
546-
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, ind, li, i, false, false)
572+
Wisz || pushgespind!(gespinds, ls, Symbol(""), 0, ind, isli, false, false)
547573
end
548574
# cases for pushgespind! and loopval!
549575
# if !isloopval, same as before

src/codegen/lower_load.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ function lower_load_collection!(
477477
uinds = Expr(:call, unrollcurl₂, inds)
478478
loadexpr = copy(loadexpr)
479479
loadexpr.args[3] = Expr(:call, unrollcurl₂, inds)
480-
(((u+1) == u₁) & masklast) && push!(loadexpr.args, MASKSYMBOL)
480+
(((u+1) == u₁) & masklast) && insert!(loadexpr.args, length(loadexpr.args)-1, MASKSYMBOL) # 1 for `falseexpr` pushed at end
481481
end
482482
# unpack_collection!(q, ls, opidmap, idsformap, ua, loadexpr, collectionname, op, false)
483483
push!(q.args, Expr(:(=), collectionname_u, Expr(:call, gf, loadexpr, 1)))

src/codegen/lower_store.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ function lower_store_collection!(
122122
lastiter = (u+1) == u₁
123123
storeexpr_tmp = if lastiter
124124
storeexpr
125-
(((u+1) == u₁) & masklast) && push!(storeexpr.args, MASKSYMBOL)
125+
(((u+1) == u₁) & masklast) && insert!(storeexpr.args, length(storeexpr.args)-3, MASKSYMBOL) # 3 for falseexpr, aliasexpr, falseexpr
126126
storeexpr
127127
else
128128
copy(storeexpr)

src/condense_loopset.jl

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ end
159159
@inline zerorangestart(r::AbstractUnitRange) = Zero():One():(maybestaticlast(r)-maybestaticfirst(r))
160160
@inline zerorangestart(r::AbstractRange) = Zero():static_step(r):(maybestaticlast(r)-maybestaticfirst(r))
161161
@inline zerorangestart(r::CartesianIndices) = CartesianIndices(map(zerorangestart, r.indices))
162+
@inline zerorangestart(r::ArrayInterface.OptionallyStaticUnitRange{StaticInt{1}}) = CloseOpen(last(r))
162163

163164
function loop_boundary!(q::Expr, ls::LoopSet, loop::Loop, shouldindbyind::Bool)
164165
if isstaticloop(loop) || loop.rangesym === Symbol("")
@@ -309,12 +310,12 @@ function findfirstcontaining(ref, ind)
309310
end
310311
0
311312
end
312-
function should_zerorangestart(ls::LoopSet, allarrayrefs::Vector{ArrayReferenceMeta}, name_to_array_map::Vector{Vector{Int}})
313+
function should_zerorangestart(ls::LoopSet, allarrayrefs::Vector{ArrayReferenceMeta}, name_to_array_map::Vector{Vector{Int}}, isrooted::Vector{Bool})
313314
loops = ls.loops
314315
shouldindbyind = fill(false, length(loops))
315316
for (i,loop) enumerate(loops)
316317
ind = loop.itersymbol
317-
if isloopvalue(ls, ind)
318+
if isloopvalue(ls, ind, isrooted)
318319
# we don't zero the range if it is used as a loopvalue
319320
shouldindbyind[i] = true
320321
continue
@@ -355,10 +356,11 @@ end
355356
# 2) decide whether to gesp that loopstart inside `add_grouped_strided_pointer`
356357
function add_grouped_strided_pointer!(extra_args::Expr, ls::LoopSet)
357358
allarrayrefs, name_to_array_map, unique_to_name_and_op_map = uniquearrayrefs_csesummary(ls)
358-
shouldindbyind = should_zerorangestart(ls, allarrayrefs, name_to_array_map)
359359
# @show allarrayrefs
360360
gsp = Expr(:call, lv(:grouped_strided_pointer))
361361
tgarrays = Expr(:tuple)
362+
# refs_to_gesp = ArrayReferenceMeta[]
363+
gespsummaries = Tuple{Int,Vector{Tuple{Symbol,Int}}}[]
362364
i = 0
363365
preserve_assignment = Expr(:tuple); preserve = Symbol[];
364366
@unpack equalarraydims, refs_aliasing_syms = ls
@@ -383,17 +385,24 @@ function add_grouped_strided_pointer!(extra_args::Expr, ls::LoopSet)
383385
duplicate && continue
384386
duplicate_map[j] = (i += 1)
385387
found = false
386-
for j eachindex(allarrayrefs)
387-
if sameref(allarrayrefs[j], ref)
388-
gespinds = cse_constant_offsets!(ls, allarrayrefs, j, name_to_array_map, unique_to_name_and_op_map, shouldindbyind)
389-
push!(tgarrays.args, Expr(:call, lv(:gespf1), vpref, gespinds))
388+
for k eachindex(allarrayrefs)
389+
if sameref(allarrayrefs[k], ref)
390+
gespindsummary = cse_constant_offsets!(ls, allarrayrefs, k, name_to_array_map, unique_to_name_and_op_map)
391+
push!(gespsummaries, (k, gespindsummary))
390392
found = true
391393
break
392394
end
393395
end
394396
@assert found
395397
push!(preserve, presbufsym(ref.ref.array))
396398
end
399+
roots = getroots(ls)
400+
shouldindbyind = should_zerorangestart(ls, allarrayrefs, name_to_array_map, roots)
401+
for (k,gespindsummary) gespsummaries
402+
ref = allarrayrefs[k]
403+
gespinds = calcgespinds(ls, ref, gespindsummary, shouldindbyind)
404+
push!(tgarrays.args, Expr(:call, lv(:gespf1), vptr(ref), gespinds))
405+
end
397406
push!(gsp.args, tgarrays)
398407
matcheddims = Expr(:tuple)
399408
for (vptrs,dims) equalarraydims
@@ -409,7 +418,7 @@ function add_grouped_strided_pointer!(extra_args::Expr, ls::LoopSet)
409418
gsps = gensym!(ls, "#grouped#strided#pointer#")
410419
push!(extra_args.args, gsps)
411420
pushpreamble!(ls, Expr(:(=), gsps, Expr(:call, GlobalRef(Core,:getfield), gsp, 1)))
412-
preserve, shouldindbyind
421+
preserve, shouldindbyind, roots
413422
end
414423

415424
# first_cache() = ifelse(gt(num_cache_levels(), StaticInt{2}()), StaticInt{2}(), StaticInt{1}())
@@ -468,12 +477,11 @@ end
468477
# Try to condense in type stable manner
469478
function generate_call(ls::LoopSet, (inline,u₁,u₂)::Tuple{Bool,Int8,Int8}, thread::UInt, debug::Bool = false)
470479
extra_args = Expr(:tuple)
471-
preserve, shouldindbyind = add_grouped_strided_pointer!(extra_args, ls)
480+
preserve, shouldindbyind, roots = add_grouped_strided_pointer!(extra_args, ls)
472481

473482
operation_descriptions = Expr(:tuple)
474483
varnames = Symbol[]; ids = Vector{Int}(undef, length(operations(ls)))
475484
ops = operations(ls)
476-
roots = getroots(ls)
477485
length(ls.includedactualarrays) == 0 || remove_outer_reducts!(roots, ls)
478486
for op ops
479487
instr::Instruction = instruction(op)

src/modeling/graphs.jl

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -746,15 +746,35 @@ end
746746

747747
operations(ls::LoopSet) = ls.operations
748748
function pushop!(ls::LoopSet, op::Operation, var::Symbol = name(op))
749-
for opp operations(ls)
750-
if matches(op, opp)
751-
ls.opdict[var] = opp
752-
return opp
753-
end
749+
if iscompute(op) && length(loopdependencies(op)) == 0
750+
op.node_type = constant
751+
opdef = callexpr(instruction(op))
752+
opparents = parents(op)
753+
mangledname = Symbol('#', instruction(op).instr, '#')
754+
while length(opparents) > 0
755+
oppname = name(popfirst!(opparents))
756+
mangledname = Symbol(mangledname, oppname, '#')
757+
push!(opdef.args, oppname)
758+
# if opp.instruction == LOOPCONSTANT
759+
# push!(opdef.args, name(opp))
760+
# else
761+
762+
# end
763+
end
764+
op.mangledvariable = mangledname
765+
pushpreamble!(ls, Expr(:(=), name(op), opdef))
766+
op.instruction = LOOPCONSTANT
767+
push!(ls.preamble_symsym, (identifier(op), name(op)))
768+
end
769+
for opp operations(ls)
770+
if matches(op, opp)
771+
ls.opdict[var] = opp
772+
return opp
754773
end
755-
push!(ls.operations, op)
756-
ls.opdict[var] = op
757-
op
774+
end
775+
push!(ls.operations, op)
776+
ls.opdict[var] = op
777+
op
758778
end
759779
function add_block!(ls::LoopSet, ex::Expr, elementbytes::Int, position::Int)
760780
for x ex.args

src/parse/add_constants.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ function add_constant!(ls::LoopSet, var::Number, elementbytes::Int = 8)
1616
op = Operation(length(operations(ls)), gensym!(ls, "loopconstnumber"), elementbytes, LOOPCONSTANT, constant, NODEPENDENCY, Symbol[], NOPARENTS)
1717
ops = operations(ls)
1818
typ = var isa Integer ? HardInt : HardFloat
19-
rop = pushop!(ls, op)
20-
rop === op || return rop
2119
if iszero(var)
2220
for (id,typ_) ls.preamble_zeros
2321
(instruction(ops[id]) == LOOPCONSTANT && typ == typ_) && return ops[id]
@@ -37,6 +35,9 @@ function add_constant!(ls::LoopSet, var::Number, elementbytes::Int = 8)
3735
end
3836
push!(ls.preamble_symfloat, (identifier(op), var))
3937
end
38+
rop = pushop!(ls, op)
39+
rop === op || return rop
40+
pushpreamble!(ls, Expr(:(=), name(op), var))
4041
rop
4142
end
4243
function add_constant!(ls::LoopSet, mpref::ArrayReferenceMetaPosition, elementbytes::Int)

0 commit comments

Comments
 (0)