Skip to content

Commit b81bbbf

Browse files
committed
ifelse reduction changes
1 parent 09798ea commit b81bbbf

File tree

4 files changed

+25
-17
lines changed

4 files changed

+25
-17
lines changed

src/codegen/lower_store.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,15 @@ function reduce_expr!(q::Expr, toreduct::Symbol, op::Operation, u₁::Int, u₂:
5050
end
5151
if (u₁ == 1) | (~isu₁unrolled)
5252
push!(q.args, Expr(:(=), Symbol(toreduct, "##onevec##"), _toreduct))
53-
else
53+
elseif instruction(op).instr :ifelse
5454
push!(q.args, Expr(:(=), Symbol(toreduct, "##onevec##"), Expr(:call, reduction_to_single_vector(op), _toreduct)))
55+
else
56+
fifelse = let u₁=u₁
57+
ifelse_reduction(:IfElseCollapser,op) do opv
58+
Symbol(mangledvar(opv), '_', u₁), tuple()
59+
end
60+
end
61+
push!(q.args, Expr(:(=), Symbol(toreduct, "##onevec##"), Expr(:call, fifelse, _toreduct, staticexpr(1))))
5562
end
5663
nothing
5764
end

src/modeling/determinestrategy.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ function solve_unroll_iter(X, R, u₁L, u₂L, u₁range, u₂range)
428428
u₁best, u₂best, bestcost
429429
end
430430

431-
function solve_unroll_lagrange(X, R, u₁L, u₂L, u₁step::Int, u₂step::Int, atleast32registers::Bool)
431+
function solve_unroll_lagrange(X, R, u₁L, u₂L, u₁step::Int, u₂step::Int, atleast31registers::Bool)
432432
X₁, X₂, X₃, X₄ = X[1], X[2], X[3], X[4]
433433
# If we don't have opmask registers, masks probably occupy a vector register (e.g., on CPUs with AVX but not AVX512)
434434
R₁, R₂, R₃, R₄ = R[1], R[2], R[3], R[4]
@@ -443,8 +443,8 @@ function solve_unroll_lagrange(X, R, u₁L, u₂L, u₁step::Int, u₂step::Int,
443443
u₂float = (RR - u₁float*R₂)/(u₁float*R₁)
444444
if !(isfinite(u₂float) & isfinite(u₁float)) # brute force
445445
u₁low = u₂low = 1
446-
u₁high = iszero(X₂) ? 2 : (atleast32registers ? 8 : 6)
447-
u₂high = iszero(X₃) ? 2 : (atleast32registers ? 8 : 6)
446+
u₁high = iszero(X₂) ? 2 : (atleast31registers ? 8 : 6)
447+
u₂high = iszero(X₃) ? 2 : (atleast31registers ? 8 : 6)
448448
return solve_unroll_iter(X, R, u₁L, u₂L, u₁low:u₁step:u₁high, u₂low:u₂step:u₂high)
449449
end
450450
u₁low = floor(Int, u₁float)
@@ -457,7 +457,7 @@ function solve_unroll_lagrange(X, R, u₁L, u₂L, u₁step::Int, u₂step::Int,
457457
if u₂low u₂high
458458
u₂low = solve_unroll_constU(R, u₁high)
459459
end
460-
maxunroll = atleast32registers ? (((X₂ > 0) & (X₃ > 0)) ? 10 : 8) : 6
460+
maxunroll = atleast31registers ? (((X₂ > 0) & (X₃ > 0)) ? 10 : 8) : 6
461461
u₁low = (clamp(u₁low, 1, maxunroll) ÷ u₁step) * u₁step
462462
u₂low = (clamp(u₂low, 1, maxunroll) ÷ u₂step) * u₂step
463463
u₁high = clamp(u₁high, 1, maxunroll)
@@ -482,9 +482,9 @@ end
482482
# floor(Int, (dynamic_register_count() - R[3] - R[4] - u₂*R[5]) / (u₂ * R[1] + R[2]))
483483
# end
484484
# Tiling here is about alleviating register pressure for the UxT
485-
function solve_unroll(X, R, u₁max, u₂max, u₁L, u₂L, u₁step, u₂step, atleast32registers::Bool)
485+
function solve_unroll(X, R, u₁max, u₂max, u₁L, u₂L, u₁step, u₂step, atleast31registers::Bool)
486486
# iszero(first(R)) && return -1,-1,Inf #solve_smalltilesize(X, R, u₁max, u₂max)
487-
u₁, u₂, cost = solve_unroll_lagrange(X, R, u₁L, u₂L, u₁step, u₂step, atleast32registers)
487+
u₁, u₂, cost = solve_unroll_lagrange(X, R, u₁L, u₂L, u₁step, u₂step, atleast31registers)
488488
# u₂ -= u₂ & 1
489489
# u₁ = min(u₁, u₂)
490490
u₁_too_large = u₁ > u₁max
@@ -539,7 +539,7 @@ function solve_unroll(
539539
u₁loop = getloop(ls, u₁loopsym)
540540
u₂loop = getloop(ls, u₂loopsym)
541541
solve_unroll(
542-
u₁loopsym, u₂loopsym, cost_vec, reg_pressure, W, vloopsym, u₁loop, u₂loop, u₁step, u₂step, reg_count(ls) 32
542+
u₁loopsym, u₂loopsym, cost_vec, reg_pressure, W, vloopsym, u₁loop, u₂loop, u₁step, u₂step, reg_count(ls) 31
543543
)
544544
end
545545

@@ -550,9 +550,9 @@ function solve_unroll(
550550
W::Int, vloopsym::Symbol,
551551
u₁loop::Loop, u₂loop::Loop,
552552
u₁step::Int, u₂step::Int,
553-
atleast32registers::Bool
553+
atleast31registers::Bool
554554
)
555-
maxu₂base = maxu₁base = atleast32registers ? 10 : 6#8
555+
maxu₂base = maxu₁base = atleast31registers ? 10 : 6#8
556556
maxu₂ = maxu₂base#8
557557
maxu₁ = maxu₁base#8
558558
u₁L = length(u₁loop)
@@ -593,7 +593,7 @@ function solve_unroll(
593593
else
594594
u₂Lf = Float64(u₂L)
595595
end
596-
u₁, u₂, cost = solve_unroll(cost_vec, reg_pressure, maxu₁, maxu₂, u₁Lf, u₂Lf, u₁step, u₂step, atleast32registers)
596+
u₁, u₂, cost = solve_unroll(cost_vec, reg_pressure, maxu₁, maxu₂, u₁Lf, u₂Lf, u₁step, u₂step, atleast31registers)
597597
# heuristic to more evenly divide small numbers of iterations
598598
if isstaticloop(u₂loop)
599599
u₂ = maybedemotesize(u₂, length(u₂loop), u₁, u₁loop, maxu₂base)

test/gemm.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@
360360
if LoopVectorization.cache_linesize() == LoopVectorization.register_size()
361361
@test LoopVectorization.choose_order(lsr2amb) == ([:n, :m, :k], :m, :n, :m, 3, 7)
362362
else
363-
@test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :n, :m, :m, 5, 4)
363+
@test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :m, :n, :m, 3, 7)
364364
end
365365
elseif LoopVectorization.register_count() == 16
366366
# @test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :m, :n, :m, 1, 6)
@@ -637,7 +637,7 @@
637637
@test LoopVectorization.choose_order(lsAtmulBt8) == ([:n, :m, :k], :m, :n, :m, 1, 8)
638638
# @test LoopVectorization.choose_order(lsAtmulBt8) == ([:n, :m, :k], :k, :n, :m, 1, 8)
639639
elseif LoopVectorization.register_size() == 16
640-
@test LoopVectorization.choose_order(lsAtmulBt8) == ([:n, :m, :k], :m, :n, :m, 4, 4)
640+
@test LoopVectorization.choose_order(lsAtmulBt8) == ([:n, :m, :k], :m, :n, :m, 2, 8)
641641
end
642642
elseif LoopVectorization.register_count() == 16
643643
@test LoopVectorization.choose_order(lsAtmulBt8) == ([:n, :m, :k], :m, :n, :m, 2, 4)

test/miscellaneous.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ using Test
99
s += x[m] * A[m,n] * y[n]
1010
end);
1111
lsdot3 = LoopVectorization.loopset(dot3q);
12-
if LoopVectorization.register_count() == 32
13-
# @test LoopVectorization.choose_order(lsdot3) == ([:n, :m], :m, :n, :m, Unum, Tnum)#&-2
12+
if LoopVectorization.register_count() 32
13+
@test LoopVectorization.choose_order(lsdot3) == ([:n, :m], :n, :m, :m, 2, 6)
14+
elseif Bool(LoopVectorization.has_opmask_registers())
1415
@test LoopVectorization.choose_order(lsdot3) == ([:n, :m], :n, Symbol("##undefined##"), :m, 4, -1)
1516
else
16-
@test LoopVectorization.choose_order(lsdot3) == ([:n, :m], :n, :m, :m, 2, 6)
17+
@test LoopVectorization.choose_order(lsdot3) == ([:n, :m], :n, :m, :m, 2, 8)
1718
end
1819

1920
@static if VERSION < v"1.4"
@@ -71,7 +72,7 @@ using Test
7172
lssubcol = LoopVectorization.loopset(subcolq);
7273
# @test LoopVectorization.choose_order(lssubcol) == (Symbol[:i,:j], :i, Symbol("##undefined##"), :j, 1, -1)
7374
# @test LoopVectorization.choose_order(lssubcol) == (Symbol[:i,:j], :j, :i, :j, 1, 8)
74-
@test LoopVectorization.choose_order(lssubcol) == (Symbol[:i,:j], :j, :i, :j, 1, ifelse(LoopVectorization.register_count() == 32, 8, 6))
75+
@test LoopVectorization.choose_order(lssubcol) == (Symbol[:i,:j], :j, :i, :j, 1, ifelse((LoopVectorization.register_count() == 32), 8, 6))
7576

7677
# if LoopVectorization.register_count() != 8
7778
# # @test LoopVectorization.choose_order(lssubcol) == (Symbol[:j,:i], :j, :i, :j, Unum, Tnum)

0 commit comments

Comments
 (0)