Skip to content

Commit c71b0be

Browse files
committed
Make LoopSet a mutable struct and handle dynamic unroll factors
1 parent 7301b41 commit c71b0be

16 files changed

+260
-240
lines changed

src/broadcast.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ end
400400
ls = LoopSet(Mod)
401401
inline, u₁, u₂, isbroadcast, W, rs, rc, cls, l1, l2, l3, threads = UNROLL
402402
set_hw!(ls, rs, rc, cls, l1, l2, l3)
403-
ls.isbroadcast[] = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
403+
ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
404404
loopsyms = [gensym!(ls, "n") for n 1:N]
405405
add_broadcast_loops!(ls, loopsyms, :dest)
406406
elementbytes = sizeof(T)
@@ -419,7 +419,7 @@ end
419419
ls = LoopSet(Mod)
420420
inline, u₁, u₂, isbroadcast, W, rs, rc, cls, l1, l2, l3, threads = UNROLL
421421
set_hw!(ls, rs, rc, cls, l1, l2, l3)
422-
ls.isbroadcast[] = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
422+
ls.isbroadcast = isbroadcast # maybe set `false` in a DiffEq-like `@..` macro
423423
loopsyms = [gensym!(ls, "n") for n 1:N]
424424
pushprepreamble!(ls, Expr(:(=), :dest, Expr(:call, :parent, :dest′)))
425425
add_broadcast_loops!(ls, loopsyms, :dest′)

src/codegen/loopstartstopmanager.jl

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ end
2828

