Skip to content

Commit 51ed926

Browse files
committed
More 4 to 2 spacing.
1 parent b81bbbf commit 51ed926

File tree

6 files changed

+102
-96
lines changed

6 files changed

+102
-96
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,5 @@ Static = "0.2, 0.3"
3232
StrideArraysCore = "0.2"
3333
ThreadingUtilities = "0.4.5"
3434
UnPack = "1"
35-
VectorizationBase = "0.21.3"
35+
VectorizationBase = "0.21.4"
3636
julia = "1.5"

src/modeling/costs.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -388,8 +388,12 @@ end
388388
end
389389
@inline (ier::IfElseReduceToMirror)(a::VecUnroll, b::VecUnroll) = VecUnroll(VectorizationBase.fmap(ier, VectorizationBase.data(a), VectorizationBase.data(b)))
390390

391-
@inline (iec::IfElseCollapserMirror)(a) = getfield(VectorizationBase.ifelse_collapse_mirror(iec.f, a, iec.a), 1, false)
392-
@inline (iec::IfElseCollapserMirror)(a, ::StaticInt{1}) = getfield(VectorizationBase.ifelse_collapse_mirror(iec.f, a, iec.a), 1, false)
391+
# @inline (iec::IfElseCollapserMirror)(a) = getfield(VectorizationBase.ifelse_collapse_mirror(iec.f, a, iec.a), 1, false)
392+
# @inline (iec::IfElseCollapserMirror)(a, ::StaticInt{N}) where {N} = getfield(VectorizationBase.ifelse_collapse_mirror(iec.f, a, iec.a, StaticInt{N}()), 1, false)
393+
394+
@inline (iec::IfElseCollapserMirror)(a) = VectorizationBase.ifelse_collapse_mirror(iec.f, a, iec.a)
395+
@inline (iec::IfElseCollapserMirror)(a, ::StaticInt{N}) where {N} = VectorizationBase.ifelse_collapse_mirror(iec.f, a, iec.a, StaticInt{N}())
396+
393397
# @inline function (iec::IfElseCollapserMirror)(a, ::StaticInt{C}) where {C}
394398
# VectorizationBase.contract(IfElseOp(iec.f), a, StaticInt{C}())
395399
# end

src/modeling/determinestrategy.jl

Lines changed: 89 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -283,112 +283,113 @@ function unroll_no_reductions(ls, order, vloopsym)
283283
# (iszero(rt) ? 4 : max(1, VectorizationBase.nextpow2( min( 4, round(Int, 8 / rt) ) ))), unrolled
284284
end
285285
function determine_unroll_factor(
286-
ls::LoopSet, order::Vector{Symbol}, unrolled::Symbol, vloopsym::Symbol
286+
ls::LoopSet, order::Vector{Symbol}, unrolled::Symbol, vloopsym::Symbol
287287
)
288-
cacheunrolled!(ls, unrolled, Symbol(""), vloopsym)
289-
size_T = biggest_type_size(ls)
290-
W, Wshift = lsvecwidthshift(ls, vloopsym, size_T)
288+
cacheunrolled!(ls, unrolled, Symbol(""), vloopsym)
289+
size_T = biggest_type_size(ls)
290+
W, Wshift = lsvecwidthshift(ls, vloopsym, size_T)
291291

