Skip to content

Commit 8aec893

Browse files
committed
Make a few cost updates
1 parent cd1c612 commit 8aec893

File tree

4 files changed

+142
-21
lines changed

4 files changed

+142
-21
lines changed

src/LoopVectorization.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ using Base.Meta: isexpr
3232
using DocStringExtensions
3333
import LinearAlgebra # for check_args
3434

35-
using Base.FastMath: add_fast, sub_fast, mul_fast, div_fast, inv_fast, abs2_fast, rem_fast, max_fast, min_fast
35+
using Base.FastMath: add_fast, sub_fast, mul_fast, div_fast, inv_fast, abs2_fast, rem_fast, max_fast, min_fast, log_fast, log2_fast, log10_fast
3636

3737

3838
using ArrayInterface

src/modeling/costs.jl

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,9 @@ function vector_cost(ic::InstructionCost, Wshift, sizeof_T)
8383
extra_latency = sl - srt
8484
srt *= W
8585
sl = round(Int, srt + extra_latency)
86-
else # we assume custom cost, and that latency == recip_throughput
87-
scaling = ic.scaling
88-
sl, srt = round(Int,scaling), scaling
86+
# else # we assume custom cost, and that latency == recip_throughput
87+
# scaling = ic.scaling
88+
# sl, srt = round(Int,scaling), scaling
8989
end
9090
srt, sl, srp
9191
end
@@ -224,28 +224,32 @@ const COST = Dict{Symbol,InstructionCost}(
224224
# :vdivlog10add! =>InstructionCost(13,4.0,-2.0),
225225
:sqrt => InstructionCost(15,4.0,-2.0),
226226
:sqrt_fast => InstructionCost(15,4.0,-2.0),
227-
:log => InstructionCost(20,20.0,20.0,20),
228-
:log1p => InstructionCost(20,25.0,25.0,20), # FIXME
229-
:exp => InstructionCost(20,20.0,20.0,18),
230-
:expm1 => InstructionCost(20,25.0,25.0,18), # FIXME
231-
:(^) => InstructionCost(40,40.0,40.0,26), # FIXME
232-
:sin => InstructionCost(18,15.0,68.0,23),
233-
:cos => InstructionCost(18,15.0,68.0,26),
234-
:sincos => InstructionCost(25,22.0,70.0,26),
227+
:log => InstructionCost(-3.0, 15, 30, 11),
228+
:log2 => InstructionCost(-3.0, 15, 30, 11),
229+
:log10 => InstructionCost(-3.0, 15, 30, 11),
230+
:log1p => InstructionCost(-3.0, 15, 30, 11),
231+
:exp => InstructionCost(-3.0,13.0,26.0,14),
232+
:exp2 => InstructionCost(-3.0,10.0,40.0,14),
233+
:exp10 => InstructionCost(-3.0,13.0,26.0,14),
234+
:expm1 => InstructionCost(-3.0,30.0,60.0,19),
235+
:(^) => InstructionCost(-3.0,200.0,400.0,26), # FIXME
236+
:sin => InstructionCost(-3,30.0,60.0,23),
237+
:cos => InstructionCost(-3,27.0,60.0,26),
238+
:sincos => InstructionCost(-3,37.0,85.0,26),
235239
:sinpi => InstructionCost(18,15.0,68.0,23),
236240
:cospi => InstructionCost(18,15.0,68.0,26),
237-
:sincospi => InstructionCost(25,22.0,70.0,26),
241+
:sincospi => InstructionCost(25,37.0,70.0,26),
238242
:log_fast => InstructionCost(20,20.0,40.0,20),
239243
:exp_fast => InstructionCost(20,20.0,20.0,18),
240244
:sin_fast => InstructionCost(18,15.0,68.0,23),
241245
:cos_fast => InstructionCost(18,15.0,68.0,26),
242-
:sincos_fast => InstructionCost(25,22.0,70.0,26),
246+
:sincos => InstructionCost(-3,37.0,85.0,26),
243247
:sinpi_fast => InstructionCost(18,15.0,68.0,23),
244248
:cospi_fast => InstructionCost(18,15.0,68.0,26),
245249
:sincospi_fast => InstructionCost(25,22.0,70.0,26),
246-
:tanh => InstructionCost(40,40.0,40.0,26), # FIXME
247-
:tanh_fast => InstructionCost(25,22.0,70.0,26), # FIXME
248-
:sigmoid_fast => InstructionCost(25,22.0,70.0,26), # FIXME
250+
:tanh => InstructionCost(-3.0,80.0,160.0,26), # FIXME
251+
:tanh_fast => InstructionCost(-3.0,30.0,60.0,20), # FIXME
252+
:sigmoid_fast => InstructionCost(-3.0,16.0,66.0,15), # FIXME
249253
:identity => InstructionCost(0,0.0,0.0,0),
250254
:adjoint => InstructionCost(0,0.0,0.0,0),
251255
:conj => InstructionCost(0,0.0,0.0,0),
@@ -548,11 +552,15 @@ const FUNCTIONSYMBOLS = IdDict{Type{<:Function},Instruction}(
548552
typeof(Base.FastMath.sqrt_fast) => :sqrt,
549553
# typeof(VectorizationBase.vsqrt) => :sqrt,
550554
typeof(log) => :log,
555+
typeof(log2) => :log2,
556+
typeof(log10) => :log10,
551557
typeof(Base.FastMath.log_fast) => :log,
552558
typeof(log1p) => :log1p,
553559
# typeof(VectorizationBase.vlog) => :log,
554560
typeof(SLEEFPirates.log) => :log,
555561
typeof(exp) => :exp,
562+
typeof(exp2) => :exp2,
563+
typeof(exp10) => :exp10,
556564
typeof(Base.FastMath.exp_fast) => :exp,
557565
typeof(expm1) => :expm1,
558566
# typeof(VectorizationBase.vexp) => :exp,

src/modeling/determinestrategy.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,28 +254,38 @@ function unroll_no_reductions(ls, order, vloopsym)
254254
unrolled = order[end-1]
255255
end
256256
# latency not a concern, because no depchains
257+
compute_l = 0.0
258+
# rp = 0
257259
for op operations(ls)
258260
dependson(op, unrolled) || continue
261+
rt, sl, rpop = cost(ls, op, (unrolled,Symbol("")), vloopsym, Wshift, size_T)
262+
# rp += rpop
259263
if iscompute(op)
260-
compute_rt += first(cost(ls, op, (unrolled,Symbol("")), vloopsym, Wshift, size_T))
264+
compute_rt += rt
265+
compute_l += sl
261266
elseif isload(op)
262-
load_rt += first(cost(ls, op, (unrolled,Symbol("")), vloopsym, Wshift, size_T))
267+
load_rt += rt
263268
elseif isstore(op)
264-
store_rt += first(cost(ls, op, (unrolled,Symbol("")), vloopsym, Wshift, size_T))
269+
store_rt += rt
265270
end
266271
end
267272
# heuristic guess
268273
# roundpow2(min(4, round(Int, (compute_rt + load_rt + 1) / compute_rt)))
269274
memory_rt = load_rt + store_rt
275+
@show memory_rt, load_rt, store_rt, compute_rt, compute_l
276+
270277
u = if compute_rt > memory_rt
271-
max(1, VectorizationBase.nextpow2( min( 4, round(Int, 8 / compute_rt) ) ))
278+
@show clamp(round(Int, compute_l / compute_rt), 1, 4)
279+
# max(1, VectorizationBase.nextpow2( min( 4, round(Int, 8 / compute_rt) ) ))
272280
elseif iszero(compute_rt)
273281
4
274282
elseif iszero(load_rt)
275283
iszero(store_rt) ? 4 : max(1, min(4, round(Int, 2compute_rt / store_rt)))
276284
else
277285
max(1, min(4, round(Int, 2compute_rt / load_rt)))
278286
end
287+
# u = min(u, max(1, (reg_count(ls) ÷ max(1,round(Int,rp)))))
288+
# @show u
279289
# commented out here is to decide to align loops
280290
# if memory_rt > compute_rt && isone(u) && (length(order) > 1) && (last(order) === vloopsym) && length(getloop(ls, last(order))) > 8W
281291
# ls.align_loops[] = findfirst(operations(ls)) do op

utils/generate_costs.jl

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
using VectorizationBase, LoopVectorization
2+
using VectorizationBase: data
3+
4+
# @generated to use VectorizationBase's API for supporting 1.5 and 1.6+
5+
@generated function readcyclecounter()
6+
decl = "declare i64 @llvm.readcyclecounter()"
7+
instr = "%res = call i64 @llvm.readcyclecounter()\nret i64 %res"
8+
VectorizationBase.llvmcall_expr(decl, instr, :Int64, :(Tuple{}), "i64", String[], Symbol[])
9+
end
10+
11+
@generated function volatile(x::Vec{W,T}) where {W,T}
12+
typ = VectorizationBase.LLVM_TYPES[T]
13+
vtyp = "<$W x $typ>"
14+
15+
suffix = T == Float32 ? "ps" : "pd"
16+
sideeffect_str = """%res = call <$W x $(typ)> asm sideeffect "", "=v,v"(<$W x $(typ)> %0)
17+
ret <$W x $(typ)> %res"""
18+
quote
19+
$(Expr(:meta, :inline))
20+
Vec(Base.llvmcall($sideeffect_str, NTuple{$W,Core.VecElement{$T}}, Tuple{NTuple{$W,Core.VecElement{$T}}}, VectorizationBase.data(x)))
21+
end
22+
end
23+
@inline volatile(x::VecUnroll) = VecUnroll(VectorizationBase.fmap(volatile, data(x)))
24+
@inline volatile(x::Tuple) = map(volatile, x)
25+
# @generated function volatile(x::Vec{W,T}, x::Vec{W,T}) where {W,T}
26+
# typ = VectorizationBase.LLVM_TYPES[T]
27+
# vtyp = "<$W x $typ>"
28+
29+
# suffix = T == Float32 ? "ps" : "pd"
30+
# sideeffect_str = """%res = call <$W x $(typ)> asm sideeffect "", "=v,v"(<$W x $(typ)> %0, <$W x $(typ)> %1)
31+
# ret <$W x $(typ)> %res"""
32+
# quote
33+
# $(Expr(:meta, :inline))
34+
# Vec(Base.llvmcall($sideeffect_str, NTuple{$W,Core.VecElement{$T}}, Tuple{NTuple{$W,Core.VecElement{$T}}}, VectorizationBase.data(x)))
35+
# end
36+
# end
37+
38+
num_vectors(::VecUnroll{N}) where {N} = N+1
39+
num_vectors(::Vec) = 1
40+
function unrolltest(f::F, vs::Vararg{Any,K}) where {F,K}
41+
cc = readcyclecounter()
42+
# num_iter = 1_048_576
43+
num_iter = 4_194_304
44+
for i 1:num_iter
45+
volatile(f(map(volatile, vs)...))
46+
end
47+
cycles = readcyclecounter() - cc
48+
cycles / (num_vectors(first(vs)) * num_iter)
49+
end
50+
51+
# @generated function vapply!(f::F, y, x, ::Val{U}) where {F,U}
52+
# quote
53+
# @avx unroll=$U for j ∈ 1:1024
54+
# y[j] = f(x[j])
55+
# end
56+
# end
57+
# end
58+
59+
# vector_init(::Val{N}, ::Type{T}) where {N,T} = VectorizationBase.zero_vecunroll(StaticInt(N), pick_vector_width(T), T, VectorizationBase.register_size())
60+
# vector_init(::Val{1}, ::Type{T}) where {T} = VectorizationBase.vzero(pick_vector_width(T), T)
61+
62+
# @generated function unrolltest(f::F, x::AbstractVector{T}, ::Val{U}) where {F,U,T}
63+
# quote
64+
# cc = readcyclecounter()
65+
# for i ∈ 1:8192
66+
# s = vector_init(Val{$U}(), $T)
67+
# @avx unroll=$U for j ∈ 1:512
68+
# s += f(x[j])
69+
# end
70+
# volatile(s)
71+
# end
72+
# cycles = readcyclecounter() - cc
73+
# pick_vector_width(T) * cycles / (512 * 8192)
74+
# end
75+
# end
76+
77+
78+
# @generated function unrolltest!(f::F, y::AbstractVector{T}, x::AbstractVector{T}, ::Val{U}) where {F,U,T}
79+
# quote
80+
# cc = readcyclecounter()
81+
# for i ∈ 1:8192
82+
# @avx unroll=$U for j ∈ 1:512
83+
# y[j] = f(x[j])
84+
# end
85+
# end
86+
# cycles = readcyclecounter() - cc
87+
# pick_vector_width(T) * cycles / (512 * 8192)
88+
# end
89+
# end
90+
91+
let
92+
vx = Vec(ntuple(_ -> 10randn(), pick_vector_width(Float64))...);
93+
vu2 = VectorizationBase.VecUnroll(ntuple(_ -> Vec(ntuple(_ -> 10randn(), pick_vector_width(Float64))...), Val(2)));
94+
vu4 = VectorizationBase.VecUnroll(ntuple(_ -> Vec(ntuple(_ -> 10randn(), pick_vector_width(Float64))...), Val(4)));
95+
vu8 = VectorizationBase.VecUnroll(ntuple(_ -> Vec(ntuple(_ -> 10randn(), pick_vector_width(Float64))...), Val(8)));
96+
for binaryf [log, log2, log10, log1p, exp, exp2, exp10, expm1, sin, cos]
97+
rt1 = unrolltest(f, vx)
98+
rt2 = unrolltest(f, vu2)
99+
rt4 = unrolltest(f, vu4)
100+
rt8 = unrolltest(f, vu8)
101+
end
102+
end
103+

0 commit comments

Comments
 (0)