Skip to content

Commit 415167c

Browse files
committed
Demote max single unroll static loop unroll
1 parent 4d6bc23 commit 415167c

File tree

2 files changed

+74
-62
lines changed

2 files changed

+74
-62
lines changed

src/modeling/determinestrategy.jl

Lines changed: 68 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,9 @@ function unroll_no_reductions(ls, order, vloopsym)
269269
# isstore(op) && isu₁unrolled(op)
270270
# end
271271
# end
272-
if unrolled === vloopsym
273-
u = demote_unroll_factor(ls, u, vloopsym)
274-
end
272+
# if unrolled === vloopsym
273+
# u = demote_unroll_factor(ls, u, vloopsym)
274+
# end
275275
remaining_reg = max(8, (reg_count(ls) - round(Int,rpc))) # spilling a few consts isn't so bad
276276
if compute_l 4compute_rt 4rpp
277277
# motivation for skipping division by loads here: https://github.com/microhh/stencilbuilder/blob/master/julia/stencil_julia_4th.jl
@@ -285,7 +285,7 @@ function unroll_no_reductions(ls, order, vloopsym)
285285
else
286286
reg_constraint = max(1, remaining_reg ÷ max(1,round(Int,rpp)))
287287
end
288-
clamp(u, 1, reg_constraint), unrolled
288+
maybe_demote_unroll(ls, clamp(u, 1, reg_constraint), unrolled, vloopsym), unrolled
289289
# rt = max(compute_rt, load_rt + store_rt)
290290
# # (iszero(rt) ? 4 : max(1, roundpow2( min( 4, round(Int, 16 / rt) ) ))), unrolled
291291
# (iszero(rt) ? 4 : max(1, VectorizationBase.nextpow2( min( 4, round(Int, 8 / rt) ) ))), unrolled
@@ -369,6 +369,8 @@ function determine_unroll_factor(ls::LoopSet, order::Vector{Symbol}, vloopsym::S
369369
UF = min(8, VectorizationBase.nextpow2(max(1, round(Int, ltemp / (rttemp) ) )))
370370
UFfactor = 8 ÷ ls.vector_width
371371
cld(UF, UFfactor)*UFfactor, vloopsym
372+
# UF2 = cld(UF, UFfactor)*UFfactor, vloopsym
373+
# maybe_demote_unroll(ls, UF2, vloopsym, vloopsym), vloopsym
372374
end
373375
end
374376
# function scale_unrolled()
@@ -394,11 +396,17 @@ function determine_unroll_factor(ls::LoopSet, order::Vector{Symbol}, vloopsym::S
394396
else
395397
UF = VectorizationBase.nextpow2(round(Int, clamp(lrtratio, 1.0, 4.0), RoundUp))
396398
end
397-
if best_unrolled === vloopsym
398-
UF = demote_unroll_factor(ls, UF, vloopsym)
399-
end
399+
UF = maybe_demote_unroll(ls, UF, best_unrolled, vloopsym)
400400
UF, best_unrolled
401401
end
402+
function maybe_demote_unroll(ls::LoopSet, UF::Int, unrollsym::Symbol, vloopsym::Symbol)::Int
403+
if unrollsym === vloopsym
404+
return demote_unroll_factor(ls, UF, vloopsym)
405+
else
406+
ul = getloop(ls, unrollsym)
407+
isstaticloop(ul) ? min(length(ul), UF) : UF
408+
end
409+
end
402410

403411
@inline function unroll_cost(X, u₁, u₂, u₁L, u₂L)
404412
u₂factor = (num_iterations(u₂L, u₂)/u₂L)
@@ -553,64 +561,64 @@ function solve_unroll(
553561
end
554562

555563
function solve_unroll(
556-
u₁loopsym::Symbol, u₂loopsym::Symbol,
557-
cost_vec::AbstractVector{Float64},
558-
reg_pressure::AbstractVector{Float64},
559-
W::Int, vloopsym::Symbol,
560-
u₁loop::Loop, u₂loop::Loop,
561-
u₁step::Int, u₂step::Int,
562-
atleast31registers::Bool
564+
u₁loopsym::Symbol, u₂loopsym::Symbol,
565+
cost_vec::AbstractVector{Float64},
566+
reg_pressure::AbstractVector{Float64},
567+
W::Int, vloopsym::Symbol,
568+
u₁loop::Loop, u₂loop::Loop,
569+
u₁step::Int, u₂step::Int,
570+
atleast31registers::Bool
563571
)
564-
maxu₂base = maxu₁base = atleast31registers ? 10 : 6#8
565-
maxu₂ = maxu₂base#8
566-
maxu₁ = maxu₁base#8
567-
u₁L = length(u₁loop)
568-
u₂L = length(u₂loop)
569-
if isstaticloop(u₂loop)
570-
if u₂loopsym !== vloopsym && u₂L 4
571-
if isstaticloop(u₁loop)
572-
u₁ = max(solve_unroll_constT(reg_pressure, u₂L), 1)
573-
u₁ = maybedemotesize(u₁, u₁loopsym === vloopsym ? cld(u₁L,W) : u₁L)
574-
else
575-
u₁ = clamp(solve_unroll_constT(reg_pressure, u₂L), 1, 8)
576-
end
577-
return u₁, u₂L, unroll_cost(cost_vec, u₁, u₂L, u₁L, u₂L)
572+
maxu₂base = maxu₁base = atleast31registers ? 10 : 6#8
573+
maxu₂ = maxu₂base#8
574+
maxu₁ = maxu₁base#8
575+
u₁L = length(u₁loop)
576+
u₂L = length(u₂loop)
577+
if isstaticloop(u₂loop)
578+
if u₂loopsym !== vloopsym && u₂L 4
579+
if isstaticloop(u₁loop)
580+
u₁ = max(solve_unroll_constT(reg_pressure, u₂L), 1)
581+
u₁ = maybedemotesize(u₁, u₁loopsym === vloopsym ? cld(u₁L,W) : u₁L)
582+
else
583+
u₁ = clamp(solve_unroll_constT(reg_pressure, u₂L), 1, 8)
578584
end
579-
u₂Ltemp = u₂loopsym === vloopsym ? cld(u₂L, W) : u₂L
580-
maxu₂ = min(4maxu₂, u₂Ltemp)
585+
return u₁, u₂L, unroll_cost(cost_vec, u₁, u₂L, u₁L, u₂L)
581586
end
582-
if isstaticloop(u₁loop)
583-
if u₁loopsym !== vloopsym && u₁L 4
584-
if isstaticloop(u₂loop)
585-
u₂ = max(solve_unroll_constU(reg_pressure, u₁L), 1)
586-
u₂ = maybedemotesize(u₂, u₂loopsym === vloopsym ? cld(u₂L,W) : u₂L)
587-
else
588-
u₂ = clamp(solve_unroll_constU(reg_pressure, u₁L), 1, 8)
589-
end
590-
return u₁L, u₂, unroll_cost(cost_vec, u₁L, u₂, u₁L, u₂L)
587+
u₂Ltemp = u₂loopsym === vloopsym ? cld(u₂L, W) : u₂L
588+
maxu₂ = min(4maxu₂, u₂Ltemp)
589+
end
590+
if isstaticloop(u₁loop)
591+
if u₁loopsym !== vloopsym && u₁L 4
592+
if isstaticloop(u₂loop)
593+
u₂ = max(solve_unroll_constU(reg_pressure, u₁L), 1)
594+
u₂ = maybedemotesize(u₂, u₂loopsym === vloopsym ? cld(u₂L,W) : u₂L)
595+
else
596+
u₂ = clamp(solve_unroll_constU(reg_pressure, u₁L), 1, 8)
591597
end
592-
u₁Ltemp = u₁loopsym === vloopsym ? cld(u₁L, W) : u₁L
593-
maxu₁ = min(4maxu₁, u₁Ltemp)
598+
return u₁L, u₂, unroll_cost(cost_vec, u₁L, u₂, u₁L, u₂L)
594599
end
595-
if u₁loopsym === vloopsym
596-
u₁Lf = u₁L / W
597-
else
598-
u₁Lf = Float64(u₁L)
599-
end
600-
if u₂loopsym === vloopsym
601-
u₂Lf = u₂L / W
602-
else
603-
u₂Lf = Float64(u₂L)
604-
end
605-
u₁, u₂, cost = solve_unroll(cost_vec, reg_pressure, maxu₁, maxu₂, u₁Lf, u₂Lf, u₁step, u₂step, atleast31registers)
606-
# heuristic to more evenly divide small numbers of iterations
607-
if isstaticloop(u₂loop)
608-
u₂ = maybedemotesize(u₂, length(u₂loop), u₁, u₁loop, maxu₂base)
609-
end
610-
if isstaticloop(u₁loop)
611-
u₁ = maybedemotesize(u₁, length(u₁loop), u₂, u₂loop, maxu₁base)
612-
end
613-
u₁, u₂, cost
600+
u₁Ltemp = u₁loopsym === vloopsym ? cld(u₁L, W) : u₁L
601+
maxu₁ = min(4maxu₁, u₁Ltemp)
602+
end
603+
if u₁loopsym === vloopsym
604+
u₁Lf = u₁L / W
605+
else
606+
u₁Lf = Float64(u₁L)
607+
end
608+
if u₂loopsym === vloopsym
609+
u₂Lf = u₂L / W
610+
else
611+
u₂Lf = Float64(u₂L)
612+
end
613+
u₁, u₂, cost = solve_unroll(cost_vec, reg_pressure, maxu₁, maxu₂, u₁Lf, u₂Lf, u₁step, u₂step, atleast31registers)
614+
# heuristic to more evenly divide small numbers of iterations
615+
if isstaticloop(u₂loop)
616+
u₂ = maybedemotesize(u₂, length(u₂loop), u₁, u₁loop, maxu₂base)
617+
end
618+
if isstaticloop(u₁loop)
619+
u₁ = maybedemotesize(u₁, length(u₁loop), u₂, u₂loop, maxu₁base)
620+
end
621+
u₁, u₂, cost
614622
end
615623

616624
function set_upstream_family!(adal::Vector{T}, op::Operation, val::T) where {T}

test/staticsize.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ end
8181
@testset "Statically Sized Arrays" begin
8282
@show @__LINE__
8383
for n1 1:MAXTESTSIZE, n3 1:MAXTESTSIZE
84+
# @show n1, n3
8485
output1 = StrideArray(undef, StaticInt(n1), StaticInt(n3))
8586
output2 = StrideArray(undef, StaticInt(n1), StaticInt(n3))
8687
output3 = StrideArray(undef, StaticInt(n1), StaticInt(n3))
@@ -91,8 +92,11 @@ end
9192
y = StrideArray(undef, StaticInt(n1)); y .= rand.();
9293
By0 = StrideArray(undef, StaticInt(n3))
9394
By1 = StrideArray(undef, StaticInt(n3))
94-
@test update_turbo!(By0, output1, y, 0.124) update!(By1, output1, y, 0.124)
95-
@test By0 By1
95+
GC.@preserve By0 By1 output1 y begin
96+
# @show update!(Vector{Float64}(undef, n3), output1, y, 0.124)
97+
@test update_turbo!(By0, output1, y, 0.124) update!(By1, output1, y, 0.124)
98+
@test By0 By1
99+
end
96100
end
97101
end
98102

0 commit comments

Comments
 (0)