Skip to content

Commit 248ad9c

Browse files
committed
Apply masks to more selfop-funs
1 parent fc0b45a commit 248ad9c

File tree

5 files changed

+58
-8
lines changed

5 files changed

+58
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,5 @@ Static = "0.2, 0.3"
3030
StrideArraysCore = "0.1.12"
3131
ThreadingUtilities = "0.4.5"
3232
UnPack = "1"
33-
VectorizationBase = "0.20.27"
33+
VectorizationBase = "0.20.31"
3434
julia = "1.5"

src/codegen/lower_compute.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ end
234234
push!(q.args, :($gf(vargs, $k, false)))
235235
end
236236
return Expr(:block, Expr(:meta, :inline), q)
237+
# return Expr(:block, Expr(:meta, :inline), :(@show($q)))
237238
end
238239
if Sreduced
239240
M = N
@@ -272,6 +273,7 @@ end
272273
push!(t.args, :($gf(dd, $m, false)))
273274
end
274275
push!(q.args, :(VecUnroll($t)))
276+
# push!(q.args, :(@show(VecUnroll($t))))
275277
q
276278
end
277279

@@ -560,7 +562,8 @@ function lower_compute!(
560562
end
561563
end
562564
end
563-
selfdepreduce = ifelse(((!u₁unrolledsym) & isu₁unrolled(op)) & (u₁ > 1), selfdep, 0)
565+
selfdepreduce = ifelse(((!u₁unrolledsym) & isu₁unrolled(op)) & (u₁ > 1), selfdep, 0)
566+
# @show selfdepreduce, selfdep, maskreduct, op
564567
if maskreduct
565568
ifelsefunc = if us.u₁ == 1
566569
:ifelse # don't need to be fancy
@@ -577,7 +580,7 @@ function lower_compute!(
577580
insert!(instrcall.args, 4, staticexpr(u₁))
578581
insert!(instrcall.args, 5, staticexpr(selfdepreduce))
579582
end
580-
elseif all(in(loopdependencies(op)), reduceddeps) || any(opp -> mangledvar(opp) === mangledvar(op), parents_op)
583+
elseif all(in(loopdependencies(op)), reduceddeps) || selfdep 0#any(opp -> mangledvar(opp) === mangledvar(op), parents_op)
581584
# Here, we are evaluating the function, and then `ifelse`-ing it with `hasf == false`.
582585
# That means we still need to adjust the `instrcall` in case we're reducing/accumulating across the unroll
583586
if ifelsefunc :ifelse # ifelse means it's unrolled by 1, no need
@@ -593,8 +596,8 @@ function lower_compute!(
593596
push!(q.args, Expr(:(=), varsym, Expr(:call, lv(:ifelse), MASKSYMBOL, instrcall, selfopname)))
594597
end
595598
return
596-
elseif selfdep != 0
597-
make_partial_map!(instrcall, selfopname, u₁, selfdepreduce)
599+
# elseif selfdep != 0
600+
# make_partial_map!(instrcall, selfopname, u₁, selfdepreduce)
598601
end
599602
elseif selfdep != 0 && (dopartialmap ||
600603
(isouterreduct && (opunrolled) && (u₁ < us.u₁)) ||

src/modeling/costs.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,9 @@ const COST = Dict{Symbol,InstructionCost}(
269269
:convert => InstructionCost(4,0.5),
270270
:vpermilps177 => InstructionCost(1, 1.0),
271271
:vmovsldup => InstructionCost(1, 1.0),
272-
:vmovshdup => InstructionCost(1, 1.0)
272+
:vmovshdup => InstructionCost(1, 1.0),
273+
:exponent => InstructionCost(8, 1.0),
274+
:significand => InstructionCost(8, 1.0)
273275
)
274276

275277
# # @inline prefetch0(x::Ptr, i) = VectorizationBase.prefetch(x, Val{3}(), Val{0}())
@@ -596,7 +598,7 @@ const FUNCTIONSYMBOLS = IdDict{Type{<:Function},Instruction}(
596598
typeof(ifelse) => :ifelse,
597599
typeof(identity) => :identity,
598600
typeof(conj) => :identity,#conj,
599-
typeof(÷) => :div_fast
601+
typeof(÷) => :vdiv_fast
600602
# typeof(zero) => :zero,
601603
# typeof(one) => :one,
602604
# typeof(axes) => :axes,

src/reconstruct_loopset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ Execute an `@turbo` block. The block's code is represented via the arguments:
711711
@aggressive_constprop @generated function _turbo_!(
712712
::Val{var"#UNROLL#"}, ::Val{var"#OPS#"}, ::Val{var"#ARF#"}, ::Val{var"#AM#"}, ::Val{var"#LPSYM#"}, ::Val{Tuple{var"#LB#",var"#V#"}}, var"#flattened#var#arguments#"::Vararg{Any,var"#num#vargs#"}
713713
) where {var"#UNROLL#", var"#OPS#", var"#ARF#", var"#AM#", var"#LPSYM#", var"#LB#", var"#V#", var"#num#vargs#"}
714-
# 1 + 1 # Irrelevant line you can comment out/in to force recompilation...
714+
1 + 1 # Irrelevant line you can comment out/in to force recompilation...
715715
ls = _turbo_loopset(var"#OPS#", var"#ARF#", var"#AM#", var"#LPSYM#", var"#LB#".parameters, var"#V#".parameters, var"#UNROLL#")
716716
pushfirst!(ls.preamble.args, :(var"#lv#tuple#args#" = reassemble_tuple(Tuple{var"#LB#",var"#V#"}, var"#flattened#var#arguments#")))
717717
# return @show avx_body(ls, var"#UNROLL#")

test/outer_reductions.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,48 @@ function test_awmean(::Type{T}) where {T}
5454
end
5555
end
5656

57+
function logℒ_fast(α, β, t, c, x)
58+
= abs(α)
59+
n, k = size(x)
60+
61+
(n == length(t) == length(c) && length(β) == k + 1) || throw(DimensionMismatch())
62+
s = zero(typeof(α))
63+
@inbounds for i in 1:n
64+
xb = sum(@inbounds(x[i, j] * β[j+1]) for j in 1:k) + β[1]
65+
ti = t[i]
66+
s += (1 - (c[i] == ti)) * (log(eα) + (eα - 1) * log(ti) + xb) - ti^* exp(xb)
67+
end
68+
return s
69+
end
70+
function logℒ_fast_turbo(α, β, t, c, x)
71+
= abs(α)
72+
n, k = size(x)
73+
74+
(n == length(t) == length(c) && length(β) == k + 1) || throw(DimensionMismatch())
75+
s = zero(typeof(α))
76+
@turbo for i in 1:n
77+
xb0 = 0.0
78+
for j in 1:k
79+
xb0 += x[i,j] * β[j+1]
80+
end
81+
xb = xb0 + β[1]
82+
ti = t[i]
83+
s += (1 - (c[i] == ti)) * (log(eα) + (eα - 1) * log(ti) + xb) - ti^* exp(xb)
84+
end
85+
return s
86+
end
87+
88+
function test_logℒ(n, k)
89+
t = rand(n)
90+
c = copy(t);
91+
b = rand(n) .> 0.5
92+
c[b] .= rand.();
93+
x = rand(n,k)
94+
α = 2.85
95+
β = rand(k+1)
96+
@test logℒ_fast(α, β, t, c, x) logℒ_fast_turbo(α, β, t, c, x)
97+
end
98+
5799
function not_an_outer_reduct!(r, N::Int, x = 2.0, y= nothing) # there was a bug where this was classified as one
58100
@turbo for i eachindex(r)
59101
acc = y === nothing ? x : r[i]
@@ -69,5 +111,8 @@ end
69111
test_awmean(T)
70112
end
71113
@test all(==(7.4), not_an_outer_reduct!(Vector{Float64}(undef, 5), 17, 7.4))
114+
for n 1:20, k 1:5
115+
test_logℒ(n,k)
116+
end
72117
end
73118

0 commit comments

Comments
 (0)