292-
# So if num_reductions > 0, we set the unroll factor to be high enough so that the CPU can be kept busy
293-
# if there are, U = max(1, round(Int, max(latency) * throughput / num_reductions)) = max(1, round(Int, latency / (recip_throughput * num_reductions)))
294-
# We also make sure register pressure is not too high.
295-
latency = 1.0
296-
# compute_recip_throughput_u = 0.0
297-
compute_recip_throughput = 0.0
298-
visited_nodes = fill(false, length(operations(ls)))
299-
load_recip_throughput = 0.0
300-
store_recip_throughput = 0.0
301-
for op operations(ls)
302-
if isreduction(op)
303-
rt, sl = depchain_cost!(ls, visited_nodes, op, unrolled, vloopsym, Wshift, size_T)
304-
if isouterreduction(ls, op) -1 || unrolled reduceddependencies(op)
305-
latency = max(sl, latency)
306-
end
307-
# if unrolled ∈ loopdependencies(op)
308-
# compute_recip_throughput_u += rt
309-
# else
310-
compute_recip_throughput += rt
311-
# end
312-
elseif isload(op)
313-
load_recip_throughput += first(cost(ls, op, (unrolled,Symbol("")), vloopsym, Wshift, size_T))
314-
elseif isstore(op)
315-
store_recip_throughput += first(cost(ls, op, (unrolled,Symbol("")), vloopsym, Wshift, size_T))
316-
end
292+
# So if num_reductions > 0, we set the unroll factor to be high enough so that the CPU can be kept busy
293+
# if there are, U = max(1, round(Int, max(latency) * throughput / num_reductions)) = max(1, round(Int, latency / (recip_throughput * num_reductions)))
294+
# We also make sure register pressure is not too high.
295+
latency = 1.0
296+
# compute_recip_throughput_u = 0.0
297+
compute_recip_throughput = 0.0
298+
visited_nodes = fill(false, length(operations(ls)))
299+
load_recip_throughput = 0.0
300+
store_recip_throughput = 0.0
301+
for op operations(ls)
302+
if isreduction(op)
303+
rt, sl = depchain_cost!(ls, visited_nodes, op, unrolled, vloopsym, Wshift, size_T)
304+
if isouterreduction(ls, op) -1 || unrolled reduceddependencies(op)
305+
latency = max(sl, latency)
306+
end
307+
# if unrolled ∈ loopdependencies(op)
308+
# compute_recip_throughput_u += rt
309+
# else
310+
compute_recip_throughput += rt
311+
# end
312+
elseif isload(op)
313+
load_recip_throughput += first(cost(ls, op, (unrolled,Symbol("")), vloopsym, Wshift, size_T))
314+
elseif isstore(op)
315+
store_recip_throughput += first(cost(ls, op, (unrolled,Symbol("")), vloopsym, Wshift, size_T))
317316
end
318-
recip_throughput = max(
319-
compute_recip_throughput,
320-
load_recip_throughput,
321-
store_recip_throughput
322-
)
323-
recip_throughput, latency
317+
end
318+
recip_throughput = max(
319+
compute_recip_throughput,
320+
load_recip_throughput,
321+
store_recip_throughput
322+
)
323+
# @show latency, recip_throughput
324+
recip_throughput, latency
324325
end
325326
function count_reductions(ls::LoopSet)
326-
num_reductions = 0
327-
for op operations(ls)
328-
if isreduction(op) & iscompute(op) && parentsnotreduction(op)
329-
num_reductions += 1
330-
end
327+
num_reductions = 0
328+
for op operations(ls)
329+
if isreduction(op) & iscompute(op) && parentsnotreduction(op)
330+
num_reductions += 1
331331
end
332-
num_reductions
332+
end
333+
num_reductions
333334
end
334335

335336
demote_unroll_factor(ls::LoopSet, UF, loop::Symbol) = demote_unroll_factor(ls, UF, getloop(ls, loop))
336337
function demote_unroll_factor(ls::LoopSet, UF, loop::Loop)
337-
W = ls.vector_width
338-
if !iszero(W) && isstaticloop(loop)
339-
UFW = maybedemotesize(UF*W, length(loop))
340-
UF = cld(UFW, W)
341-
end
342-
UF
338+
W = ls.vector_width
339+
if !iszero(W) && isstaticloop(loop)
340+
UFW = maybedemotesize(UF*W, length(loop))
341+
UF = cld(UFW, W)
342+
end
343+
UF
343344
end
344345

