Skip to content

Commit bc6a887

Browse files
committed
Minor progress.
1 parent fb39997 commit bc6a887

File tree

3 files changed

+138
-74
lines changed

3 files changed

+138
-74
lines changed

src/costs.jl

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,14 @@ InstructionCost(sl, srt, scaling = -3.0) = InstructionCost(sl, srt, scaling, 1)
1717
function scalar_cost(instruction::InstructionCost)#, ::Type{T} = Float64) where {T}
1818
instruction.scalar_latency, instruction.scalar_reciprical_throughput
1919
end
20-
function vector_cost(instruction::InstructionCost, Wshift, ::Type{T} = Float64) where {T}
20+
function vector_cost(instruction::InstructionCost, Wshift, sizeof_T)
2121
sl, srt = scalar_cost(instruction)
2222
scaling = instruction.scaling
2323
if scaling == -3.0 || Wshift == 0
2424
return sl, srt
2525
elseif scaling == -2.0
26-
srt *= 1 << (Wshift + VectorizationBase.intlog2(sizeof(T)) - 4)
27-
if (sizeof(T) << Wshift) == VectorizationBase.REGISTER_SIZE # These instructions experience double latency with zmm
26+
srt *= 1 << (Wshift + VectorizationBase.intlog2(sizeof_T) - 4)
27+
if (sizeof_T << Wshift) == 64 # VectorizationBase.REGISTER_SIZE # These instructions experience double latency with zmm
2828
sl += sl
2929
end
3030
elseif scaling == -1.0
@@ -37,28 +37,35 @@ function vector_cost(instruction::InstructionCost, Wshift, ::Type{T} = Float64)
3737
end
3838
sl, srt
3939
end
40-
function cost(instruction::InstructionCost, Wshift, ::Type{T}) where {T}
41-
Wshift == 0 ? scalar_cost(instruction) : vector_cost(instruction, Wshift, T)
40+
function cost(instruction::InstructionCost, Wshift, sizeof_T)
41+
Wshift == 0 ? scalar_cost(instruction) : vector_cost(instruction, Wshift, sizeof_T)
42+
end
43+
44+
function cost(instruction::Symbol, Wshift, sizeof_T)
45+
cost(
46+
get(COST, instruction, OPAQUE_INSTRUCTION),
47+
Wshift, sizeof_T
48+
)
4249
end
4350

4451
# Just a semi-reasonable assumption; should not be that sensitive to anything other than loads
45-
const OPAQUE_INSTRUCTION = InstructionSet(50.0, 50.0, -1.0, VectorizationBase.REGISTER_COUNT)
52+
const OPAQUE_INSTRUCTION = InstructionCost(50, 50.0, -1.0, VectorizationBase.REGISTER_COUNT)
4653

