Skip to content

Commit 8b5f640

Browse files
committed
Some more of the endless stream of unrolling naming fixes...
1 parent 905ee6c commit 8b5f640

File tree

8 files changed

+197
-138
lines changed

8 files changed

+197
-138
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ SLEEFPirates = "0.6.14"
2828
Static = "0.2"
2929
ThreadingUtilities = "0.4.1"
3030
UnPack = "1"
31-
VectorizationBase = "0.19.29"
31+
VectorizationBase = "0.19.30"
3232
julia = "1.5"
3333

3434
[extras]

src/codegen/lower_compute.jl

Lines changed: 45 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -274,31 +274,47 @@ end
274274
q
275275
end
276276

277-
function parent_op_name(
278-
ls::LoopSet, parents_op::Vector{Operation}, n::Int, modsuffix, suffix_, parents_u₁syms, parents_u₂syms, u₁, opisvectorized, tiledouterreduction
277+
function parent_op_name!(
278+
q, ls::LoopSet, parents_op::Vector{Operation}, n::Int, modsuffix, suffix_, parents_u₁syms, parents_u₂syms, u₁, u₂max, u₂unrolledsym, op, tiledouterreduction
279279
)
280-
opp = parents_op[n]
281-
parent = mangledvar(opp)
282-
u = 0
283-
if n == tiledouterreduction# && isvectorized(opp)
284-
parent = Symbol(parent, modsuffix)
280+
opp = parents_op[n]
281+
opisvectorized = isvectorized(op)
282+
parent = mangledvar(opp)
283+
u = 0
284+
if n == tiledouterreduction# && isvectorized(opp)
285+
parent = Symbol(parent, modsuffix)
286+
else
287+
u = if !parents_u₁syms[n]
288+
1
289+
elseif isouterreduction(ls, opp) -1
290+
getu₁full(ls, u₁)
285291
else
286-
if parents_u₂syms[n]
287-
parent = Symbol(parent, suffix_)
288-
end
289-
u = if !parents_u₁syms[n]
290-
1
291-
elseif isouterreduction(ls, opp) -1
292-
getu₁full(ls, u₁)
293-
else
294-
getu₁forreduct(ls, opp, u₁)
295-
end
296-
parent = Symbol(parent, '_', u)
292+
getu₁forreduct(ls, opp, u₁)
297293
end
298-
if opisvectorized && isload(opp) && (!isvectorized(opp))
299-
parent = Symbol(parent, "##broadcasted##")
294+
if parents_u₂syms[n]
295+
if isu₂unrolled(op) # u₂unrolledsym ||
296+
parent = Symbol(parent, suffix_, '_', u)
297+
elseif u₂max > 1
298+
t = Expr(:tuple)
299+
reduction = Expr(:call, GlobalRef(ArrayInterface, :reduce_tup), reduce_to_onevecunroll(instruction(opp)), t)
300+
for u₂ 0:u₂max-1
301+
push!(t.args, Symbol(parent, u₂, "__", u))
302+
end
303+
parent = gensym!(ls, parent)
304+
push!(q.args, Expr(:(=), parent, reduction))
305+
parent
306+
else
307+
# parent = Symbol(parent, '_', u)
308+
parent = Symbol(parent, 0, "__", u)
309+
end
310+
else
311+
parent = Symbol(parent, '_', u)
300312
end
301-
parent, u
313+
end
314+
if opisvectorized && isload(opp) && (!isvectorized(opp))
315+
parent = Symbol(parent, "##broadcasted##")
316+
end
317+
parent, u
302318
end
303319
function getuouterreduct(ls::LoopSet, op::Operation, suffix)
304320
us = ls.unrollspecification
@@ -413,9 +429,6 @@ function lower_compute!(
413429
# parentsyms = [opp.variable for opp ∈ parents(op)]
414430
Uiter = opunrolled ? u₁ - 1 : 0
415431
isreduct = isreduction(op)
416-
# if isreduct
417-
# @show u₁unrolledsym, u₂unrolledsym, isu₁unrolled(op), isu₂unrolled(op) op
418-
# end
419432
if Base.libllvm_version < v"11.0.0" && (suffix -1) && isreduct# && (iszero(suffix) || (ls.unrollspecification.u₂ - 1 == suffix))
420433
# if (length(reduceddependencies(op)) > 0) | (length(reducedchildren(op)) > 0)# && (iszero(suffix) || (ls.unrollspecification.u₂ - 1 == suffix))
421434
# instrfid = findfirst(isequal(instr.instr), (:vfmadd, :vfnmadd, :vfmsub, :vfnmsub))
@@ -474,7 +487,6 @@ function lower_compute!(
474487
# isouterreduct = true
475488
isouterreduct = isanouterreduction(ls, op)
476489
u₁reduct = isouterreduct ? getu₁full(ls, u₁) : getu₁forreduct(ls, op, u₁)
477-
# @show isouterreduct, u₁reduct, op
478490
dopartialmap = u₁reduct u₁
479491
Symbol(mvar, '_', u₁reduct)
480492
else
@@ -484,47 +496,35 @@ function lower_compute!(
484496
Symbol(mvar, '_', 1)
485497
end
486498
selfopname = varsym
487-
selfdep = 0
499+
selfdep = 0
488500
for n 1:nparents
489501
opp = parents_op[n]
490502
if isloopvalue(opp)
491503
loopval = first(loopdependencies(opp))
492504
add_loopvalue!(instrcall, loopval, ua, u₁)
493505
elseif name(opp) === name(op)
506+
494507
selfdep = n
495-
# @show mangledvar(op), name(opp), name(op)
496508
if ((isvectorized(opp) && !isvectorized(op))) ||
497509
(parents_u₁syms[n] != u₁unrolledsym) || (parents_u₂syms[n] != u₂unrolledsym)
498-
499-
selfopname, uₚ = parent_op_name(ls, parents_op, n, modsuffix, suffix_, parents_u₁syms, parents_u₂syms, u₁, opisvectorized, tiledouterreduction)
500-
# if (uₚ ≠ 0) & (uₚ ≠ u₁)
501-
# dopartialmap = true
502-
# end
503-
# @show selfopname, instr
504-
push!(instrcall.args, selfopname)
510+
511+
selfopname, uₚ = parent_op_name!(q, ls, parents_op, n, modsuffix, suffix_, parents_u₁syms, parents_u₂syms, u₁, u₂max, u₂unrolledsym, op, tiledouterreduction)
512+
push!(instrcall.args, selfopname)
505513
else
506-
push!(instrcall.args, varsym)
514+
push!(instrcall.args, varsym)
507515
end
508516
elseif ((!isu₂unrolled(op)) & isu₂unrolled(opp)) && (parents_u₂syms[n] & (!u₂unrolledsym))
509-
# elseif parents_u₂syms[n] & (!u₂unrolledsym)
517+
# elseif parents_u₂syms[n] & (!u₂unrolledsym)
510518
#&& (isouterreduction(ls, opp) != -1)
511519
# this checks if the parent is u₂ unrolled but this operation is not, in which case we need to reduce it.
512-
# @show op opp
513520
reduced_u₂ = reduce_expr_u₂(mangledvar(opp), instruction(opp), ureduct(ls))
514521
reducedparentname = gensym!(ls, "reducedop")
515522
push!(q.args, Expr(:(=), reducedparentname, reduced_u₂))
516523
reduced_u₂ = reduce_parent!(q, ls, op, opp, reducedparentname)
517524
push!(instrcall.args, reduced_u₂)
518525
else
519-
parent, uₚ = parent_op_name(ls, parents_op, n, modsuffix, suffix_, parents_u₁syms, parents_u₂syms, u₁, opisvectorized, tiledouterreduction)
526+
parent, uₚ = parent_op_name!(q, ls, parents_op, n, modsuffix, suffix_, parents_u₁syms, parents_u₂syms, u₁, u₂max, u₂unrolledsym, op, tiledouterreduction)
520527
parent = reduce_parent!(q, ls, op, opp, parent)
521-
# if instr.instr === :vfmadd_fast && tiledouterreduction > 0
522-
# @show mvar, varsym, selfopname
523-
# end
524-
# @show opp
525-
# if instr.instr === :identity
526-
# @show isvectorized(op) isvectorized(opp)
527-
# end
528528
if (selfdep == 0) && search_tree(parents(opp), name(op))
529529
selfdep = n
530530
push!(instrcall.args, parent)
@@ -536,12 +536,6 @@ function lower_compute!(
536536
end
537537
end
538538
selfdepreduce = ifelse(((!u₁unrolledsym) & isu₁unrolled(op)) & (u₁ > 1), selfdep, 0)
539-
# if selfdep ≠ 0
540-
# @show mvar
541-
# # @show isu₁unrolled(op), u₁unrolledsym, u₁, u₂max
542-
# # @show selfdep, selfdepreduce#, op
543-
# end
544-
# push!(q.args, (isreduct, u₁, (!u₁unrolledsym), isu₁unrolled(op), dopartialmap, varsym))
545539
if maskreduct
546540
ifelsefunc = if us.u₁ == 1
547541
:ifelse # don't need to be fancy
@@ -575,7 +569,6 @@ function lower_compute!(
575569
end
576570
return
577571
elseif selfdep != 0
578-
# @show op, isouterreduct, maskreduct, instr
579572
make_partial_map!(instrcall, selfopname, u₁, selfdepreduce)
580573
end
581574
elseif selfdep != 0 && (dopartialmap ||

src/codegen/lower_constant.jl

Lines changed: 89 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -48,37 +48,41 @@ end
4848
function lower_zero!(
4949
q::Expr, op::Operation, ls::LoopSet, ua::UnrollArgs, zerotyp::NumberType = zerotype(ls, op)
5050
)
51-
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = ua
52-
mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, suffix, ls)
53-
!opu₂ && suffix > 0 && return
54-
# TODO: for u₁, needs to consider if reducedchildren are u₁-unrolled
55-
# reductions need to consider reduct-status
56-
# if !opu₁
57-
# opu₁ = u₁loopsym ∈ reducedchildren(op)
58-
# end
59-
mvar = Symbol(mvar, '_', Core.ifelse(opu₁, u₁, 1))
60-
typeT = typeof_sym(ls, op, zerotyp)
61-
# TODO: make should_broadcast_op handle everything.
62-
if isvectorized(op) || vloopsym reducedchildren(op) || vloopsym reduceddependencies(op) || should_broadcast_op(op)
63-
if opu₁ && u₁ > 1
64-
call = Expr(:call, lv(:zero_vecunroll), staticexpr(u₁), VECTORWIDTHSYMBOL, typeT, staticexpr(reg_size(ls)))
65-
else
66-
call = Expr(:call, lv(:_vzero), VECTORWIDTHSYMBOL, typeT, staticexpr(reg_size(ls)))
67-
end
51+
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = ua
52+
mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, suffix, ls)
53+
!opu₂ && suffix > 0 && return
54+
# TODO: for u₁, needs to consider if reducedchildren are u₁-unrolled
55+
# reductions need to consider reduct-status
56+
# if !opu₁
57+
# opu₁ = u₁loopsym ∈ reducedchildren(op)
58+
# end
59+
typeT = typeof_sym(ls, op, zerotyp)
60+
# TODO: make should_broadcast_op handle everything.
61+
if isvectorized(op) || vloopsym reducedchildren(op) || vloopsym reduceddependencies(op) || should_broadcast_op(op)
62+
if opu₁ && u₁ > 1
63+
call = Expr(:call, lv(:zero_vecunroll), staticexpr(u₁), VECTORWIDTHSYMBOL, typeT, staticexpr(reg_size(ls)))
6864
else
69-
call = Expr(:call, :zero, typeT)
70-
if opu₁ && u₁ > 1
71-
# broadcastsym = Symbol(mvar, "_#init#")
72-
# pushpreamble!(ls, Expr(:(=), broadcastsym, call))
73-
t = Expr(:tuple)
74-
for u 1:u₁
75-
push!(t.args, call)
76-
end
77-
call = Expr(:call, lv(:VecUnroll), t)
78-
end
65+
call = Expr(:call, lv(:_vzero), VECTORWIDTHSYMBOL, typeT, staticexpr(reg_size(ls)))
66+
end
67+
else
68+
call = Expr(:call, :zero, typeT)
69+
if opu₁ && u₁ > 1
70+
t = Expr(:tuple)
71+
for u 1:u₁
72+
push!(t.args, call)
73+
end
74+
call = Expr(:call, lv(:VecUnroll), t)
75+
end
76+
end
77+
if (suffix == -1) && opu₂
78+
for u 0:u₂max-1
79+
push!(q.args, Expr(:(=), Symbol(mvar, u, "__", Core.ifelse(opu₁, u₁, 1)), call))
7980
end
81+
else
82+
mvar = Symbol(mvar, '_', Core.ifelse(opu₁, u₁, 1))
8083
push!(q.args, Expr(:(=), mvar, call))
81-
nothing
84+
end
85+
nothing
8286
end
8387
# Have to awkwardly search through `operations(ls)` to try and find op's child
8488
function getparentsreductzero(ls::LoopSet, op::Operation)::Float64
@@ -95,52 +99,65 @@ vecbasefunc(f) = Expr(:(.), Expr(:(.), :LoopVectorization, QuoteNode(:Vectorizat
9599
function lower_constant!(
96100
q::Expr, op::Operation, ls::LoopSet, ua::UnrollArgs
97101
)
98-
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = ua
99-
mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, suffix, ls)
100-
!opu₂ && suffix > 0 && return
101-
mvar = Symbol(mvar, '_', Core.ifelse(opu₁, u₁, 1))
102-
instruction = op.instruction
103-
constsym = instruction.instr
104-
# constsym = Symbol(instruction.instr, '_', 1)
105-
reducedchildvectorized = vloopsym reducedchildren(op)
106-
if reducedchildvectorized || isvectorized(op) || vloopsym reduceddependencies(op) || should_broadcast_op(op)
107-
# call = Expr(:call, lv(:vbroadcast), W, Expr(:call, lv(:maybeconvert), typeT, constsym))
108-
call = if reducedchildvectorized && vloopsym loopdependencies(op)
109-
instrclass = getparentsreductzero(ls, op)
110-
if instrclass == ADDITIVE_IN_REDUCTIONS
111-
Expr(:call, vecbasefunc(:addscalar), Expr(:call, lv(:vzero), VECTORWIDTHSYMBOL, ELTYPESYMBOL), constsym)
112-
elseif instrclass == MULTIPLICATIVE_IN_REDUCTIONS
113-
Expr(:call, vecbasefunc(:mulscalar), Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, Expr(:call, :one, ELTYPESYMBOL)), constsym)
114-
elseif instrclass == MAX
115-
Expr(:call, vecbasefunc(:maxscalar), Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, Expr(:call, :typemin, ELTYPESYMBOL)), constsym)
116-
elseif instrclass == MIN
117-
Expr(:call, vecbasefunc(:minscalar), Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, Expr(:call, :typemax, ELTYPESYMBOL)), constsym)
118-
else
119-
throw("Reductions of type $(reduction_zero(reinstrclass)) not yet supported; please file an issue as a reminder to take care of this.")
120-
end
121-
else
122-
Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, constsym)
123-
end
124-
if opu₁ && u₁ > 1
125-
# broadcastsym = Symbol(mvar, "_#init#")
126-
# push!(q.args, Expr(:(=), broadcastsym, call))
127-
t = Expr(:tuple)
128-
for u 1:u₁
129-
push!(t.args, call)
130-
end
131-
call = Expr(:call, lv(:VecUnroll), t)
132-
end
133-
push!(q.args, Expr(:(=), mvar, call))
134-
elseif opu₁ && u₁ > 1
135-
t = Expr(:tuple)
136-
for u 1:u₁
137-
push!(t.args, constsym)
138-
end
139-
push!(q.args, Expr(:(=), mvar, Expr(:call, lv(:VecUnroll), t)))
102+
@unpack u₁, u₁loopsym, u₂loopsym, vloopsym, u₂max, suffix = ua
103+
mvar, opu₁, opu₂ = variable_name_and_unrolled(op, u₁loopsym, u₂loopsym, vloopsym, suffix, ls)
104+
!opu₂ && suffix > 0 && return
105+
instruction = op.instruction
106+
constsym = instruction.instr
107+
# constsym = Symbol(instruction.instr, '_', 1)
108+
reducedchildvectorized = vloopsym reducedchildren(op)
109+
if reducedchildvectorized || isvectorized(op) || vloopsym reduceddependencies(op) || should_broadcast_op(op)
110+
# call = Expr(:call, lv(:vbroadcast), W, Expr(:call, lv(:maybeconvert), typeT, constsym))
111+
call = if reducedchildvectorized && vloopsym loopdependencies(op)
112+
instrclass = getparentsreductzero(ls, op)
113+
if instrclass == ADDITIVE_IN_REDUCTIONS
114+
Expr(:call, vecbasefunc(:addscalar), Expr(:call, lv(:vzero), VECTORWIDTHSYMBOL, ELTYPESYMBOL), constsym)
115+
elseif instrclass == MULTIPLICATIVE_IN_REDUCTIONS
116+
Expr(:call, vecbasefunc(:mulscalar), Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, Expr(:call, :one, ELTYPESYMBOL)), constsym)
117+
elseif instrclass == MAX
118+
Expr(:call, vecbasefunc(:maxscalar), Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, Expr(:call, :typemin, ELTYPESYMBOL)), constsym)
119+
elseif instrclass == MIN
120+
Expr(:call, vecbasefunc(:minscalar), Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, Expr(:call, :typemax, ELTYPESYMBOL)), constsym)
121+
else
122+
throw("Reductions of type $(reduction_zero(reinstrclass)) not yet supported; please file an issue as a reminder to take care of this.")
123+
end
140124
else
141-
push!(q.args, Expr(:(=), mvar, constsym))
125+
Expr(:call, lv(:vbroadcast), VECTORWIDTHSYMBOL, constsym)
142126
end
143-
nothing
127+
if opu₁ && u₁ > 1
128+
# broadcastsym = Symbol(mvar, "_#init#")
129+
# push!(q.args, Expr(:(=), broadcastsym, call))
130+
t = Expr(:tuple)
131+
for u 1:u₁
132+
push!(t.args, call)
133+
end
134+
call = Expr(:call, lv(:VecUnroll), t)
135+
end
136+
elseif opu₁ && u₁ > 1
137+
t = Expr(:tuple)
138+
for u 1:u₁
139+
push!(t.args, constsym)
140+
end
141+
call = Expr(:call, lv(:VecUnroll), t)
142+
elseif opu₂ & (suffix == -1)
143+
for u 0:u₂max-1
144+
push!(q.args, Expr(:(=), Symbol(mvar, u, "__", 1), constsym))
145+
end
146+
return nothing
147+
else
148+
push!(q.args, Expr(:(=), Symbol(mvar, '_', 1), constsym))
149+
return nothing
150+
end
151+
u₁tag = Core.ifelse(opu₁, u₁, 1)
152+
if opu₂ & (suffix == -1)
153+
for u 0:u₂max-1
154+
push!(q.args, Expr(:(=), Symbol(mvar, u, "__", u₁tag), call))
155+
end
156+
else
157+
mvar = Symbol(mvar, '_', u₁tag)
158+
push!(q.args, Expr(:(=), mvar, call))
159+
end
160+
nothing
144161
end
145162

