Skip to content

Commit 9a04bce

Browse files
committed
2 parents 84c85e8 + 77b45cf commit 9a04bce

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

src/determinestrategy.jl

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -274,27 +274,27 @@ function solve_unroll_iter(X, R, u₁L, u₂L, u₁range, u₂range)
274274
u₁best, u₂best, bestcost
275275
end
276276

277-
function solve_unroll(X, R, u₁L, u₂L)
277+
function solve_unroll(X, R, u₁L, u₂L, u₁step, u₂step)
278278
X₁, X₂, X₃, X₄ = X[1], X[2], X[3], X[4]
279279
R₁, R₂, R₃, R₄, R₅ = R[1], R[2], R[3], R[4], R[5]
280-
iszero(R₅) || return solve_unroll_iter(X, R, u₁L, u₂L, 1:10, 1:10)
280+
iszero(R₅) || return solve_unroll_iter(X, R, u₁L, u₂L, u₁step:u₁step:10, u₂step:u₂step:10)
281281
RR = REGISTER_COUNT - R₃ - R₄
282282
a = R₂^2*X₃ -R₁*X₄ * R₂ - R₁*X₂*RR
283283
b = R₁ * X₄ * RR - R₁ * X₄ * RR - 2X₃*RR*R₂
284284
c = X₃*RR^2
285285
discriminant = b^2 - 4a*c
286286
discriminant < 0 && return -1,-1,Inf
287-
u₁float = max(1.0, (sqrt(discriminant) + b) / (-2a)) # must be at least 1
287+
u₁float = max(float(u₁step), (sqrt(discriminant) + b) / (-2a)) # must be at least 1
288288
u₂float = (RR - u₁float*R₂)/(u₁float*R₁)
289289
if !(isfinite(u₂float) && isfinite(u₁float))
290290
return 4, 4, unroll_cost(X, 4, 4, u₁L, u₂L)
291291
# return itertilesize(X, u₁L, u₂L)
292292
end
293293
u₁low = floor(Int, u₁float)
294-
u₂low = max(1, floor(Int, u₂float)) # must be at least 1
295-
u₁high = solve_unroll_constT(R, u₂low) + 1
296-
u₂high = solve_unroll_constU(R, u₁low) + 1
297-
solve_unroll_iter(X, R, u₁L, u₂L, u₁low:u₁high, u₂low:u₂high)
294+
u₂low = max(u₂step, floor(Int, u₂float)) # must be at least 1
295+
u₁high = solve_unroll_constT(R, u₂low) + u₁step
296+
u₂high = solve_unroll_constU(R, u₁low) + u₂step
297+
solve_unroll_iter(X, R, u₁L, u₂L, u₁low:u₁step:u₁high, u₂low:u₂step:u₂high)
298298
end
299299

300300
function solve_unroll_constU(R::AbstractVector, u₁::Int)
@@ -308,9 +308,9 @@ function solve_unroll_constT(ls::LoopSet, u₂::Int)
308308
floor(Int, (REGISTER_COUNT - R[3] - R[4] - u₂*R[5]) / (u₂ * R[1] + R[2]))
309309
end
310310
# Tiling here is about alleviating register pressure for the UxT
311-
function solve_unroll(X, R, u₁max, u₂max, u₁L, u₂L)
311+
function solve_unroll(X, R, u₁max, u₂max, u₁L, u₂L, u₁step, u₂step)
312312
# iszero(first(R)) && return -1,-1,Inf #solve_smalltilesize(X, R, u₁max, u₂max)
313-
u₁, u₂, cost = solve_unroll(X, R, u₁L, u₂L)
313+
u₁, u₂, cost = solve_unroll(X, R, u₁L, u₂L, u₁step, u₂step)
314314
# u₂ -= u₂ & 1
315315
# u₁ = min(u₁, u₂)
316316
u₁_too_large = u₁ > u₁max
@@ -354,12 +354,19 @@ function solve_unroll(
354354
ls::LoopSet, u₁loopsym::Symbol, u₂loopsym::Symbol,
355355
cost_vec::AbstractVector{Float64},
356356
reg_pressure::AbstractVector{Float64},
357-
W::Int, vectorized::Symbol
357+
W::Int, vectorized::Symbol, rounduᵢ::Int
358358
)
359+
(u₁step, u₂step) = if rounduᵢ == 1 # max is to safeguard against some weird arch I've never heard of.
360+
(max(1,VectorizationBase.CACHELINE_SIZE ÷ VectorizationBase.REGISTER_SIZE), 1)
361+
elseif rounduᵢ == 2
362+
(1, max(1,VectorizationBase.CACHELINE_SIZE ÷ VectorizationBase.REGISTER_SIZE))
363+
else
364+
(1, 1)
365+
end
359366
u₁loop = getloop(ls, u₁loopsym)
360367
u₂loop = getloop(ls, u₂loopsym)
361368
solve_unroll(
362-
u₁loopsym, u₂loopsym, cost_vec, reg_pressure, W, vectorized, u₁loop, u₂loop
369+
u₁loopsym, u₂loopsym, cost_vec, reg_pressure, W, vectorized, u₁loop, u₂loop, u₁step, u₂step
363370
)
364371
end
365372