2929
otherindexunrolled(loopsym::Symbol, ind::Symbol, loopdeps::Vector{Symbol}) = (loopsym !== ind) && (loopsym loopdeps)
3030
function otherindexunrolled(ls::LoopSet, ind::Symbol, ref::ArrayReferenceMeta)
31-
us = ls.unrollspecification[]
31+
us = ls.unrollspecification
3232
u₁sym = names(ls)[us.u₁loopnum]
3333
u₂sym = us.u₂loopnum > 0 ? names(ls)[us.u₂loopnum] : Symbol("##undefined##")
3434
otherindexunrolled(u₁sym, ind, loopdependencies(ref)) || otherindexunrolled(u₂sym, ind, loopdependencies(ref))
@@ -37,18 +37,18 @@ multiple_with_name(n::Symbol, v::Vector{ArrayReferenceMeta}) = sum(ref -> n ===
3737
# TODO: DRY between indices_calculated_by_pointer_offsets and use_loop_induct_var
3838
function indices_calculated_by_pointer_offsets(ls::LoopSet, ar::ArrayReferenceMeta)
3939
indices = getindices(ar)
40-
ls.isbroadcast[] && return fill(false, length(indices))
40+
ls.isbroadcast && return fill(false, length(indices))
4141
looporder = names(ls)
4242
offset = isdiscontiguous(ar)
4343
gespinds = Expr(:tuple)
4444
out = Vector{Bool}(undef, length(indices))
4545
li = ar.loopedindex
46-
# @show ls.vector_width[]
46+
# @show ls.vector_width
4747
for i eachindex(li)
4848
ii = i + offset
4949
ind = indices[ii]
50-
if (!li[i]) || (ind === CONSTANTZEROINDEX) || multiple_with_name(vptr(ar), ls.lssm[].uniquearrayrefs) ||
51-
(iszero(ls.vector_width[]) && isstaticloop(getloop(ls, ind)))# ||
50+
if (!li[i]) || (ind === CONSTANTZEROINDEX) || multiple_with_name(vptr(ar), ls.lssm.uniquearrayrefs) ||
51+
(iszero(ls.vector_width) && isstaticloop(getloop(ls, ind)))# ||
5252
out[i] = false
5353
elseif (isone(ii) && (first(looporder) === ind))
5454
out[i] = otherindexunrolled(ls, ind, ar)
@@ -81,7 +81,7 @@ A value > 0 indicates which loop number that index corresponds to when increment
8181
A value < 0 indicates that abs(value) is the corresponding loop, and a `loopvalue` will be used.
8282
"""
8383
function use_loop_induct_var!(ls::LoopSet, q::Expr, ar::ArrayReferenceMeta, allarrayrefs::Vector{ArrayReferenceMeta}, includeinlet::Bool)
84-
us = ls.unrollspecification[]
84+
us = ls.unrollspecification
8585
li = ar.loopedindex
8686
looporder = reversenames(ls)
8787
uliv = Vector{Int}(undef, length(li))
@@ -92,7 +92,7 @@ function use_loop_induct_var!(ls::LoopSet, q::Expr, ar::ArrayReferenceMeta, alla
9292
println(ar)
9393
throw("Length of indices and length of offset do not match!")
9494
end
95-
isbroadcast = ls.isbroadcast[]
95+
isbroadcast = ls.isbroadcast
9696
gespinds = Expr(:tuple)
9797
offsetprecalc_descript = Expr(:tuple)
9898
use_offsetprecalc = false
@@ -111,7 +111,7 @@ function use_loop_induct_var!(ls::LoopSet, q::Expr, ar::ArrayReferenceMeta, alla
111111
elseif isbroadcast ||
112112
((isone(ii) && (last(looporder) === ind)) && !(otherindexunrolled(ls, ind, ar)) ||
113113
multiple_with_name(vptr(ar), allarrayrefs)) ||
114-
(iszero(ls.vector_width[]) && isstaticloop(getloop(ls, ind)))# ||
114+
(iszero(ls.vector_width) && isstaticloop(getloop(ls, ind)))# ||
115115
# ((ls.align_loops[] > 0) && (first(names(ls)) == ind))
116116

117117
# Not doing normal offset indexing
@@ -168,13 +168,19 @@ end
168168
# Plan here is that we increment every unique array
169169
function add_loop_start_stop_manager!(ls::LoopSet)
170170
q = Expr(:block)
171-
us = ls.unrollspecification[]
171+
us = ls.unrollspecification
172172
# Presence of an explicit use of a loopinducation var means we should use that, so we look for one
173173
# TODO: replace first with only once you add Compat as a dep or drop support for older Julia versions
174-
loopinductvars = map(op -> first(loopdependencies(op)), filter(isloopvalue, operations(ls)))
174+
loopinductvars = Symbol[]
175+
for op operations(ls)
176+
isloopvalue(op) && push!(loopinductvars, first(loopdependencies(op)))
177+
end
175178
# Filtered ArrayReferenceMetas, we must increment each
176179
arrayrefs, includeinlet = uniquearrayrefs(ls)
177-
use_livs = map((ar,iil) -> use_loop_induct_var!(ls, q, ar, arrayrefs, iil), arrayrefs, includeinlet)
180+
use_livs = Vector{Vector{Int}}(undef, length(arrayrefs))
181+
for i eachindex(arrayrefs)
182+
use_livs[i] = use_loop_induct_var!(ls, q, arrayrefs[i], arrayrefs, includeinlet[i])
183+
end
178184
# @show use_livs,
179185
# loops, sorted from outer-most to inner-most
180186
looporder = reversenames(ls)
@@ -208,7 +214,7 @@ function add_loop_start_stop_manager!(ls::LoopSet)
208214
last(ric[argmin(first.(ric))]) # index corresponds to array ref's position in loopstart
209215
end
210216
end
211-
ls.lssm[] = LoopStartStopManager(
217+
ls.lssm = LoopStartStopManager(
212218
terminators, loopstarts, arrayrefs
213219
)
214220
q
@@ -395,7 +401,7 @@ end
395401

396402
function startloop(ls::LoopSet, us::UnrollSpecification, n::Int, submax = maxunroll(us, n))
397403
@unpack u₁loopnum, u₂loopnum, vloopnum, u₁, u₂ = us
398-
lssm = ls.lssm[]
404+
lssm = ls.lssm
399405
termind = lssm.terminators[n]
400406
ptrdefs = lssm.incrementedptrs[n]
401407
loopstart = Expr(:block)
@@ -436,7 +442,7 @@ function offset_ptr(
436442
end
437443
function incrementloopcounter!(q::Expr, ls::LoopSet, us::UnrollSpecification, n::Int, UF::Int)
438444
@unpack u₁loopnum, u₂loopnum, vloopnum, u₁, u₂ = us
439-
lssm = ls.lssm[]
445+
lssm = ls.lssm
440446
ptrdefs = lssm.incrementedptrs[n]
441447
looporder = names(ls)
442448
loopsym = looporder[n]
@@ -452,7 +458,7 @@ function incrementloopcounter!(q::Expr, ls::LoopSet, us::UnrollSpecification, n:
452458
nothing
453459
end
454460
function terminatecondition(ls::LoopSet, us::UnrollSpecification, n::Int, inclmask::Bool, UF::Int)
455-
lssm = ls.lssm[]
461+
lssm = ls.lssm
456462
termind = lssm.terminators[n]
457463
if iszero(termind)
458464
loop = getloop(ls, n)

src/codegen/lower_compute.jl

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11

22

3-
function load_constrained(op, u₁loop, u₂loop, innermost_loop_or_vloop, forprefetch = false)
3+
function load_constrained(op::Operation, u₁loop::Symbol, u₂loop::Symbol, innermost_loop_or_vloop::Symbol, forprefetch::Bool = false)
44
dependsonu₁ = isu₁unrolled(op)
55
dependsonu₂ = isu₂unrolled(op)
66
if forprefetch
@@ -21,8 +21,8 @@ function load_constrained(op, u₁loop, u₂loop, innermost_loop_or_vloop, forpr
2121
isload(opp) && all(in(loopdependencies(opp)), unrolleddeps)
2222
end
2323
end
24-
function check_if_remfirst(ls, ua)
25-
usorig = ls.unrollspecification[]
24+
function check_if_remfirst(ls::LoopSet, ua::UnrollArgs)
25+
usorig = ls.unrollspecification
2626
@unpack u₁, u₁loopsym, u₂loopsym, u₂max = ua
2727
u₁loop = getloop(ls, u₁loopsym)
2828
u₂loop = getloop(ls, u₂loopsym)
@@ -39,20 +39,23 @@ function sub_fmas(ls::LoopSet, op::Operation, ua::UnrollArgs)
3939
!(load_constrained(op, u₁loopsym, u₂loopsym, vloopsym) || check_if_remfirst(ls, ua))
4040
end
4141

42-
struct FalseCollection end
43-
Base.getindex(::FalseCollection, i...) = false
44-
function parent_unroll_status(op::Operation, u₁loop::Symbol)
45-
map(opp -> isunrolled_sym(opp, u₁loop), parents(op)), fill(false, length(parents(op)))
42+
function parent_unroll_status(op::Operation, u₁loop::Symbol, us::UnrollSpecification)
43+
parentsop = parents(op)
44+
u2 = fill(false, length(parentsop))
45+
u1 = similar(u2)
46+
for i eachindex(parentsop)
47+
u1[i] = isunrolled_sym(parentsop[i], u₁loop, us)
48+
end
49+
u1, u2
4650
end
47-
function parent_unroll_status(op::Operation, u₁loop::Symbol, u₂loop::Symbol, vloop::Symbol, u₂max::Int)
48-
# u₂max ≥ 0 || return parent_unroll_status(op, u₁loop)
49-
u₂max == -1 && return parent_unroll_status(op, u₁loop)
51+
function parent_unroll_status(op::Operation, u₁loop::Symbol, u₂loop::Symbol, vloop::Symbol, u₂max::Int, us::UnrollSpecification)
52+
u₂max == -1 && return parent_unroll_status(op, u₁loop, us)
5053
vparents = parents(op);
5154
# parent_names = Vector{Symbol}(undef, length(vparents))
5255
parents_u₁syms = Vector{Bool}(undef, length(vparents))
5356
parents_u₂syms = Vector{Bool}(undef, length(vparents))
5457
for i eachindex(vparents)
55-
parents_u₁syms[i], parents_u₂syms[i] = isunrolled_sym(vparents[i], u₁loop, u₂loop, vloop)#, u₂max)
58+
parents_u₁syms[i], parents_u₂syms[i] = isunrolled_sym(vparents[i], u₁loop, u₂loop, vloop, us)#, u₂max)
5659
end
5760
# parent_names, parents_u₁syms, parents_u₂syms
5861
parents_u₁syms, parents_u₂syms
@@ -100,7 +103,10 @@ vecunrolllen(::Type{VecUnroll{N,W,T,V}}) where {N,W,T,V} = (N::Int + 1)
100103
vecunrolllen(_) = -1
101104
function ifelselastexpr(hasf::Bool, M::Int, vargtypes, K::Int, S::Int, maskearly::Bool)
102105
q = Expr(:block, Expr(:meta,:inline))
103-
vargs = map(k -> Symbol(:varg_,k), 1:K)
106+
vargs = Vector{Symbol}(undef, K)
107+
for k 1:K
108+
vargs[k] = Symbol(:varg_,k)
109+
end
104110
lengths = Vector{Int}(undef, K);
105111
for k 1:K
106112
lengths[k] = l = vecunrolllen(vargtypes[k])
@@ -295,7 +301,7 @@ function parent_op_name(
295301
parent, u
296302
end
297303
function getuouterreduct(ls::LoopSet, op::Operation, suffix)
298-
us = ls.unrollspecification[]
304+
us = ls.unrollspecification
299305
if us.vloopnum === us.u₁loopnum # unroll u₂
300306
suffix
301307
else # unroll u₁
@@ -306,7 +312,7 @@ end
306312
function getu₁full(ls::LoopSet, u₁::Int)
307313
Ureduct = ureduct(ls)
308314
ufull = if Ureduct == -1 # no reducing
309-
ls.unrollspecification[].u₁
315+
ls.unrollspecification.u₁
310316
else
311317
Ureduct
312318
end
@@ -326,9 +332,9 @@ function getu₁forreduct(ls::LoopSet, op::Operation, u₁::Int)
326332
end
327333
if isu₁unrolled(op)
328334
return u₁
329-
elseif (ls.unrollspecification[].u₂ != -1) && length(ls.outer_reductions) > 0
335+
elseif (ls.unrollspecification.u₂ != -1) && length(ls.outer_reductions) > 0
330336
# then `ureduct` doesn't tell us what we need, so
331-
return ls.unrollspecification[].u₁
337+
return ls.unrollspecification.u₁
332338
else # we need to find u₁-full
333339
return getu₁full(ls, u₁)
334340
end
@@ -357,13 +363,12 @@ function lower_compute!(
357363
instr = instruction(op)
358364
parents_op = parents(op)
359365
nparents = length(parents_op)
360-
# __u₂max = ls.unrollspecification[].u₂
366+
# __u₂max = ls.unrollspecification.u₂
361367
# TODO: perhaps allos for swithcing unrolled axis again
362-
# mvar, u₁unrolledsym, u₂unrolledsym = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix)
363-
mvar, u₁unrolledsym, u₂unrolledsym = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, suffix)
368+
mvar, u₁unrolledsym, u₂unrolledsym = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, suffix, ls)
364369
opunrolled = u₁unrolledsym || isu₁unrolled(op)
365-
# parent_names, parents_u₁syms, parents_u₂syms = parent_unroll_status(op, u₁loop, u₂loop, suffix)
366-
parents_u₁syms, parents_u₂syms = parent_unroll_status(op, u₁loopsym, u₂loopsym, vloopsym, u₂max)
370+
us = ls.unrollspecification
371+
parents_u₁syms, parents_u₂syms = parent_unroll_status(op, u₁loopsym, u₂loopsym, vloopsym, u₂max, us)
367372
# tiledouterreduction = if num_loops(ls) == 1# (suffix == -1)# || (vloopsym === u₂loopsym)
368373
tiledouterreduction = if (suffix == -1)# || (vloopsym === u₂loopsym)
369374
suffix_ = Symbol("")
@@ -411,8 +416,8 @@ function lower_compute!(
411416
# if isreduct
412417
# @show u₁unrolledsym, u₂unrolledsym, isu₁unrolled(op), isu₂unrolled(op) op
413418
# end
414-
if Base.libllvm_version < v"11.0.0" && (suffix -1) && isreduct# && (iszero(suffix) || (ls.unrollspecification[].u₂ - 1 == suffix))
415-
# if (length(reduceddependencies(op)) > 0) | (length(reducedchildren(op)) > 0)# && (iszero(suffix) || (ls.unrollspecification[].u₂ - 1 == suffix))
419+
if Base.libllvm_version < v"11.0.0" && (suffix -1) && isreduct# && (iszero(suffix) || (ls.unrollspecification.u₂ - 1 == suffix))
420+
# if (length(reduceddependencies(op)) > 0) | (length(reducedchildren(op)) > 0)# && (iszero(suffix) || (ls.unrollspecification.u₂ - 1 == suffix))
416421
# instrfid = findfirst(isequal(instr.instr), (:vfmadd, :vfnmadd, :vfmsub, :vfnmsub))
417422
instrfid = findfirst(Base.Fix2(===,instr.instr), (:vfmadd_fast, :vfnmadd_fast, :vfmsub_fast, :vfnmsub_fast))
418423
# instrfid = findfirst(isequal(instr.instr), (:vfnmadd_fast, :vfmsub_fast, :vfnmsub_fast))
@@ -449,13 +454,13 @@ function lower_compute!(
449454
# modsuffix = ((u + suffix*(Uiter + 1)) & 7)
450455
isouterreduct = true
451456
# if u₁unrolledsym
452-
# modsuffix = ls.unrollspecification[].u₁#getu₁full(ls, u₁)#u₁
457+
# modsuffix = ls.unrollspecification.u₁#getu₁full(ls, u₁)#u₁
453458
# Symbol(mangledvar(op), '_', modsuffix)
454459
# else
455460
if u₁unrolledsym
456461
modsuffix = 0
457462
else
458-
modsuffix = suffix % ls.ureduct[]
463+
modsuffix = suffix % ls.ureduct
459464
end
460465
Symbol(mangledvar(op), modsuffix)
461466
# end
@@ -538,7 +543,7 @@ function lower_compute!(
538543
# end
539544
# push!(q.args, (isreduct, u₁, (!u₁unrolledsym), isu₁unrolled(op), dopartialmap, varsym))
540545
if maskreduct
541-
ifelsefunc = if ls.unrollspecification[].u₁ == 1
546+
ifelsefunc = if us.u₁ == 1
542547
:ifelse # don't need to be fancy
543548
elseif (u₁loopsym !== vloopsym)
544549
:ifelsepartial # ifelse all the early ones
@@ -574,9 +579,9 @@ function lower_compute!(
574579
make_partial_map!(instrcall, selfopname, u₁, selfdepreduce)
575580
end
576581
elseif selfdep != 0 && (dopartialmap ||
577-
(isouterreduct && (opunrolled) && (u₁ < ls.unrollspecification[].u₁)) ||
582+
(isouterreduct && (opunrolled) && (u₁ < us.u₁)) ||
578583
(isreduct & (u₁ > 1) & (!u₁unrolledsym) & isu₁unrolled(op))) # TODO: DRY `selfdepreduce` definition
579-
# first possibility (`isouterreduct && opunrolled && (u₁ < ls.unrollspecification[].u₁)`):
584+
# first possibility (`isouterreduct && opunrolled && (u₁ < ls.unrollspecification.u₁)`):
580585
# checks if we're in the "reduct" part of an outer reduction
581586
#
582587
# second possibility (`(isreduct & (u₁ > 1) & (!u₁unrolledsym) & isu₁unrolled(op))`):

src/codegen/lower_constant.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,7 @@ function lower_zero!(
5050
q::Expr, op::Operation, ls::LoopSet, ua::UnrollArgs, zerotyp::NumberType = zerotype(ls, op)
5151
)
5252
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = ua
53-
# mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix)
54-
mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, suffix)
53+
mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, suffix, ls)
5554
!opu₂ && suffix > 0 && return
5655
# TODO: for u₁, needs to consider if reducedchildren are u₁-unrolled
5756
# reductions need to consider reduct-status
@@ -98,8 +97,7 @@ function lower_constant!(
9897
q::Expr, op::Operation, ls::LoopSet, ua::UnrollArgs
9998
)
10099
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = ua
101-
# mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix)
102-
mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, suffix)
100+
mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, suffix, ls)
103101
!opu₂ && suffix > 0 && return
104102
mvar = Symbol(mvar, '_', Core.ifelse(opu₁, u₁, 1))
105103
instruction = op.instruction

0 commit comments

Comments
 (0)