Skip to content

Commit 272c856

Browse files
committed
A fix for lowering when an outerreduction isn't unrolled, and changes to determinestrategy.jl that should (hopefully) improve performance more often than not.
1 parent 886887a commit 272c856

File tree

6 files changed

+75
-52
lines changed

6 files changed

+75
-52
lines changed

src/condense_loopset.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,17 +157,23 @@ end
157157
function loopset_return_value(ls::LoopSet, ::Val{extract}) where {extract}
158158
@assert !iszero(length(ls.outer_reductions))
159159
if isone(length(ls.outer_reductions))
160+
op = getop(ls, ls.outer_reductions[1])
160161
if extract
161-
Expr(:call, :extract_data, Symbol(mangledvar(getop(ls, ls.outer_reductions[1])), 0))
162+
if (isu₁unrolled(op) | isu₂unrolled(op))
163+
Expr(:call, :extract_data, Symbol(mangledvar(op), 0))
164+
else
165+
Expr(:call, :extract_data, mangledvar(op))
166+
end
162167
else
163-
Symbol(mangledvar(getop(ls, ls.outer_reductions[1])), 0)
168+
Symbol(mangledvar(op), 0)
164169
end
165170
else#if length(ls.outer_reductions) > 1
166171
ret = Expr(:tuple)
167172
ops = operations(ls)
168173
for or ls.outer_reductions
174+
op = ops[or]
169175
if extract
170-
push!(ret.args, Expr(:call, :extract_data, Symbol(mangledvar(ops[or]), 0)))
176+
push!(ret.args, Expr(:call, :extract_data, Symbol(mangledvar(op), 0)))
171177
else
172178
push!(ret.args, Symbol(mangledvar(ops[or]), 0))
173179
end