@@ -368,7 +375,8 @@ function solve_unroll(
368375
cost_vec::AbstractVector{Float64},
369376
reg_pressure::AbstractVector{Float64},
370377
W::Int, vectorized::Symbol,
371-
u₁loop::Loop, u₂loop::Loop
378+
u₁loop::Loop, u₂loop::Loop,
379+
u₁step::Int, u₂step::Int
372380
)
373381
maxu₂base = maxu₁base = REGISTER_COUNT == 32 ? 10 : 6#8
374382
maxu₂ = maxu₂base#8
@@ -393,7 +401,7 @@ function solve_unroll(
393401
u₁L = u₁loopsym === vectorized ? cld(u₁L,W) : u₁L
394402
maxu₁ = min(4maxu₁, u₁L)
395403
end
396-
u₁, u₂, cost = solve_unroll(cost_vec, reg_pressure, maxu₁, maxu₂, length(u₁loop), length(u₂loop))
404+
u₁, u₂, cost = solve_unroll(cost_vec, reg_pressure, maxu₁, maxu₂, length(u₁loop), length(u₂loop), u₁step, u₂step)
397405
# heuristic to more evenly divide small numbers of iterations
398406
if isstaticloop(u₂loop)
399407
u₂ = maybedemotesize(u₂, length(u₂loop), u₁, u₁loop, maxu₂base)
@@ -637,6 +645,7 @@ function evaluate_cost_tile(
637645
u₁reached = u₂reached = false
638646
choose_to_inline = Ref(false)
639647
copyto!(names(ls), order); reverse!(names(ls))
648+
prefetch_good_idea = false
640649
for n 1:N
641650
itersym = order[n]
642651
if itersym == u₁loopsym
@@ -688,6 +697,7 @@ function evaluate_cost_tile(
688697
rt, lat, rp = cost(ls, op, vectorized, Wshift, size_T)
689698
if isload(op) && !iszero(prefetchisagoodidea(ls, op, UnrollArgs(4, unrollsyms, 4, 0)))
690699
rt += 0.5VectorizationBase.REGISTER_SIZE / VectorizationBase.CACHELINE_SIZE
700+
prefetch_good_idea = true
691701
end
692702
# @show isunrolled₁, isunrolled₂, op rt, lat, rp
693703
rp = opisininnerloop ? rp : zero(rp) # we only care about register pressure within the inner most loop
@@ -710,10 +720,11 @@ function evaluate_cost_tile(
710720
costpenalty = (sum(reg_pressure) > REGISTER_COUNT) ? 2 : 1
711721
# @show order, vectorized cost_vec reg_pressure
712722
# @show solve_unroll(ls, u₁loopsym, u₂loopsym, cost_vec, reg_pressure)
713-
u₁, u₂, ucost = solve_unroll(ls, u₁loopsym, u₂loopsym, cost_vec, reg_pressure, W, vectorized)
723+
u₁v = vectorized === u₁loopsym; u₂v = vectorized === u₂loopsym
724+
round_uᵢ = prefetch_good_idea ? (u₁v ? 1 : (u₂v ? 2 : 0)) : 0
725+
u₁, u₂, ucost = solve_unroll(ls, u₁loopsym, u₂loopsym, cost_vec, reg_pressure, W, vectorized, round_uᵢ)
714726
outer_reduct_penalty = length(ls.outer_reductions) * (u₁ + isodd(u₁))
715727
favor_bigger_u₂ = u₁ - u₂
716-
u₁v = vectorized === u₁loopsym; u₂v = vectorized === u₂loopsym
717728
favor_smaller_vectorized = u₁v ? ( u₁ - u₂ ) : (u₂v ? ( u₂ - u₁ ) : 0 )
718729
favor_u₁_vectorized = -0.2u₁v
719730
favoring_heuristics = favor_bigger_u₂ + 0.5favor_smaller_vectorized + favor_u₁_vectorized

test/gemm.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
@testset "GEMM" begin
2-
# using LoopVectorization, LinearAlgebra, Test; T = Float64
3-
Unum, Tnum = LoopVectorization.REGISTER_COUNT == 16 ? (3, 4) : (3, 9)
2+
# using LoopVectorization, LinearAlgebra, Test; T = Float64
3+
Unum, Tnum = LoopVectorization.REGISTER_COUNT == 16 ? (2, 6) : (3, 9)
44
Unumt, Tnumt = LoopVectorization.REGISTER_COUNT == 16 ? (3, 4) : (5, 5)
55
if LoopVectorization.REGISTER_COUNT != 8
66
@test LoopVectorization.mᵣ == Unum
@@ -476,8 +476,6 @@
476476
end
477477
return C
478478
end
479-
A = rand(35,23); B = rand(35,17); C1 = A' * B; C2 = similar(C1);
480-
mulCAtB_2x2block_avx!(C2, A, B);
481479

482480
function mulCAtB_2x2blockavx_noinline!(C, A, B)
483481
M, N = size(C); K = size(B,1)
@@ -846,7 +844,6 @@
846844
fill!(C, 9999.999); rank2AmulBavx_noinline!(C, Aₘ, Aₖ′', B)
847845
@test C C2
848846
end
849-
850847
end
851848
end
852849

0 commit comments

Comments
 (0)