Skip to content

Commit d19bffb

Browse files
committed
static reduction cleanup
1 parent 415167c commit d19bffb

File tree

3 files changed

+82
-64
lines changed

3 files changed

+82
-64
lines changed

src/codegen/lower_compute.jl

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,6 @@ end
234234
push!(q.args, :($gf(vargs, $k, false)))
235235
end
236236
return Expr(:block, Expr(:meta, :inline), q)
237-
# return Expr(:block, Expr(:meta, :inline), :(@show($q)))
238237
end
239238
if Sreduced
240239
M = N
@@ -273,7 +272,6 @@ end
273272
push!(t.args, :($gf(dd, $m, false)))
274273
end
275274
push!(q.args, :(VecUnroll($t)))
276-
# push!(q.args, :(@show(VecUnroll($t))))
277275
q
278276
end
279277

@@ -366,20 +364,11 @@ function getu₁forreduct(ls::LoopSet, op::Operation, u₁::Int)
366364
end
367365
isidentityop(op::Operation) = iscompute(op) && (instruction(op).instr === :identity) && (length(parents(op)) == 1)
368366
function reduce_parent!(q::Expr, ls::LoopSet, op::Operation, opp::Operation, parent::Symbol)
369-
# if instruction(op).instr === :log_fast
370-
# @show op opp isvectorized(op) isvectorized(opp) dependent_outer_reducts(ls, op)
371-
# end
372367
isvectorized(op) && return parent
373-
# if dependent_outer_reducts(ls, op)
374-
375-
# return parent
376-
# end
377-
# @show op opp isvectorized(opp)
378368
if isvectorized(opp)
379369
oppt = opp
380370
elseif isidentityop(opp)
381371
oppt = parents(opp)[1]
382-
# @show oppt
383372
isvectorized(oppt) || return parent
384373
else
385374
return parent
@@ -510,13 +499,6 @@ function lower_compute!(
510499
modsuffix = suffix % ls.ureduct
511500
Symbol(mangledvar(op), modsuffix)
512501
end
513-
# @show op, u₁unrolledsym, u₂unrolledsym
514-
# end
515-
# dopartialmap = u₁ > 1
516-
517-
# Symbol(mvar, modsuffix)
518-
# elseif u₁unrolledsym
519-
# Symbol(mvar, u)
520502
elseif u₁unrolledsym
521503
if isreduct #(isanouterreduction(ls, op))
522504
# isouterreduct = true
@@ -532,7 +514,6 @@ function lower_compute!(
532514
end
533515
selfopname = varsym
534516
selfdep = 0
535-
# @show u₁unrolledsym
536517
for n 1:nparents
537518
opp = parents_op[n]
538519
if isloopvalue(opp)
@@ -544,7 +525,6 @@ function lower_compute!(
544525
(parents_u₁syms[n] != u₁unrolledsym) || (parents_u₂syms[n] != u₂unrolledsym)
545526

546527
selfopname, uₚ = parent_op_name!(q, ls, parents_op, n, modsuffix, suffix_, parents_u₁syms, parents_u₂syms, u₁, u₂max, u₂unrolledsym, op, tiledouterreduction)
547-
# @show selfopname, uₚ, tiledouterreduction opp op
548528
push!(instrcall.args, selfopname)
549529
else
550530
push!(instrcall.args, varsym)
@@ -574,7 +554,6 @@ function lower_compute!(
574554
end
575555
end
576556
selfdepreduce = ifelse(((!u₁unrolledsym) & isu₁unrolled(op)) & (u₁ > 1), selfdep, 0)
577-
# @show selfdepreduce, selfdep, maskreduct, op
578557
if maskreduct
579558
ifelsefunc = if us.u₁ == 1
580559
:ifelse # don't need to be fancy

src/codegen/lowering.jl

Lines changed: 60 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -292,19 +292,19 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
292292
end
293293
end
294294
q = if unsigned(Ureduct) < unsigned(UF) # unsigned(-1) == typemax(UInt);
295-
add_cleanup = true
295+
add_cleanup = !loopisstatic# true
296296
if isone(Ureduct)
297297
UF_cleanup = 1
298298
if nisvectorized
299299
blockhead = :while
300300
else
301301
blockhead = if UF == 2
302-
if loopisstatic
303-
add_cleanup = UFt == 1
304-
:block
305-
else
306-
:if
307-
end
302+
if loopisstatic
303+
# add_cleanup = UFt == 1
304+
:block
305+
else
306+
:if
307+
end
308308
else
309309
:while
310310
end
@@ -319,6 +319,12 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
319319
end
320320
_q = if dynamicbounded
321321
initialize_outer_reductions!(q, ls, Ureduct); q
322+
elseif loopisstatic
323+
if length(loop) < UF*W
324+
Expr(:block)
325+
else
326+
Expr(:block, add_upper_outer_reductions(ls, q, Ureduct, UF, loop, nisvectorized))
327+
end
322328
else
323329
Expr(:block, add_upper_outer_reductions(ls, q, Ureduct, UF, loop, nisvectorized))
324330
end
@@ -587,6 +593,12 @@ function add_upper_outer_reductions(ls::LoopSet, loopq::Expr, Ulow::Int, Uhigh::
587593
end
588594
push!(ifq.args, t)
589595
ifqfull = Expr(:let, ifqlet, ifq)
596+
if isstaticloop(unrolledloop)
597+
W = Core.ifelse(reductisvectorized, ls.vector_width, 1)
598+
if Uhigh*W*gethint(step(unrolledloop)) length(unrolledloop)
599+
return Expr(:(=), mvartl, ifqfull)
600+
end
601+
end
590602
ncomparison = if reductisvectorized
591603
add_upper_comp_check(unrolledloop, mulexpr(VECTORWIDTHSYMBOL, Uhigh, step(unrolledloop)))
592604
elseif isknown(step(unrolledloop))
@@ -848,44 +860,49 @@ end
848860
# cld(u₂, cld(u₂, unroll))
849861
# end
850862
function calc_Ureduct!(ls::LoopSet, us::UnrollSpecification)
851-
@unpack u₁loopnum, u₁, u₂, vloopnum = us
852-
ur = if iszero(length(ls.outer_reductions))
853-
-1
854-
elseif u₂ == -1
855-
if u₁loopnum == num_loops(ls)
856-
loopisstatic = isstaticloop(getloop(ls, u₁loopnum))
857-
loopisstatic &= ((vloopnum != u₁loopnum) | (!iszero(ls.vector_width)))
858-
# loopisstatic ? u₁ : min(u₁, 4) # much worse than the other two options, don't use this one
859-
if Sys.CPU_NAME === "znver1"
860-
loopisstatic ? u₁ : 1
861-
else
862-
loopisstatic ? u₁ : (u₁ 4 ? 2 : 1)
863-
end
864-
else
865-
-1
866-
end
863+
@unpack u₁loopnum, u₁, u₂, vloopnum = us
864+
ur = if iszero(length(ls.outer_reductions))
865+
-1
866+
elseif u₂ == -1
867+
if u₁loopnum == num_loops(ls)
868+
u₁loop = getloop(ls, u₁loopnum)
869+
loopisstatic = isstaticloop(u₁loop)
870+
loopisstatic &= ((vloopnum != u₁loopnum) | (!iszero(ls.vector_width)))
871+
# loopisstatic ? u₁ : min(u₁, 4) # much worse than the other two options, don't use this one
872+
if loopisstatic
873+
W = Core.ifelse(vloopnum == u₁loopnum, ls.vector_width, 1)
874+
UFt = cld(length(u₁loop) % (W*u₁), W)
875+
Core.ifelse(UFt == 0, u₁, UFt)
876+
# rem = length(u₁loop) -
877+
# max(1, cld(rem, u₁))
878+
else
879+
Core.ifelse(Sys.CPU_NAME === "znver1", 1, Core.ifelse(u₁ 4, 2, 1))
880+
end
867881
else
868-
u₁ui = u₂ui = -1
869-
u₁loopsym = getloop(ls, u₁loopnum).itersymbol
870-
u₂loopsym = getloop(ls, us.u₂loopnum).itersymbol
871-
vloopsym = getloop(ls, vloopnum).itersymbol
872-
for or ls.outer_reductions
873-
op = ls.operations[or]
874-
u₁u, u₂u = isunrolled_sym(op, u₁loopsym, u₂loopsym, vloopsym, us)
875-
if u₁ui == -1
876-
u₁ui = Int(u₁u)
877-
u₂ui = Int(u₁u)
878-
elseif !((u₁ui == Int(u₁u)) & (u₂ui == Int(u₁u)))
879-
throw(ArgumentError("Doesn't currenly handle differently unrolled reductions yet, please file an issue with an example."))
880-
end
881-
end
882-
if u₁ui % Bool
883-
u₁
884-
else
885-
u₂
886-
end
882+
-1
887883
end
888-
ls.ureduct = ur
884+
else
885+
u₁ui = u₂ui = -1
886+
u₁loopsym = getloop(ls, u₁loopnum).itersymbol
887+
u₂loopsym = getloop(ls, us.u₂loopnum).itersymbol
888+
vloopsym = getloop(ls, vloopnum).itersymbol
889+
for or ls.outer_reductions
890+
op = ls.operations[or]
891+
u₁u, u₂u = isunrolled_sym(op, u₁loopsym, u₂loopsym, vloopsym, us)
892+
if u₁ui == -1
893+
u₁ui = Int(u₁u)
894+
u₂ui = Int(u₁u)
895+
elseif !((u₁ui == Int(u₁u)) & (u₂ui == Int(u₁u)))
896+
throw(ArgumentError("Doesn't currenly handle differently unrolled reductions yet, please file an issue with an example."))
897+
end
898+
end
899+
if u₁ui % Bool
900+
u₁
901+
else
902+
u₂
903+
end
904+
end
905+
ls.ureduct = ur
889906
end
890907
ureduct(ls::LoopSet) = ls.ureduct
891908
function lower_unrollspec(ls::LoopSet)

test/staticsize.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,21 @@ function update_turbo!(B⁻¹yₖ, B⁻¹, yₖ, sₖᵀyₖ⁻¹)
7878
yₖᵀB⁻¹yₖ
7979
end
8080

81+
function maxabs(x)
82+
s = -Inf
83+
@turbo for i eachindex(x)
84+
s = max(s, abs(x[i]))
85+
end
86+
s
87+
end
88+
function sum_turbo(x)
89+
s = zero(eltype(x))
90+
@turbo for i eachindex(x)
91+
s += x[i]
92+
end
93+
s
94+
end
95+
8196
@testset "Statically Sized Arrays" begin
8297
@show @__LINE__
8398
for n1 1:MAXTESTSIZE, n3 1:MAXTESTSIZE
@@ -97,6 +112,13 @@ end
97112
@test update_turbo!(By0, output1, y, 0.124) update!(By1, output1, y, 0.124)
98113
@test By0 By1
99114
end
115+
116+
end
117+
for i in 1:65
118+
x = StrideArray(undef, StaticInt(i))
119+
x .= randn.();
120+
@test maxabs(x) == maximum(abs, x)
121+
@test sum_turbo(x) sum(x)
100122
end
101123
end
102124

0 commit comments

Comments
 (0)