src/determinestrategy.jl

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ function cost(ls::LoopSet, op::Operation, vectorized::Symbol, Wshift::Int, size_
6767
srt, sl, srp = opisvectorized ? vector_cost(instr, Wshift, size_T) : scalar_cost(instr)
6868
if accesses_memory(op)
6969
# either vbroadcast/reductionstore, vmov(a/u)pd, or gather/scatter
70-
# @show instr, vectorized, loopdependencies(op), unitstride(op, vectorized)
7170
if opisvectorized
7271
if !unitstride(ls, op, vectorized)# || !isdense(op) # need gather/scatter
7372
r = (1 << Wshift)
@@ -131,12 +130,11 @@ function evaluate_cost_unroll(
131130
rd = reduceddependencies(op)
132131
hasintersection(rd, @view(nested_loop_syms[1:end-length(rd)])) && return Inf
133132
included_vars[id] = true
134-
# @show op first(cost(op, vectorized, Wshift, size_T)), iter
135133
total_cost += iter * first(cost(ls, op, vectorized, Wshift, size_T))
136134
total_cost > max_cost && return total_cost # abort if more expensive; we only want to know the cheapest
137135
end
138136
end
139-
total_cost + stride_penalty(ls, order)
137+
total_cost + stride_penalty(ls, order) - 1.0 # -1.0 to place finger on scale in its favor
140138
end
141139

142140
# only covers vectorized ops; everything else considered lifted?
@@ -163,13 +161,16 @@ function parentsnotreduction(op::Operation)
163161
end
164162
return true
165163
end
166-
function roundpow2(i::Integer)
167-
u = VectorizationBase.nextpow2(i)
168-
l = u >>> 1
169-
ud = u - i
170-
ld = i - l
171-
ud > ld ? l : u
172-
end
164+
# function roundpow2(i::Integer)
165+
# u = VectorizationBase.nextpow2(i)
166+
# l = u >>> 1
167+
# ud = u - i
168+
# ld = i - l
169+
# ud > ld ? l : u
170+
# end
171+
# function roundpow2(x::Float64)
172+
# 1 << round(Int, log2(x))
173+
# end
173174
function unroll_no_reductions(ls, order, vectorized)
174175
size_T = biggest_type_size(ls)
175176
W, Wshift = VectorizationBase.pick_vector_width_shift(length(ls, vectorized), size_T)::Tuple{Int,Int}
@@ -190,10 +191,10 @@ function unroll_no_reductions(ls, order, vectorized)
190191
end
191192
end
192193
# heuristic guess
193-
# @show compute_rt, load_rt
194194
# roundpow2(min(4, round(Int, (compute_rt + load_rt + 1) / compute_rt)))
195195
rt = max(compute_rt, load_rt)
196-
(iszero(rt) ? 4 : max(1, roundpow2( min( 4, round(Int, 16 / rt) ) ))), unrolled
196+
# (iszero(rt) ? 4 : max(1, roundpow2( min( 4, round(Int, 16 / rt) ) ))), unrolled
197+
(iszero(rt) ? 4 : max(1, VectorizationBase.nextpow2( min( 4, round(Int, 16 / rt) ) ))), unrolled
197198
end
198199
function determine_unroll_factor(
199200
ls::LoopSet, order::Vector{Symbol}, unrolled::Symbol, vectorized::Symbol
@@ -204,17 +205,24 @@ function determine_unroll_factor(
204205
# So if num_reductions > 0, we set the unroll factor to be high enough so that the CPU can be kept busy
205206
# if there are, U = max(1, round(Int, max(latency) * throughput / num_reductions)) = max(1, round(Int, latency / (recip_throughput * num_reductions)))
206207
# We also make sure register pressure is not too high.
207-
latency = 0
208+
latency = 1
209+
# compute_recip_throughput_u = 0.0
208210
compute_recip_throughput = 0.0
209211
visited_nodes = fill(false, length(operations(ls)))
210212
load_recip_throughput = 0.0
211213
store_recip_throughput = 0.0
212214
for op operations(ls)
213-
dependson(op, unrolled) || continue
215+
# dependson(op, unrolled) || continue
214216
if isreduction(op)
215217
rt, sl = depchain_cost!(ls, visited_nodes, op, vectorized, Wshift, size_T)
216-
latency = max(sl, latency)
218+
if isouterreduction(op) != -1 || unrolled reduceddependencies(op)
219+
latency = max(sl, latency)
220+
end
221+
# if unrolled ∈ loopdependencies(op)
222+
# compute_recip_throughput_u += rt
223+
# else
217224
compute_recip_throughput += rt
225+
# end
218226
elseif isload(op)
219227
load_recip_throughput += first(cost(ls, op, vectorized, Wshift, size_T))
220228
elseif isstore(op)
@@ -247,19 +255,20 @@ function determine_unroll_factor(ls::LoopSet, order::Vector{Symbol}, vectorized:
247255
# if more than 1 loop, there is some cost. Picking 2 here as a heuristic.
248256
return unroll_no_reductions(ls, order, vectorized)
249257
end
250-
258+
innermost_loop = last(order)
251259
rt = Inf; rtcomp = Inf; latency = Inf; best_unrolled = Symbol("")
252260
for unrolled order
253261
rttemp, ltemp = determine_unroll_factor(ls, order, unrolled, vectorized)
254-
rtcomptemp = rttemp + (0.01 * (vectorized === unrolled))
262+
rtcomptemp = rttemp + (0.01 * ((vectorized === unrolled) + (unrolled === innermost_loop) - latency))
255263
if rtcomptemp < rtcomp
256264
rt = rttemp
257265
rtcomp = rtcomptemp
258266
latency = ltemp
259267
best_unrolled = unrolled
260268
end
261269
end
262-
min(8, roundpow2(max(1, round(Int, latency / (rt * num_reductions) ) ))), best_unrolled
270+
# min(8, roundpow2(max(1, round(Int, latency / (rt * num_reductions) ) ))), best_unrolled
271+
min(8, VectorizationBase.nextpow2(max(1, round(Int, latency / (rt * num_reductions) ) ))), best_unrolled
263272
end
264273

265274
function unroll_cost(X, u₁, u₂, u₁L, u₂L)
@@ -273,7 +282,6 @@ end
273282
# u₁b = 1; u₂b = 1
274283
# for u₁ ∈ 1:4, u₂ ∈ 1:4
275284
# c = unroll_cost(X, u₁, u₂, u₁L, u₂L)
276-
# @show u₁, u₂, c
277285
# if cb > c
278286
# cb = c
279287
# u₁b = u₁; u₂b = u₂
@@ -679,7 +687,6 @@ function evaluate_cost_tile(
679687
# cost_mat[2] / ( u₂loopsym)
680688
# cost_mat[3] / ( unrolled)
681689
# cost_mat[4]
682-
# @show order
683690
cost_vec = cost_vec_buf(ls)
684691
reg_pressure = reg_pres_buf(ls)
685692
# @inbounds reg_pressure[2] = 1
@@ -708,8 +715,6 @@ function evaluate_cost_tile(
708715
included_vars[id] && continue
709716
# it must also be a subset of defined symbols
710717
all(ld -> ld nested_loop_syms, loopdependencies(op)) || continue
711-
# # @show nested_loop_syms
712-
# # @show reduceddependencies(op)
713718
rd = reduceddependencies(op)
714719
hasintersection(rd, @view(nested_loop_syms[1:end-length(rd)])) && return 0,0,Inf,false
715720
included_vars[id] = true
@@ -720,7 +725,6 @@ function evaluate_cost_tile(
720725
# reduced_by_unrolling[2,id] = (u₂reached | depends_on_u₁) & !depends_on_u₂
721726
reduced_by_unrolling[1,id] = (u₁reached) & !depends_on_u₁
722727
reduced_by_unrolling[2,id] = (u₂reached) & !depends_on_u₂
723-
# @show op iter, unrolledu₂loopsym[:,id]
724728
iters[id] = iter
725729
innerloop loopdependencies(op) && set_upstream_family!(descendentsininnerloop, op, true)
726730
end
@@ -730,7 +734,6 @@ function evaluate_cost_tile(
730734
opisininnerloop = descendentsininnerloop[id]
731735

732736
u₁reduces, u₂reduces = reduced_by_unrolling[1,id], reduced_by_unrolling[2,id]
733-
# @show op, u₁reduces, u₂reduces
734737
if isload(op)
735738
if add_constant_offset_load_elmination_cost!(cost_vec, reg_pressure, choose_to_inline, ls, op, iters[id], unrollsyms, u₁reduces, u₂reduces, Wshift, size_T, opisininnerloop)
736739
continue
@@ -743,34 +746,26 @@ function evaluate_cost_tile(
743746
rt += 0.5VectorizationBase.REGISTER_SIZE / VectorizationBase.CACHELINE_SIZE
744747
prefetch_good_idea = true
745748
end
746-
# @show isunrolled₁, isunrolled₂, op rt, lat, rp
747749
rp = (opisininnerloop && !(loadintostore(ls, op))) ? rp : zero(rp) # we only care about register pressure within the inner most loop
748750
# rp = opisininnerloop ? rp : zero(rp) # we only care about register pressure within the inner most loop
749751
rto = rt
750752
rt *= iters[id]
751753
if u₁reduces & u₂reduces
752-
# @show op 4, rto, iters[id], lat, rp
753754
cost_vec[4] += rt
754755
reg_pressure[4] += rp
755756
elseif u₂reduces # cost decreased by unrolling u₂loop
756-
# @show op 2, rto, iters[id], lat, rp
757757
cost_vec[2] += rt
758758
reg_pressure[2] += rp
759759
elseif u₁reduces # cost decreased by unrolling u₁loop
760-
# @show op 3, rto, iters[id], lat, rp
761760
cost_vec[3] += rt
762761
reg_pressure[3] += rp
763762
else # no cost decrease; cost must be repeated
764-
# @show op 1, rto, iters[id], lat, rp
765763
cost_vec[1] += rt
766764
reg_pressure[1] += rp
767765
end
768766
end
769767
# @inbounds ((cost_vec[4] > 0) || ((cost_vec[2] > 0) & (cost_vec[3] > 0))) || return 0,0,Inf,false
770-
# @show cost_vec reg_pressure
771768
costpenalty = (sum(reg_pressure) > REGISTER_COUNT) ? 2 : 1
772-
# @show order, vectorized cost_vec reg_pressure
773-
# @show solve_unroll(ls, u₁loopsym, u₂loopsym, cost_vec, reg_pressure)
774769
u₁v = vectorized === u₁loopsym; u₂v = vectorized === u₂loopsym
775770
round_uᵢ = prefetch_good_idea ? (u₁v ? 1 : (u₂v ? 2 : 0)) : 0
776771
u₁, u₂, ucost = solve_unroll(ls, u₁loopsym, u₂loopsym, cost_vec, reg_pressure, W, vectorized, round_uᵢ)
@@ -820,7 +815,6 @@ end
820815
# that I could come up with.
821816
function Base.iterate(lo::LoopOrders, state)
822817
advance_state!(state) || return nothing
823-
# # @show state
824818
syms = copyto!(lo.buff, lo.syms)
825819
for i eachindex(state)
826820
sᵢ = state[i]

src/graphs.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,3 +703,11 @@ function UnrollSpecification(ls::LoopSet, u₁loop::Symbol, u₂loop::Symbol, ve
703703
nv = findfirst(isequal(vectorized), order)::Int
704704
UnrollSpecification(nu₁, nu₂, nv, u₁, u₂)
705705
end
706+
707+
# function getunrolled(ls::LoopSet)
708+
# order = names(ls)
709+
# us = ls.unrollspecification[]
710+
# @unpack u₁loopnum, u₂loopnum = us
711+
# order[u₁loopnum], order[u₂loopnum]
712+
# end
713+

src/lowering.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,13 +440,21 @@ function add_upper_outer_reductions(ls::LoopSet, loopq::Expr, Ulow::Int, Uhigh::
440440
Expr(:if, ncomparison, ifq)
441441
end
442442
function reduce_expr!(q::Expr, ls::LoopSet, U::Int)
443+
us = ls.unrollspecification[]
444+
# u₁loop, u₂loop = getunrolled(ls)
443445
for or ls.outer_reductions
444446
op = ls.operations[or]
445447
var = name(op)
446448
mvar = mangledvar(op)
447449
instr = instruction(op)
448450
reduce_expr!(q, mvar, instr, U)
449-
length(ls.opdict) == 0 || push!(q.args, Expr(:(=), var, Expr(:call, lv(reduction_scalar_combine(instr)), var, Symbol(mvar, 0))))
451+
if !iszero(length(ls.opdict))
452+
if (isu₁unrolled(op) | isu₂unrolled(op))
453+
push!(q.args, Expr(:(=), var, Expr(:call, lv(reduction_scalar_combine(instr)), var, Symbol(mvar, 0))))
454+
else
455+
push!(q.args, Expr(:(=), var, mvar))
456+
end
457+
end
450458
end
451459
end
452460
function gc_preserve(ls::LoopSet, q::Expr)

test/gemv.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Test
33
# T = Float32
44
@testset "GEMV" begin
55
# Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (4, 6)
6-
Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (2, 6) : (2, 10)
6+
Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (1, 6) : (1, 10)
77
gemvq = :(for i eachindex(y)
88
yᵢ = 0.0
99
for j eachindex(x)

test/miscellaneous.jl

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Test
55
@testset "Miscellaneous" begin
66

77
# Unum, Tnum = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (4, 4)
8-
Unum, Tnum = LoopVectorization.REGISTER_COUNT == 16 ? (2, 6) : (2, 10)
8+
Unum, Tnum = LoopVectorization.REGISTER_COUNT == 16 ? (1, 6) : (1, 10)
99
dot3q = :(for m 1:M, n 1:N
1010
s += x[m] * A[m,n] * y[n]
1111
end);
@@ -69,9 +69,12 @@ using Test
6969
B[j,i] = A[j,i] - x[j]
7070
end)
7171
lssubcol = LoopVectorization.LoopSet(subcolq);
72-
if LoopVectorization.REGISTER_COUNT != 8
73-
@test LoopVectorization.choose_order(lssubcol) == (Symbol[:j,:i], :j, :i, :j, Unum, Tnum)
74-
end
72+
# if LoopVectorization.REGISTER_COUNT != 8
73+
# # @test LoopVectorization.choose_order(lssubcol) == (Symbol[:j,:i], :j, :i, :j, Unum, Tnum)
74+
# @test LoopVectorization.choose_order(lssubcol) == (Symbol[:j,:i], :j, :i, :j, 1, 1)
75+
# end
76+
@test LoopVectorization.choose_order(lssubcol) == (Symbol[:j,:i], :i, Symbol("##undefined##"), :j, 4, -1)
77+
# @test LoopVectorization.choose_order(lssubcol) == (Symbol[:j,:i], :j, Symbol("##undefined##"), :j, 4, -1)
7578
## @avx is SLOWER!!!!
7679
## need to fix!
7780
function mysubcol!(B, A, x)
@@ -96,9 +99,11 @@ using Test
9699
x[j] += A[j,i] - 0.25
97100
end)
98101
lscolsum = LoopVectorization.LoopSet(colsumq);
99-
if LoopVectorization.REGISTER_COUNT != 8
100-
@test LoopVectorization.choose_order(lscolsum) == (Symbol[:j,:i], :j, :i, :j, Unum, Tnum)
101-
end
102+
# if LoopVectorization.REGISTER_COUNT != 8
103+
# # @test LoopVectorization.choose_order(lscolsum) == (Symbol[:j,:i], :j, :i, :j, Unum, Tnum)
104+
# @test LoopVectorization.choose_order(lscolsum) == (Symbol[:j,:i], :j, :i, :j, 1, 1)
105+
# end
106+
@test LoopVectorization.choose_order(lscolsum) == (Symbol[:j,:i], :j, Symbol("##undefined##"), :j, 8, -1)
102107
# my colsum is wrong (by 0.25), but slightly more interesting
103108
function mycolsum!(x, A)
104109
@. x = 0
@@ -133,11 +138,13 @@ using Test
133138
end)
134139
lsvar = LoopVectorization.LoopSet(varq);
135140
# LoopVectorization.choose_order(lsvar)
136-
if LoopVectorization.REGISTER_COUNT == 32
137-
@test LoopVectorization.choose_order(lsvar) == (Symbol[:j,:i], :j, :i, :j, 2, 10)
138-
elseif LoopVectorization.REGISTER_COUNT == 16
139-
@test LoopVectorization.choose_order(lsvar) == (Symbol[:j,:i], :j, :i, :j, 2, 6)
140-
end
141+
# @test LoopVectorization.choose_order(lsvar) == (Symbol[:j,:i], :j, :i, :j, Unum, Tnum)
142+
@test LoopVectorization.choose_order(lsvar) == (Symbol[:j,:i], :j, Symbol("##undefined##"), :j, 8, -1)
143+
# if LoopVectorization.REGISTER_COUNT == 32
144+
# @test LoopVectorization.choose_order(lsvar) == (Symbol[:j,:i], :j, :i, :j, 2, 10)
145+
# elseif LoopVectorization.REGISTER_COUNT == 16
146+
# @test LoopVectorization.choose_order(lsvar) == (Symbol[:j,:i], :j, :i, :j, 2, 6)
147+
# end
141148

142149
function myvar!(s², A, x̄)
143150
@.= 0
@@ -686,8 +693,8 @@ using Test
686693
basis = rand(r, (dim, nbasis));
687694
coeffs = rand(T, nbasis);
688695
P = rand(T, dim, maxdeg+1);
689-
mvp(P, basis, coeffs)
690-
mvpavx(P, basis, coeffs)
696+
# mvp(P, basis, coeffs)
697+
# mvpavx(P, basis, coeffs)
691698
mvpv = mvp(P, basis, coeffs)
692699
@test mvpv mvpavx(P, basis, coeffs)
693700
if VERSION > v"1.1"

0 commit comments

Comments
 (0)