4754
const COST = Dict{Symbol,InstructionCost}(
4855
:getindex => InstructionCost(3,0.5),
4956
:setindex! => InstructionCost(3,1.0), # but not a part of dependency chains, so not really twice as expensive?
50-
:+ => InstructionCost(4,0.5),
51-
:- => InstructionCost(4,0.5),
52-
:* => InstructionCost(4,0.5),
53-
:/ => InstructionCost(13,4.0,-2.0),
54-
:== => InstructionCost(1, 0.5),
57+
:(+) => InstructionCost(4,0.5),
58+
:(-) => InstructionCost(4,0.5),
59+
:(*) => InstructionCost(4,0.5),
60+
:(/) => InstructionCost(13,4.0,-2.0),
61+
:(==) => InstructionCost(1, 0.5),
5562
:isequal => InstructionCost(1, 0.5),
56-
:& => InstructionCost(1, 0.5),
57-
:| => InstructionCost(1, 0.5),
58-
:> => InstructionCost(1, 0.5),
59-
:< => InstructionCost(1, 0.5),
60-
:>= => InstructionCost(1, 0.5),
61-
:<= => InstructionCost(1, 0.5),
63+
:(&) => InstructionCost(1, 0.5),
64+
:(|) => InstructionCost(1, 0.5),
65+
:(>) => InstructionCost(1, 0.5),
66+
:(<) => InstructionCost(1, 0.5),
67+
:(>=) => InstructionCost(1, 0.5),
68+
:(<=) => InstructionCost(1, 0.5),
6269
:inv => InstructionCost(13,4.0,-2.0,2),
6370
:muladd => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
6471
:fma => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
@@ -76,26 +83,6 @@ const COST = Dict{Symbol,InstructionCost}(
7683
:sincos => InstructionCost(25,22.0,70.0,27)
7784
)
7885

79-
function sum_simd(x)
80-
s = zero(eltype(x))
81-
@simd for xᵢ x
82-
s += xᵢ
83-
end
84-
s
85-
end
86-
using LoopVectorization, BenchmarkTools
87-
function sum_loopvec(x::AbstractVector{Float64})
88-
s = 0.0
89-
@vvectorize 4 for i eachindex(x)
90-
s += x[i]
91-
end
92-
s
93-
end
94-
x = rand(111);
95-
@btime sum($x)
96-
@btime sum_simd($x)
97-
@btime sum_loopvec($x)
98-
9986

10087
# const SIMDPIRATES_COST = Dict{Symbol,InstructionCost}()
10188
# const SLEEFPIRATES_COST = Dict{Symbol,InstructionCost}()

src/graphs.jl

Lines changed: 95 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,32 @@
33

44
isdense(::Type{<:DenseArray}) = true
55

6+
"""
7+
ShortVector{T} simply wraps a Vector{T}, but uses a different hash function that is faster for short vectors to support using it as the keys of a Dict.
8+
This hash function scales O(N) with length of the vectors, so it is slow for long vectors.
9+
"""
10+
struct ShortVector{T} <: DenseVector{T}
11+
data::Vector{T}
12+
end
13+
Base.@propagate_inbounds Base.getindex(x::ShortVector, I...) = x.data[I...]
14+
Base.@propagate_inbounds Base.setindex!(x::ShortVector, v, I...) = x.data[I...] = v
15+
@inbounds Base.length(x::ShortVector) = length(x.data)
16+
@inbounds Base.size(x::ShortVector) = size(x.data)
17+
@inbounds Base.strides(x::ShortVector) = strides(x.data)
18+
@inbounds Base.push!(x::ShortVector, v) = push!(x.data, v)
19+
@inbounds Base.append!(x::ShortVector, v) = append!(x.data, v)
20+
function Base.hash(x::ShortVector, h::UInt)
21+
@inbounds for n eachindex(x)
22+
h = hash(x[n], h)
23+
end
24+
h
25+
end
26+
27+
28+
629
@enum NodeType begin
730
memload
831
memstore
9-
reduction
1032
compute
1133
end
1234

@@ -15,61 +37,62 @@ struct Operation
1537
elementbytes::Int
1638
instruction::Symbol
1739
node_type::NodeType
40+
# dependencies::ShortVector{Symbol}
41+
dependencies::Set{Symbol}
42+
# dependencies::Set{Symbol}
1843
parents::Vector{Operation}
1944
children::Vector{Operation}
20-
metadata::Vector{Float64}
45+
numerical_metadata::Vector{Float64}
46+
symbolic_metadata::Vector{Symbol}
2147
function Operation(elementbytes, instruction, node_type)
2248
new(
2349
elementbytes, instruction, node_type,
24-
Operation[], Operation[], Float64[]
50+
Set{Symbol}(), Operation[], Operation[], Float64[], Symbol[]
2551
)
2652
end
2753
end
2854

29-
isreduction(op::Operation) = op.node_type == reduction
55+
function isreduction(op::Operation)
56+
(op.node_type == memstore) && (length(op.symbolic_metadata) < length(op.dependencies)) && issubset(op.symbolic_metadata, op.dependencies)
57+
end
3058
isload(op::Operation) = op.node_type == memload
3159
isstore(op::Operation) = op.node_type == memstore
3260
accesses_memory(op::Operation) = isload(op) | isstore(op)
33-
Base.eltype(var::Operation) = op.outtype
34-
35-
"""
36-
ShortVector{T} simply wraps a Vector{T}, but uses a different hash function that is faster for short vectors to support using it as the keys of a Dict.
37-
This hash function scales O(N) with length of the vectors, so it is slow for long vectors.
38-
"""
39-
struct ShortVector{T} <: DenseVector{T}
40-
data::Vector{T}
41-
end
42-
Base.@propagate_inbounds Base.getindex(x::ShortVector, I...) = x.data[I...]
43-
Base.@propagate_inbounds Base.setindex!(x::ShortVector, v, I...) = x.data[I...] = v
44-
@inbounds Base.length(x::ShortVector) = length(x.data)
45-
@inbounds Base.size(x::ShortVector) = size(x.data)
46-
@inbounds Base.strides(x::ShortVector) = strides(x.data)
47-
@inbounds Base.push!(x::ShortVector, v) = push!(x.data, v)
48-
@inbounds Base.append!(x::ShortVector, v) = append!(x.data, v)
49-
function Base.hash(x::ShortVector, h::UInt)
50-
@inbounds for n eachindex(x)
51-
h = hash(x[n], h)
52-
end
53-
h
54-
end
61+
elsize(op::Operation) = op.elementbytes
62+
dependson(op::Operation, sym::Symbol) = sym op.dependencies
5563

5664
function stride(op::Operation, sym::Symbol)
5765
@assert accesses_memory(op) "This operation does not access memory!"
5866
# access stride info?
5967
end
60-
function
68+
# function
6169

6270
struct Node
6371
type::DataType
6472
end
6573

74+
struct Loop
75+
itersymbol::Symbol
76+
rangehint::Int
77+
rangesym::Symbol
78+
hintexact::Bool # if true, rangesym ignored and rangehint used for final lowering
79+
end
80+
function Loop(itersymbol::Symbol, rangehint::Int)
81+
Loop( itersymbol, rangehint, :undef, true )
82+
end
83+
function Loop(itersymbol::Symbol, rangesym::Symbol, rangehint::Int = 1_000_000)
84+
Loop( itersymbol, rangehint, rangesym, false )
85+
end
86+
6687
# Must make it easy to iterate
6788
struct LoopSet
89+
loops::Dict{Symbol,Loop} # sym === loops[sym].itersymbol
90+
operations::Vector{Operation}
6891

6992
end
7093

7194
function Base.length(ls::LoopSet, is::Symbol)
72-
95+
ls.loops[is].rangehint
7396
end
7497
function variables(ls::LoopSet)
7598

@@ -78,7 +101,7 @@ function loopdependencies(var::Operation)
78101

79102
end
80103
function sym(var::Operation)
81-
104+
82105
end
83106
function instruction(var::Operation)
84107

@@ -89,6 +112,7 @@ end
89112
function stride(var::Operation, sym::Symbol)
90113

91114
end
115+
operations(ls::LoopSet) = ls.operations
92116
function cost(var::Operation, unrolled::Symbol, dim::Int)
93117
c = cost(instruction(var), Wshift, T)::Int
94118
if accesses_memory(var)
@@ -108,31 +132,31 @@ end
108132
# Base._return_type()
109133

110134
function biggest_type(ls::LoopSet)
111-
135+
maximum(elsize, ls.operations)
112136
end
113137

114138

115139

116140
# evaluates cost of evaluating loop in given order
117141
function evaluate_cost_unroll(
118-
ls::LoopSet, order::ShortVector{Symbol}, unrolled::Symbol, max_cost = typemax(Int)
142+
ls::LoopSet, order::ShortVector{Symbol}, unrolled::Symbol, max_cost = typemax(Float64)
119143
)
120144
included_vars = Set{Symbol}()
121145
nested_loop_syms = Set{Symbol}()
122146
total_cost = 0.0
123147
iter = 1.0
124148
# Need to check if fusion is possible
125-
# W, Wshift = VectorizationBase.pick_vector_width_shift(length(ls, unrolled), biggest_type(ls))::Tuple{Int,Int}
149+
W, Wshift = VectorizationBase.pick_vector_width_shift(length(ls, unrolled), biggest_type(ls))::Tuple{Int,Int}
126150
for itersym order
127151
# Add to set of defined symbles
128152
push!(nested_loop_syms, itersym)
129-
liter = length(ls, itersym)
153+
liter = Float64(length(ls, itersym))
130154
if itersym == unrolled
131155
liter /= W
132156
end
133157
iter *= liter
134158
# check which vars we can define at this level of loop nest
135-
for var variables(ls)
159+
for var operations(ls)
136160
# won't define if already defined...
137161
sym(var) included_vars && continue
138162
# it must also be a subset of defined symbols
@@ -141,14 +165,48 @@ function evaluate_cost_unroll(
141165
push!(included_vars, sym(var))
142166

143167
total_cost += iter * cost(var, W, Wshift, unrolled, liter)
144-
total_cost > max_cost && return total_cost # abort
168+
total_cost > max_cost && return total_cost # abort if more expensive; we only want to know the cheapest
145169
end
146170
end
171+
total_cost
147172
end
148-
function evaluate_cost_tile(
149-
ls::LoopSet, order::ShortVector{Symbol}, tiler, tilec, max_cost = typemax(Int)
173+
174+
# only covers unrolled ops; everything else considered lifted?
175+
function depchain_cost!(
176+
skip::Set{Symbol}, ls::LoopSet, op::Operation, unrolled::Symbol, Wshift::Int, size_T::Int
177+
)
178+
179+
end
180+
181+
function determine_unroll_factor(
182+
ls::LoopSet, order::ShortVector{Symbol}, unrolled::Symbol, Wshift::Int, size_T::Int
150183
)
184+
# The strategy is to use an unroll factor of 1, unless there appears to be loop carried dependencies (ie, num_reductions > 0)
185+
# The assumption here is that unrolling provides no real benefit, unless it is needed to enable OOO execution by breaking up these dependency chains
186+
num_reductions = sum(isreduction, operations(ls))
187+
iszero(num_reductions) && return 1
188+
# So if num_reductions > 0, we set the unroll factor to be high enough so that the CPU can be kept busy
189+
# if there are, U = max(1, round(Int, max(latency) * throughput / num_reductions)) = max(1, round(Int, latency / (recip_througput * num_reductions)))
190+
latency = 0
191+
recip_throughput = 0.0
192+
visited_nodes = Set{Symbol}()
193+
for op operations(ls)
194+
if isreduction(op) && dependson(op, unrolled)
195+
l, rt = cost_of_chain()
196+
num_reductions += 1
197+
sl, rt = cost(instruction(op), Wshift, size_T)
198+
latency = max(sl, latency)
199+
recip_throughput += rt
200+
end
201+
end
202+
151203

204+
205+
end
206+
function evaluate_cost_tile(
207+
ls::LoopSet, order::ShortVector{Symbol}, tiler, tilec, max_cost = typemax(Float64)
208+
)
209+
152210
end
153211

154212
struct LoopOrders

test/runtests.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,25 @@
11
using LoopVectorization
22
using Test
33

4+
pkgdir(pkg::String) = abspath(joinpath(dirname(Base.find_package(pkg)), ".."))
5+
using VectorizationBase, SIMDPirates, SLEEFPirates
6+
# includet(joinpath(pkgdir("LoopVectorization"), "src/costs.jl"))
7+
# includet(joinpath(pkgdir("LoopVectorization"), "src/graphs.jl"))
8+
include(joinpath(pkgdir("LoopVectorization"), "src/costs.jl"))
9+
include(joinpath(pkgdir("LoopVectorization"), "src/graphs.jl"))
10+
11+
# loop is gemv!
12+
for c 1:C
13+
for r 1:R
14+
y[r] += A[r,c] * x[c]
15+
# translates to
16+
# y[r] = vmuladd(A[r,c], x[c], y[r])
17+
end
18+
end
19+
20+
21+
22+
423
using CpuId, VectorizationBase, SIMDPirates, SLEEFPirates, VectorizedRNG
524

625
@generated function estimate_cost_onearg_serial(f::F, N::Int = 512, K = 1_000, ::Type{T} = Float64, ::Val{U} = Val(4)) where {F,T,U}

0 commit comments

Comments
 (0)