Skip to content

Commit 753aa3d

Browse files
committed
A few tweeks, most importantly use check_args for broadcasts.
1 parent 27079c4 commit 753aa3d

File tree

10 files changed

+58
-44
lines changed

10 files changed

+58
-44
lines changed

src/broadcast.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ function add_broadcast!(
6666
K = gensym(:K)
6767
mA = gensym(:Aₘₖ)
6868
mB = gensym(:Bₖₙ)
69-
pushpreamble!(ls, Expr(:(=), mA, Expr(:(.), bcname, QuoteNode(:a))))
70-
pushpreamble!(ls, Expr(:(=), mB, Expr(:(.), bcname, QuoteNode(:b))))
71-
pushpreamble!(ls, Expr(:(=), K, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__,Symbol(@__FILE__)), Expr(:ref, Expr(:call, :size, mB), 1))))
69+
pushprepreamble!(ls, Expr(:(=), mA, Expr(:(.), bcname, QuoteNode(:a))))
70+
pushprepreamble!(ls, Expr(:(=), mB, Expr(:(.), bcname, QuoteNode(:b))))
71+
pushprepreamble!(ls, Expr(:(=), K, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__,Symbol(@__FILE__)), Expr(:ref, Expr(:call, :size, mB), 1))))
7272
k = gensym(:k)
7373
add_loop!(ls, Loop(k, 1, K), k)
7474
m = loopsyms[1];
@@ -139,7 +139,7 @@ end
139139
function extract_all_1_array!(ls::LoopSet, bcname::Symbol, N::Int, elementbytes::Int)
140140
refextract = gensym(bcname)
141141
ref = Expr(:ref, bcname); append!(ref.args, [1 for n 1:N])
142-
pushpreamble!(ls, Expr(:(=), refextract, ref))
142+
pushprepreamble!(ls, Expr(:(=), refextract, ref))
143143
return add_constant!(ls, refextract, elementbytes) # or replace elementbytes with sizeof(T) ? u
144144
end
145145
function add_broadcast!(
@@ -159,7 +159,7 @@ function add_broadcast_adjoint_array!(
159159
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{A}, elementbytes::Int
160160
) where {T,N,A<:AbstractArray{T,N}}
161161
parent = gensym(:parent)
162-
pushpreamble!(ls, Expr(:(=), parent, Expr(:call, :parent, bcname)))
162+
pushprepreamble!(ls, Expr(:(=), parent, Expr(:call, :parent, bcname)))
163163
# isone(length(loopsyms)) && return extract_all_1_array!(ls, bcname, N, elementbytes)
164164
ref = ArrayReference(parent, Symbol[loopsyms[N + 1 - n] for n 1:N])
165165
add_simple_load!( ls, destname, ref, elementbytes, true, true )::Operation
@@ -198,7 +198,7 @@ function add_broadcast!(
198198
ls::LoopSet, ::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, ::Type{Base.RefValue{T}}, elementbytes::Int
199199
) where {T}
200200
refextract = gensym(bcname)
201-
pushpreamble!(ls, Expr(:(=), refextract, Expr(:ref, bcname)))
201+
pushprepreamble!(ls, Expr(:(=), refextract, Expr(:ref, bcname)))
202202
add_constant!(ls, refextract, elementbytes) # or replace elementbytes with sizeof(T) ? u
203203
end
204204
function add_broadcast!(
@@ -210,7 +210,7 @@ function add_broadcast!(
210210
inds[2:end] .= @view(loopsyms[1:N])
211211
add_simple_load!(ls, destname, ArrayReference(bcname, inds), elementbytes, true, true)
212212
end
213-
BroadcastedArray{S<:Broadcast.AbstractArrayStyle,F,A} = Broadcasted{S,Nothing,F,A}
213+
const BroadcastedArray{S<:Broadcast.AbstractArrayStyle,F,A} = Broadcasted{S,Nothing,F,A}
214214
function add_broadcast!(
215215
ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol},
216216
@nospecialize(B::Type{<:BroadcastedArray}),
@@ -219,7 +219,7 @@ function add_broadcast!(
219219
S,_,F,A = B.parameters
220220
instr = get(FUNCTIONSYMBOLS, F) do
221221
f = gensym(:func)
222-
pushpreamble!(ls, Expr(:(=), f, Expr(:(.), bcname, QuoteNode(:f))))
222+
pushprepreamble!(ls, Expr(:(=), f, Expr(:(.), bcname, QuoteNode(:f))))
223223
Instruction(bcname, f)
224224
end
225225
args = A.parameters
@@ -231,7 +231,7 @@ function add_broadcast!(
231231
# reduceddeps = Symbol[]
232232
for (i,arg) enumerate(args)
233233
argname = gensym(:arg)
234-
pushpreamble!(ls, Expr(:(=), argname, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__,Symbol(@__FILE__)), Expr(:ref, bcargs, i))))
234+
pushprepreamble!(ls, Expr(:(=), argname, Expr(:macrocall, Symbol("@inbounds"), LineNumberNode(@__LINE__,Symbol(@__FILE__)), Expr(:ref, bcargs, i))))
235235
# dynamic dispatch
236236
parent = add_broadcast!(ls, gensym(:temp), argname, loopsyms, arg, elementbytes)::Operation
237237
push!(parents, parent)
@@ -272,8 +272,10 @@ end
272272
# return ls
273273
q = lower(ls)
274274
push!(q.args, :dest)
275-
pushfirst!(q.args, Expr(:meta,:inline))
276275
# @show q
276+
# q
277+
q = Expr(:block, ls.prepreamble, Expr(:if, check_args_call(ls), q, :(Base.Broadcast.materialize!(dest, bc))))
278+
isone(N) && pushfirst!(q.args, Expr(:meta,:inline))
277279
q
278280
# ls
279281
end
@@ -285,7 +287,7 @@ end
285287
loopsyms = [gensym(:n) for n 1:N]
286288
ls = LoopSet(Mod)
287289
ls.isbroadcast[] = true
288-
pushpreamble!(ls, Expr(:(=), :dest, Expr(:call, :parent, :dest′)))
290+
pushprepreamble!(ls, Expr(:(=), :dest, Expr(:call, :parent, :dest′)))
289291
sizes = Expr(:tuple)
290292
for (n,itersym) enumerate(loopsyms)
291293
Nsym = gensym(:N)
@@ -299,7 +301,8 @@ end
299301
resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product
300302
q = lower(ls)
301303
push!(q.args, :dest′)
302-
pushfirst!(q.args, Expr(:meta,:inline))
304+
q = Expr(:block, ls.prepreamble, Expr(:if, check_args_call(ls), q, :(Base.Broadcast.materialize!(dest′, bc))))
305+
isone(N) && pushfirst!(q.args, Expr(:meta,:inline))
303306
q
304307
# ls
305308
end
@@ -329,4 +332,4 @@ end
329332
end
330333

331334
vmaterialize!(dest, bc, ::Val{mod}) where {mod} = Base.Broadcast.materialize!(dest, bc)
332-
335+

src/costs.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,8 @@ const COST = Dict{Symbol,InstructionCost}(
231231
:sinpi_fast => InstructionCost(18,15.0,68.0,23),
232232
:cospi_fast => InstructionCost(18,15.0,68.0,26),
233233
:sincospi_fast => InstructionCost(25,22.0,70.0,26),
234+
:tanh => InstructionCost(40,40.0,40.0,26), # FIXME
235+
# :tanh_fast => InstructionCost(25,22.0,70.0,26), # FIXME
234236
:identity => InstructionCost(0,0.0,0.0,0),
235237
:adjoint => InstructionCost(0,0.0,0.0,0),
236238
:conj => InstructionCost(0,0.0,0.0,0),
@@ -442,6 +444,8 @@ const FUNCTIONSYMBOLS = IdDict{Type{<:Function},Instruction}(
442444
typeof(sincos) => :sincos,
443445
typeof(Base.FastMath.sincos_fast) => :sincos,
444446
typeof(SLEEFPirates.sincos) => :sincos,
447+
typeof(Base.tanh) => :tanh,
448+
# typeof(SLEEFPirates.tanh_fast) => :tanh_fast,
445449
typeof(max) => :max,
446450
typeof(min) => :min,
447451
typeof(<<) => :<<,

src/determinestrategy.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,15 @@ function cost(ls::LoopSet, op::Operation, vectorized::Symbol, Wshift::Int, size_
7272
r = (1 << Wshift)
7373
srt *= r
7474
sl *= r
75-
# else # vmov(a/u)pd
75+
elseif isload(op) & length(loopdependencies(op)) > 1# vmov(a/u)pd
76+
# penalize vectorized loads with more than 1 loopdep
77+
# heuristic; more than 1 loopdep means that many loads will not be aligned
78+
# Roughly corresponds to double-counting loads crossing cacheline boundaries
79+
# TODO: apparently the new ARM A64FX CPU (with 512 bit vectors) is NOT penalized for unaligned loads
80+
# would be nice to add a check for this CPU, to see if such a penalty is still appropriate.
81+
# Also, once more SVE (scalable vector extension) CPUs are released, would be nice to know if
82+
# this feature is common to all of them.
83+
srt += 0.5VectorizationBase.REGISTER_SIZE / VectorizationBase.CACHELINE_SIZE
7684
end
7785
elseif instr === :setindex! # broadcast or reductionstore; if store we want to penalize reduction
7886
srt *= 3
@@ -857,12 +865,14 @@ function evaluate_cost_tile(
857865
elseif load_elimination_cost_factor!(cost_vec, reg_pressure, choose_to_inline, ls, op, iters[id], unrollsyms, Wshift, size_T)
858866
continue
859867
end
860-
elseif isconstant(op)
868+
#elseif isconstant(op)
861869
end
862870
rt, lat, rp = cost(ls, op, vectorized, Wshift, size_T)
863-
if isload(op) && !iszero(prefetchisagoodidea(ls, op, UnrollArgs(4, unrollsyms, 4, 0)))
864-
rt += 0.5VectorizationBase.REGISTER_SIZE / VectorizationBase.CACHELINE_SIZE
865-
prefetch_good_idea = true
871+
if isload(op)
872+
if !iszero(prefetchisagoodidea(ls, op, UnrollArgs(4, unrollsyms, 4, 0)))
873+
# rt += 0.5VectorizationBase.REGISTER_SIZE / VectorizationBase.CACHELINE_SIZE
874+
prefetch_good_idea = true
875+
end
866876
end
867877
# rp = (opisininnerloop && !(loadintostore(ls, op))) ? rp : zero(rp) # we only care about register pressure within the inner most loop
868878
rp = opisininnerloop ? rp : zero(rp) # we only care about register pressure within the inner most loop
@@ -871,9 +881,11 @@ function evaluate_cost_tile(
871881
if isstore(op) & (!u₁reducesrt) & (!u₂reducesrt)
872882
irreducible_storecosts += rt
873883
end
884+
# @show u₁reducesrt, u₂reducesrt, op, rt, rto, rp
874885
update_costs!(cost_vec, rt, u₁reducesrt, u₂reducesrt)
875886
update_costs!(reg_pressure, rp, u₁reducesrp, u₂reducesrp)
876887
end
888+
# reg_pressure[1] = max(reg_pressure[1], length(ls.outer_reductions))
877889
# @inbounds ((cost_vec[4] > 0) || ((cost_vec[2] > 0) & (cost_vec[3] > 0))) || return 0,0,Inf,false
878890
costpenalty = (sum(reg_pressure) > REGISTER_COUNT) ? 2 : 1
879891
u₁v = vectorized === u₁loopsym; u₂v = vectorized === u₂loopsym
@@ -886,10 +898,12 @@ function evaluate_cost_tile(
886898
end
887899
outer_reduct_penalty = length(ls.outer_reductions) * (u₁ + isodd(u₁))
888900
favor_bigger_u₂ = u₁ - u₂
889-
favor_smaller_vectorized = u₁v ? ( u₁ - u₂ ) : (u₂v ? ( u₂ - u₁ ) : 0 )
901+
# favor_smaller_vectorized = (u₁v ? u₁ : -u₁) + (u₂v ? u₂ : -u₂)
902+
favor_smaller_vectorized = (u₁v u₂v) ? (u₁v ? u₁ - u₂ : u₂ - u₁) : 0
890903
favor_u₁_vectorized = -0.2u₁v
891904
favoring_heuristics = favor_bigger_u₂ + 0.5favor_smaller_vectorized + favor_u₁_vectorized
892-
u₁, u₂, costpenalty * ucost + stride_penalty(ls, order) + outer_reduct_penalty + favoring_heuristics, choose_to_inline[]
905+
costpenalty = costpenalty * ucost + stride_penalty(ls, order) + outer_reduct_penalty + favoring_heuristics
906+
u₁, u₂, costpenalty, choose_to_inline[]
893907
end
894908

895909

src/lower_compute.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
function load_constrained(op, u₁loop, u₂loop, forprefetch = false)
2+
function load_constrained(op, u₁loop, u₂loop, innermost_loop, forprefetch = false)
33
loopdeps = loopdependencies(op)
44
dependsonu₁ = u₁loop loopdeps
55
if u₂loop === Symbol("##undefined##")
@@ -17,7 +17,10 @@ function load_constrained(op, u₁loop, u₂loop, forprefetch = false)
1717
unrolleddeps = Symbol[]
1818
dependsonu₁ && push!(unrolleddeps, u₁loop)
1919
dependsonu₂ && push!(unrolleddeps, u₂loop)
20-
any(opp -> isload(opp) && all(in(loopdependencies(opp)), unrolleddeps), parents(op))
20+
forprefetch && push!(unrolleddeps, innermost_loop)
21+
any(parents(op)) do opp
22+
isload(opp) && all(in(loopdependencies(opp)), unrolleddeps)
23+
end
2124
end
2225
function check_if_remfirst(ls, ua)
2326
usorig = ls.unrollspecification[]
@@ -34,7 +37,7 @@ function check_if_remfirst(ls, ua)
3437
end
3538
function sub_fmas(ls::LoopSet, op::Operation, ua::UnrollArgs)
3639
@unpack u₁, u₁loopsym, u₂loopsym, u₂max = ua
37-
!(load_constrained(op, u₁loopsym, u₂loopsym) || check_if_remfirst(ls, ua))
40+
!(load_constrained(op, u₁loopsym, u₂loopsym, first(names(ls))) || check_if_remfirst(ls, ua))
3841
end
3942

4043
struct FalseCollection end
@@ -212,7 +215,8 @@ function lower_compute!(
212215
for u 0:Uiter
213216
instrcall = callexpr(instr)
214217
varsym = if tiledouterreduction > 0 # then suffix !== nothing
215-
modsuffix = ((u + suffix*(Uiter + 1)) & 3)
218+
modsuffix = ((u + suffix*(Uiter + 1)) & 7)
219+
# modsuffix = suffix::Int#((u + suffix*(Uiter + 1)) & 7)
216220
# modsuffix = u
217221
# modsuffix = suffix # (suffix & 3)
218222
Symbol(mangledvar(op), modsuffix)

src/lower_load.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ function prefetchisagoodidea(ls::LoopSet, op::Operation, td::UnrollArgs)
9090
if prod(s -> Float64(length(getloop(ls, s))), @view(indices[1:innermostloopind-1])) 120.0 && length(getloop(ls, innermostloopsym)) 120
9191
if op.ref.ref.offsets[innermostloopind] < 120
9292
for opp operations(ls)
93-
iscompute(opp) && (innermostloopsym loopdependencies(opp)) && load_constrained(opp, u₁loopsym, u₂loopsym, true) && return 0
93+
iscompute(opp) && (innermostloopsym loopdependencies(opp)) && load_constrained(opp, u₁loopsym, u₂loopsym, innermostloopsym, true) && return 0
9494
end
9595
return innermostloopind
9696
end

src/lowering.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ function calc_Ureduct(ls::LoopSet, us::UnrollSpecification)
643643
elseif u₂ == -1
644644
min(u₁, 4)
645645
else
646-
4#u₁
646+
8#u₂#u₁
647647
# elseif num_loops(ls) == u₁loopnum
648648
# min(u₁, 4)
649649
# else

test/fallback.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,9 @@
3636
@test msdavx(FallbackArrayWrapper(x)) == 1e18
3737
@test msd(x) == msdavx(FallbackArrayWrapper(x))
3838
@test msdavx(x) != msdavx(FallbackArrayWrapper(x))
39+
40+
x = rand(1000); # should be long enough to make zero differences incredibly unlikely
41+
@test exp.(x) != (@avx exp.(x))
42+
@test exp.(x) == (@avx exp.(FallbackArrayWrapper(x)))
3943
end
4044

test/gemm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@
336336
end);
337337
lsr2amb = LoopVectorization.LoopSet(r2ambq);
338338
if LoopVectorization.REGISTER_COUNT == 32
339-
@test LoopVectorization.choose_order(lsr2amb) == ([:n, :m, :k], :n, :m, :m, 7, 3)
339+
@test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :n, :m, :m, 7, 3)
340340
elseif LoopVectorization.REGISTER_COUNT == 16
341341
@test LoopVectorization.choose_order(lsr2amb) == ([:m, :n, :k], :n, :m, :m, 4, 2)
342342
end

test/ifelsemasks.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -165,21 +165,6 @@ T = Float32
165165
a[i] > b[i] || (c[i] = a[i] ^ b[i])
166166
end
167167
end
168-
function maybewriteor!(c, a, b)
169-
@inbounds for i eachindex(c,a,b)
170-
a[i] > b[i] || (c[i] = a[i] ^ b[i])
171-
end
172-
end
173-
function maybewriteor_avx!(c, a, b)
174-
@_avx for i eachindex(c,a,b)
175-
a[i] > b[i] || (c[i] = a[i] ^ b[i])
176-
end
177-
end
178-
function maybewriteoravx!(c, a, b)
179-
@avx for i eachindex(c,a,b)
180-
a[i] > b[i] || (c[i] = a[i] ^ b[i])
181-
end
182-
end
183168
function maybewriteor!(c::AbstractVector{<:Integer}, a, b)
184169
@inbounds for i eachindex(c,a,b)
185170
a[i] > b[i] || (c[i] = a[i] & b[i])

test/miscellaneous.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ using Test
108108
# # @test LoopVectorization.choose_order(lscolsum) == (Symbol[:j,:i], :j, :i, :j, Unum, Tnum)
109109
# @test LoopVectorization.choose_order(lscolsum) == (Symbol[:j,:i], :j, :i, :j, 1, 1)
110110
# end
111-
@test LoopVectorization.choose_order(lscolsum) == (Symbol[:j,:i], :j, Symbol("##undefined##"), :j, 8, -1)
111+
@test LoopVectorization.choose_order(lscolsum) == (Symbol[:j,:i], :j, Symbol("##undefined##"), :j, 4, -1)
112112
# my colsum is wrong (by 0.25), but slightly more interesting
113113
function mycolsum!(x, A)
114114
@. x = 0
@@ -144,7 +144,7 @@ using Test
144144
lsvar = LoopVectorization.LoopSet(varq);
145145
# LoopVectorization.choose_order(lsvar)
146146
# @test LoopVectorization.choose_order(lsvar) == (Symbol[:j,:i], :j, :i, :j, Unum, Tnum)
147-
@test LoopVectorization.choose_order(lsvar) == (Symbol[:j,:i], :j, Symbol("##undefined##"), :j, 8, -1)
147+
@test LoopVectorization.choose_order(lsvar) == (Symbol[:j,:i], :j, Symbol("##undefined##"), :j, 4, -1)
148148
# if LoopVectorization.REGISTER_COUNT == 32
149149
# @test LoopVectorization.choose_order(lsvar) == (Symbol[:j,:i], :j, :i, :j, 2, 10)
150150
# elseif LoopVectorization.REGISTER_COUNT == 16

0 commit comments

Comments
 (0)