Skip to content

Commit 2472156

Browse files
committed
Handle offsets alongside +/- expressions in indices, and still unroll when we have load-into-stores; stridepenalty function needs work based on actual size of the arrays.
1 parent 15a5507 commit 2472156

File tree

4 files changed

+56
-45
lines changed

4 files changed

+56
-45
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LoopVectorization"
22
uuid = "bdcacae8-1622-11e9-2a5c-532679323890"
33
authors = ["Chris Elrod <[email protected]>"]
4-
version = "0.8.9"
4+
version = "0.8.10"
55

66
[deps]
77
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"

src/determinestrategy.jl

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ function unitstride(ls::LoopSet, op::Operation, s::Symbol)
2626
fi = first(inds)
2727
if fi === Symbol("##DISCONTIGUOUSSUBARRAY##")
2828
return false
29-
elseif !first(li)
30-
# We must check if this
31-
parent = findparent(ls, fi)
32-
indexappearences(parent, s) > 1 && return false
29+
# elseif !first(li)
30+
# # We must check if this
31+
# parent = findparent(ls, fi)
32+
# indexappearences(parent, s) > 1 && return false
3333
end
3434
for i 2:length(inds)
3535
if li[i]
@@ -217,8 +217,7 @@ function unroll_no_reductions(ls, order, vectorized)
217217
elseif isstore(op)
218218
store_rt += first(cost(ls, op, vectorized, Wshift, size_T))
219219
end
220-
end
221-
# @show compute_rt, load_rt, store_rt
220+
end
222221
# heuristic guess
223222
# roundpow2(min(4, round(Int, (compute_rt + load_rt + 1) / compute_rt)))
224223
memory_rt = load_rt + store_rt
@@ -374,7 +373,7 @@ function solve_unroll(X, R, u₁L, u₂L, u₁step, u₂step)
374373
u₂low = max(u₂step, floor(Int, u₂float)) # must be at least 1
375374
u₁high = solve_unroll_constT(R, u₂low) + u₁step
376375
u₂high = solve_unroll_constU(R, u₁low) + u₂step
377-
maxunroll = REGISTER_COUNT == 32 ? 10 : 6
376+
maxunroll = REGISTER_COUNT == 32 ? (((X₂ > 0) & (X₃ > 0)) ? 10 : 8) : 6
378377
u₁low = min(u₁low, maxunroll)
379378
u₂low = min(u₂low, maxunroll)
380379
u₁high = min(u₁high, maxunroll)
@@ -534,18 +533,19 @@ function stride_penalty(ls::LoopSet, op::Operation, order::Vector{Symbol}, loopf
534533
penalty
535534
end
536535
function stride_penalty(ls::LoopSet, order::Vector{Symbol})
537-
stridepenalty = 0.0
536+
stridepenaltydict = Dict{Symbol,Vector{Float64}}()
538537
loopfreqs = Vector{Int}(undef, length(order))
539538
loopfreqs[1] = 1
540539
for i 2:length(order)
541540
loopfreqs[i] = loopfreqs[i-1] * length(getloop(ls, order[i]))
542541
end
543542
for op operations(ls)
544543
if accesses_memory(op)
545-
stridepenalty += stride_penalty(ls, op, order, loopfreqs)
544+
v = get!(() -> Float64[], stridepenaltydict, op.ref.ref.array)
545+
push!(v, stride_penalty(ls, op, order, loopfreqs))
546546
end
547547
end
548-
stridepenalty# * 1e-9
548+
sum(maximum, values(stridepenaltydict)) #* prod(length, ls.loops) / 1024^length(order)
549549
end
550550
function isoptranslation(ls::LoopSet, op::Operation, unrollsyms::UnrollSymbols)
551551
@unpack u₁loopsym, u₂loopsym, vectorized = unrollsyms
@@ -627,6 +627,7 @@ function load_elimination_cost_factor!(
627627
@unpack u₁loopsym, u₂loopsym, vectorized = unrollsyms
628628
if !iszero(first(isoptranslation(ls, op, unrollsyms)))
629629
rt, lat, rp = cost(ls, op, vectorized, Wshift, size_T)
630+
rto = rt
630631
rt *= iters
631632
# rt *= factor1; rp *= factor2;
632633
choose_to_inline[] = true
@@ -645,8 +646,8 @@ function load_elimination_cost_factor!(
645646
# end
646647
# # (0.25, REGISTER_COUNT == 32 ? 1.2 : 1.0)
647648
# (0.25, 1.0)
648-
cost_vec[1] += 0.1rt
649-
reg_pressure[1] += 0.51rp
649+
cost_vec[1] -= 0.1prod(length, ls.loops)
650+
reg_pressure[1] += 0.25rp
650651
cost_vec[2] += rt
651652
reg_pressure[2] += rp
652653
cost_vec[3] += rt
@@ -658,13 +659,13 @@ function load_elimination_cost_factor!(
658659
false
659660
end
660661
end
661-
# function loadintostore(ls::LoopSet, op::Operation)
662-
# isload(op) || return false # leads to bad behavior more than it helps
663-
# for opp ∈ operations(ls)
664-
# isstore(opp) && opp.ref == op.ref && return true
665-
# end
666-
# false
667-
# end
662+
function loadintostore(ls::LoopSet, op::Operation)
663+
isload(op) || return false # leads to bad behavior more than it helps
664+
for opp operations(ls)
665+
isstore(opp) && opp.ref == op.ref && return true
666+
end
667+
false
668+
end
668669
function store_load_deps!(deps::Vector{Symbol}, op::Operation, compref = op.ref)
669670
for opp parents(op)
670671
foreach(ld -> ((ld deps) || push!(deps, ld)), loopdependencies(opp))
@@ -799,7 +800,6 @@ function evaluate_cost_tile(
799800
all(ld -> ld nested_loop_syms, loopdependencies(op)) || continue
800801
rd = reduceddependencies(op)
801802
if hasintersection(rd, @view(nested_loop_syms[1:end-length(rd)]))
802-
# @show rd, op itersym, nested_loop_syms @view(nested_loop_syms[1:end-length(rd)])
803803
return 0,0,Inf,false
804804
end
805805
if isstore(op)
@@ -860,7 +860,7 @@ function evaluate_cost_tile(
860860
costpenalty = (sum(reg_pressure) > REGISTER_COUNT) ? 2 : 1
861861
u₁v = vectorized === u₁loopsym; u₂v = vectorized === u₂loopsym
862862
round_uᵢ = prefetch_good_idea ? (u₁v ? 1 : (u₂v ? 2 : 0)) : 0
863-
if irreducible_storecosts / sum(cost_vec) 0.25
863+
if (irreducible_storecosts / sum(cost_vec) 0.25) && !any(op -> loadintostore(ls, op), operations(ls))
864864
u₁, u₂ = (1, 1)
865865
ucost = unroll_cost(cost_vec, 1, 1, length(getloop(ls, u₁loopsym)), length(getloop(ls, u₂loopsym)))
866866
else

src/lowering.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -379,17 +379,17 @@ function lower_unrolled_dynamic(ls::LoopSet, us::UnrollSpecification, n::Int, in
379379
end
380380
elseif iszero(UFt)
381381
Expr( :block, q )
382-
# elseif !nisvectorized && !loopisstatic && UF ≥ 8
383-
# rem_uf = UF - 1
384-
# UF = rem_uf >> 1
385-
# UFt = rem_uf - UF
386-
# ust = nisunrolled ? UnrollSpecification(us, UFt, u₂) : UnrollSpecification(us, u₁, UFt)
387-
# newblock = lower_block(ls, ust, n, remmask, UFt)
388-
# # comparison = unrollremcomparison(ls, loop, UFt, n, nisvectorized, remfirst)
389-
# comparison = terminatecondition(ls, us, n, inclmask, UFt)
390-
# UFt = 1
391-
# UF += 1 - iseven(rem_uf)
392-
# Expr( :block, q, Expr(iseven(rem_uf) ? :while : :if, comparison, newblock), remblock )
382+
elseif !nisvectorized && !loopisstatic && UF 10
383+
rem_uf = UF - 1
384+
UF = rem_uf >> 1
385+
UFt = rem_uf - UF
386+
ust = nisunrolled ? UnrollSpecification(us, UFt, u₂) : UnrollSpecification(us, u₁, UFt)
387+
newblock = lower_block(ls, ust, n, remmask, UFt)
388+
# comparison = unrollremcomparison(ls, loop, UFt, n, nisvectorized, remfirst)
389+
comparison = terminatecondition(ls, us, n, inclmask, UFt)
390+
UFt = 1
391+
UF += 1 - iseven(rem_uf)
392+
Expr( :block, q, Expr(iseven(rem_uf) ? :while : :if, comparison, newblock), remblock )
393393
else
394394
# if (usorig.u₁ == us.u₁) && (usorig.u₂ == us.u₂) && !isstaticloop(loop) && !inclmask# && !ls.loadelimination[]
395395
# # Expr(:block, sl, assumeloopiteratesatleastonce(loop), Expr(:while, tc, body))

src/memory_ops_common.jl

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,35 +85,46 @@ function subset_vptr!(ls::LoopSet, vptr::Symbol, indnum::Int, ind, previndices,
8585
end
8686

8787
function addoffset!(ls, indices, offsets, loopedindex, loopdependencies, ind, offset)
88-
if typemin(Int8) offset typemax(Int8)
89-
push!(indices, ind);
90-
push!(offsets, offset % Int8)
91-
push!(loopedindex, true)
92-
push!(loopdependencies, ind)
93-
true
94-
else
95-
false
96-
end
88+
(typemin(Int8) offset typemax(Int8)) || return false
89+
push!(indices, ind);
90+
push!(offsets, offset % Int8)
91+
push!(loopedindex, true)
92+
push!(loopdependencies, ind)
93+
true
94+
end
95+
function addoffsetexpr!(ls, parents, indices, offsets, loopedindex, loopdependencies, reduceddeps, ind, offset, elementbytes)
96+
(typemin(Int8) offset typemax(Int8)) || return false
97+
parent = add_operation!(ls, gensym(:indexpr), ind, elementbytes, length(ls.loopsymbols))
98+
pushparent!(parents, loopdependencies, reduceddeps, parent)
99+
push!(indices, name(parent));
100+
push!(offsets, offset % Int8)
101+
push!(loopedindex, false)
102+
true
97103
end
98104

99105
function checkforoffset!(
100-
ls::LoopSet, indices::Vector{Symbol}, offsets::Vector{Int8}, loopedindex::Vector{Bool}, loopdependencies::Vector{Symbol}, ind::Expr
106+
ls::LoopSet, parents::Vector{Operation}, indices::Vector{Symbol}, offsets::Vector{Int8}, loopedindex::Vector{Bool},
107+
loopdependencies::Vector{Symbol}, reduceddeps::Vector{Symbol}, ind::Expr, elementbytes::Int
101108
)
102109
ind.head === :call || return false
103110
f = first(ind.args)
104111
(((f === :+) || (f === :-)) && (length(ind.args) == 3)) || return false
105112
factor = f === :+ ? 1 : -1
106113
arg1 = ind.args[2]
107114
arg2 = ind.args[3]
108-
if arg1 isa Integer && isone(factor)
115+
if arg1 isa Integer# && isone(factor)
109116
if arg2 isa Symbol && arg2 ls.loopsymbols
110117
addoffset!(ls, indices, offsets, loopedindex, loopdependencies, arg2, arg1 * factor)
118+
elseif arg2 isa Expr
119+
addoffsetexpr!(ls, parents, indices, offsets, loopedindex, loopdependencies, reduceddeps, arg2, arg1 * factor, elementbytes)
111120
else
112121
false
113122
end
114123
elseif arg2 isa Integer
115124
if arg1 isa Symbol && arg1 ls.loopsymbols
116125
addoffset!(ls, indices, offsets, loopedindex, loopdependencies, arg1, arg2 * factor)
126+
elseif arg1 isa Expr
127+
addoffsetexpr!(ls, parents, indices, offsets, loopedindex, loopdependencies, reduceddeps, arg1, arg2 * factor, elementbytes)
117128
else
118129
false
119130
end
@@ -158,7 +169,7 @@ function array_reference_meta!(ls::LoopSet, array::Symbol, rawindices, elementby
158169
length(indices) == 0 && push!(indices, DISCONTIGUOUS)
159170
elseif ind isa Expr
160171
#FIXME: position (in loopnest) wont be length(ls.loopsymbols) in general
161-
if !checkforoffset!(ls, indices, offsets, loopedindex, loopdependencies, ind)
172+
if !checkforoffset!(ls, parents, indices, offsets, loopedindex, loopdependencies, reduceddeps, ind, elementbytes)
162173
parent = add_operation!(ls, gensym(:indexpr), ind, elementbytes, length(ls.loopsymbols))
163174
pushparent!(parents, loopdependencies, reduceddeps, parent)
164175
push!(indices, name(parent));

0 commit comments

Comments
 (0)