345346
function determine_unroll_factor(ls::LoopSet, order::Vector{Symbol}, vloopsym::Symbol)
346-
num_reductions = count_reductions(ls)
347-
# The strategy is to use an unroll factor of 1, unless there appears to be loop carried dependencies (ie, num_reductions > 0)
348-
# The assumption here is that unrolling provides no real benefit, unless it is needed to enable OOO execution by breaking up these dependency chains
349-
loopindexesbit = ls.loopindexesbit
350-
if iszero(length(loopindexesbit)) || ((!loopindexesbit[getloopid(ls, vloopsym)]))
351-
if iszero(num_reductions)
352-
return unroll_no_reductions(ls, order, vloopsym)
353-
else
354-
return determine_unroll_factor(ls, order, vloopsym, num_reductions)
355-
end
356-
elseif iszero(num_reductions) # handle `BitArray` loops w/out reductions
357-
return 8 ÷ ls.vector_width, vloopsym
358-
else # handle `BitArray` loops with reductions
359-
rttemp, ltemp = determine_unroll_factor(ls, order, vloopsym, vloopsym)
360-
UF = min(8, VectorizationBase.nextpow2(max(1, round(Int, ltemp / (rttemp) ) )))
361-
UFfactor = 8 ÷ ls.vector_width
362-
cld(UF, UFfactor)*UFfactor, vloopsym
363-
end
347+
num_reductions = count_reductions(ls)
348+
# The strategy is to use an unroll factor of 1, unless there appears to be loop carried dependencies (ie, num_reductions > 0)
349+
# The assumption here is that unrolling provides no real benefit, unless it is needed to enable OOO execution by breaking up these dependency chains
350+
loopindexesbit = ls.loopindexesbit
351+
if iszero(length(loopindexesbit)) || ((!loopindexesbit[getloopid(ls, vloopsym)]))
352+
if iszero(num_reductions)
353+
return unroll_no_reductions(ls, order, vloopsym)
354+
else
355+
return determine_unroll_factor(ls, order, vloopsym, num_reductions)
356+
end
357+
elseif iszero(num_reductions) # handle `BitArray` loops w/out reductions
358+
return 8 ÷ ls.vector_width, vloopsym
359+
else # handle `BitArray` loops with reductions
360+
rttemp, ltemp = determine_unroll_factor(ls, order, vloopsym, vloopsym)
361+
UF = min(8, VectorizationBase.nextpow2(max(1, round(Int, ltemp / (rttemp) ) )))
362+
UFfactor = 8 ÷ ls.vector_width
363+
cld(UF, UFfactor)*UFfactor, vloopsym
364+
end
364365
end
365366
# function scale_unrolled()
366367
# end
367368
function determine_unroll_factor(ls::LoopSet, order::Vector{Symbol}, vloopsym::Symbol, num_reductions::Int)
368-
innermost_loop = last(order)
369-
rt = Inf; rtcomp = Inf; latency = Inf; best_unrolled = Symbol("")
370-
for unrolled order
371-
reject_reorder(ls, unrolled, false) && continue
372-
rttemp, ltemp = determine_unroll_factor(ls, order, unrolled, vloopsym)
373-
rtcomptemp = rttemp + (0.01 * ((vloopsym === unrolled) + (unrolled === innermost_loop) - latency))
374-
if rtcomptemp < rtcomp
375-
rt = rttemp
376-
rtcomp = rtcomptemp
377-
latency = ltemp
378-
best_unrolled = unrolled
379-
end
380-
end
381-
# min(8, roundpow2(max(1, round(Int, latency / (rt * num_reductions) ) ))), best_unrolled
382-
lrtratio = latency / rt
383-
if lrtratio 7.0
384-
UF = 8
385-
else
386-
UF = VectorizationBase.nextpow2(round(Int, clamp(lrtratio, 1.0, 4.0)))
387-
end
388-
if best_unrolled === vloopsym
389-
UF = demote_unroll_factor(ls, UF, vloopsym)
369+
innermost_loop = last(order)
370+
rt = Inf; rtcomp = Inf; latency = Inf; best_unrolled = Symbol("")
371+
for unrolled order
372+
reject_reorder(ls, unrolled, false) && continue
373+
rttemp, ltemp = determine_unroll_factor(ls, order, unrolled, vloopsym)
374+
rtcomptemp = rttemp + (0.01 * ((vloopsym === unrolled) + (unrolled === innermost_loop) - latency))
375+
if rtcomptemp < rtcomp
376+
rt = rttemp
377+
rtcomp = rtcomptemp
378+
latency = ltemp
379+
best_unrolled = unrolled
390380
end
391-
UF, best_unrolled
381+
end
382+
# min(8, roundpow2(max(1, round(Int, latency / (rt * num_reductions) ) ))), best_unrolled
383+
lrtratio = latency / rt
384+
if lrtratio 7.0
385+
UF = 8
386+
else
387+
UF = VectorizationBase.nextpow2(round(Int, clamp(lrtratio, 1.0, 4.0), RoundUp))
388+
end
389+
if best_unrolled === vloopsym
390+
UF = demote_unroll_factor(ls, UF, vloopsym)
391+
end
392+
UF, best_unrolled
392393
end
393394

394395
@inline function unroll_cost(X, u₁, u₂, u₁L, u₂L)

test/grouptests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ const START_TIME = time()
1111
end
1212

1313
@time if LOOPVECTORIZATION_TEST == "all" || LOOPVECTORIZATION_TEST == "part2"
14+
using Aqua
1415
@time Aqua.test_all(LoopVectorization, ambiguities = VERSION v"1.6")
1516
# @test isempty(detect_unbound_args(LoopVectorization))
1617

test/ifelsemasks.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -478,10 +478,10 @@ T = Float32
478478
end
479479
minval, indmin
480480
end
481-
function findminturbo_u2(x)
481+
function findminturbo_u4(x)
482482
indmin = 0
483483
minval = typemax(eltype(x))
484-
@turbo unroll=2 for i eachindex(x)
484+
@turbo unroll=4 for i eachindex(x)
485485
newmin = x[i] < minval
486486
minval = newmin ? x[i] : minval
487487
indmin = newmin ? i : indmin
@@ -495,11 +495,11 @@ T = Float32
495495
if T <: Integer
496496
a = rand(-T(100):T(100), N); b = rand(-T(100):T(100), N);
497497
mv, mi = findminturbo(a)
498-
mv2, mi2 = findminturbo_u2(a)
498+
mv2, mi2 = findminturbo_u4(a)
499499
@test mv == a[mi] == minimum(a) == mv2 == a[mi2]
500500
else
501501
a = rand(T, N); b = rand(T, N);
502-
@test findmin(a) == findminturbo(a) == findminturbo_u2(a)
502+
@test findmin(a) == findminturbo(a) == findminturbo_u4(a)
503503
end;
504504
c1 = similar(a); c2 = similar(a);
505505

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
include("testsetup.jl")
22

3-
import InteractiveUtils, Aqua
3+
import InteractiveUtils
44

55
InteractiveUtils.versioninfo(stdout; verbose = true)
66

0 commit comments

Comments
 (0)