146163
isconstantop(op::Operation) = (instruction(op) === LOOPCONSTANT) || (isconstant(op) && length(loopdependencies(op)) == 0)

src/modeling/determinestrategy.jl

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -578,22 +578,30 @@ function solve_unroll(
578578
u₁L = length(u₁loop)
579579
u₂L = length(u₂loop)
580580
if isstaticloop(u₂loop)
581-
if u₂loopsym !== vloopsym && u₂L 4
582-
u₁ = max(1, solve_unroll_constT(reg_pressure, u₂L))
583-
u₁ = isstaticloop(u₁loop) ? maybedemotesize(u₁, u₁loopsym === vloopsym ? cld(u₁L,W) : u₁L) : u₁
584-
return u₁, u₂L, unroll_cost(cost_vec, u₁, u₂L, u₁L, u₂L)
581+
if u₂loopsym !== vloopsym && u₂L 4
582+
if isstaticloop(u₁loop)
583+
u₁ = max(solve_unroll_constT(reg_pressure, u₂L), 1)
584+
u₁ = maybedemotesize(u₁, u₁loopsym === vloopsym ? cld(u₁L,W) : u₁L)
585+
else
586+
u₁ = clamp(solve_unroll_constT(reg_pressure, u₂L), 1, 8)
585587
end
586-
u₂Ltemp = u₂loopsym === vloopsym ? cld(u₂L, W) : u₂L
587-
maxu₂ = min(4maxu₂, u₂Ltemp)
588+
return u₁, u₂L, unroll_cost(cost_vec, u₁, u₂L, u₁L, u₂L)
589+
end
590+
u₂Ltemp = u₂loopsym === vloopsym ? cld(u₂L, W) : u₂L
591+
maxu₂ = min(4maxu₂, u₂Ltemp)
588592
end
589593
if isstaticloop(u₁loop)
590-
if u₁loopsym !== vloopsym && u₁L 4
591-
u₂ = max(1, solve_unroll_constU(reg_pressure, u₁L))
592-
u₂ = isstaticloop(u₂loop) ? maybedemotesize(u₂, u₂loopsym === vloopsym ? cld(u₂L,W) : u₂L) : u₂
593-
return u₁L, u₂, unroll_cost(cost_vec, u₁L, u₂, u₁L, u₂L)
594+
if u₁loopsym !== vloopsym && u₁L 4
595+
if isstaticloop(u₂loop)
596+
u₂ = max(solve_unroll_constU(reg_pressure, u₁L), 1)
597+
u₂ = maybedemotesize(u₂, u₂loopsym === vloopsym ? cld(u₂L,W) : u₂L)
598+
else
599+
u₂ = clamp(solve_unroll_constU(reg_pressure, u₁L), 1, 8)
594600
end
595-
u₁Ltemp = u₁loopsym === vloopsym ? cld(u₁L, W) : u₁L
596-
maxu₁ = min(4maxu₁, u₁Ltemp)
601+
return u₁L, u₂, unroll_cost(cost_vec, u₁L, u₂, u₁L, u₂L)
602+
end
603+
u₁Ltemp = u₁loopsym === vloopsym ? cld(u₁L, W) : u₁L
604+
maxu₁ = min(4maxu₁, u₁Ltemp)
597605
end
598606
if u₁loopsym === vloopsym
599607
u₁Lf = u₁L / W

0 commit comments

Comments
 (0)