Skip to content

Commit 270b27e

Browse files
committed
Minor progress. Figuring out how to handle unrolling, and will also consider tiling.
1 parent 57c1b4d commit 270b27e

File tree

4 files changed

+166
-30
lines changed

4 files changed

+166
-30
lines changed

src/LoopVectorization.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ const SLEEFPiratesDict = Dict{Symbol,Tuple{Symbol,Symbol}}(
1515
:cos => (:SLEEFPirates, :cos_fast),
1616
:cospi => (:SLEEFPirates, :cospi),
1717
:tan => (:SLEEFPirates, :tan_fast),
18-
:log => (:SLEEFPirates, :log_fast),
18+
# :log => (:SLEEFPirates, :log_fast),
19+
:log => (:SIMDPirates, :vlog),
1920
:log10 => (:SLEEFPirates, :log10),
2021
:log2 => (:SLEEFPirates, :log2),
2122
:log1p => (:SLEEFPirates, :log1p),
22-
:exp => (:SLEEFPirates, :exp),
23+
# :exp => (:SLEEFPirates, :exp),
24+
:exp => (:SIMDPirates, :vexp),
2325
:exp2 => (:SLEEFPirates, :exp2),
2426
:exp10 => (:SLEEFPirates, :exp10),
2527
:expm1 => (:SLEEFPirates, :expm1),
@@ -53,7 +55,8 @@ const SLEEFPiratesDict = Dict{Symbol,Tuple{Symbol,Symbol}}(
5355
:mod => (:SLEEFPirates, :mod),
5456
# :copysign => :copysign
5557
:one => (:SIMDPirates, :vone),
56-
:zero => (:SIMDPirates, :vzero)
58+
:zero => (:SIMDPirates, :vzero),
59+
:erf => (:SIMDPirates, :verf)
5760
)
5861

5962

src/contract_pass.jl

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
2+
3+
4+
contract_pass(x) = x # x will probably be a symbol
5+
function contract_pass(expr::Expr)::Expr
6+
prewalk(expr) do ex
7+
if !(ex isa Expr)
8+
return ex
9+
elseif ex.head != :call
10+
if ex.head === :(+=)
11+
call = Expr(:call, :(+))
12+
append!(call.args, ex.args)
13+
Expr(:(=), first(ex.args), call)
14+
elseif ex.head === :(-=)
15+
call = Expr(:call, :(-))
16+
append!(call.args, ex.args)
17+
Expr(:(=), first(ex.args), call)
18+
elseif ex.head === :(*=)
19+
call = Expr(:call, :(*))
20+
append!(call.args, ex.args)
21+
Expr(:(=), first(ex.args), call)
22+
elseif ex.head === :(/=)
23+
call = Expr(:call, :(/))
24+
append!(call.args, ex.args)
25+
Expr(:(=), first(ex.args), call)
26+
elseif ex.head != :call
27+
ex
28+
end
29+
elseif @capture(ex, f_(c_, g_(a_, b_))) || @capture(ex, f_(g_(a_,b_), c_))
30+
if (f === :(+) || f == :(Base.FastMath.add_fast)) && (g === :(*) || g == :(Base.FastMath.mul_fast))
31+
if a isa Expr && a.head === :call && (first(a.args) === :(-) || first(a.args) == :(Base.FastMath.sub_fast))
32+
Expr(:call, :vnfmadd, a, b, c)
33+
else
34+
Expr(:call, :vmuladd, a, b, c) #Expr(:call, :vfmadd, a, b, c)
35+
end
36+
elseif (f === :(-) || f == :(Base.FastMath.sub_fast)) && (g === :(*) || g == :(Base.FastMath.mul_fast))
37+
if a isa Expr && a.head === :call && (first(a.args) === :(-) || first(a.args) == :(Base.FastMath.sub_fast))
38+
Expr(:call, :vnfmsub, a, b, c)
39+
else
40+
Expr(:call, :vfmsub, a, b, c)
41+
end
42+
else
43+
ex
44+
end
45+
else
46+
ex
47+
end
48+
end
49+
end
50+
51+
52+
using MLStyle
53+
walk(x, inner, outer) = outer(x)
54+
walk(x::Expr, inner, outer) = outer(Expr(x.head, map(inner, x.args)...))
55+
56+
"""
57+
postwalk(f, expr)
58+
Applies `f` to each node in the given expression tree, returning the result.
59+
`f` sees expressions *after* they have been transformed by the walk. See also
60+
`prewalk`.
61+
"""
62+
postwalk(f, x) = walk(x, x -> postwalk(f, x), f)
63+
64+
"""
65+
prewalk(f, expr)
66+
Applies `f` to each node in the given expression tree, returning the result.
67+
`f` sees expressions *before* they have been transformed by the walk, and the
68+
walk will be applied to whatever `f` returns.
69+
This makes `prewalk` somewhat prone to infinite loops; you probably want to try
70+
`postwalk` first.
71+
"""
72+
prewalk(f, x) = walk(f(x), x -> prewalk(f, x), identity)
73+
74+
75+
function contract(expr)
76+
@match expr begin
77+
quote
78+
$a * $b + $c
79+
end
80+
81+
function contract_pass(expr)
82+
prewalk(expr) do ex
83+
84+
end
85+
@match expr begin
86+
quote
87+
struct $name{$tvar}
88+
$f1 :: $t1
89+
$f2 :: $t2
90+
end
91+
end =>
92+
quote
93+
struct $name{$tvar}
94+
$f1 :: $t2
95+
$f2 :: $t1
96+
end
97+
end |> rmlines
98+
end
99+
100+
101+
end
102+
103+

src/costs.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ struct InstructionCost
1212
scaling::Float64 # sentinel values: -3 == no scaling; -2 == offset_scaling, -1 == linear scaling, >0 -> == latency == reciprical throughput
1313
register_pressure::Int
1414
end
15-
InstructionCost(sl, srt, scaling = -3.0) = InstructionCost(sl, srt, scaling, 0)
15+
InstructionCost(sl, srt, scaling = -3.0) = InstructionCost(sl, srt, scaling, 1)
1616

1717
function scalar_cost(instruction::InstructionCost)#, ::Type{T} = Float64) where {T}
1818
instruction.scalar_latency, instruction.scalar_reciprical_throughput
@@ -42,7 +42,7 @@ function cost(instruction::InstructionCost, Wshift, ::Type{T}) where {T}
4242
end
4343

4444
# Just a semi-reasonable assumption; should not be that sensitive to anything other than loads
45-
const OPAQUE_INSTRUCTION = InstructionSet(50.0, 50.0, -1.0, 32)
45+
const OPAQUE_INSTRUCTION = InstructionSet(50.0, 50.0, -1.0, VectorizationBase.REGISTER_COUNT)
4646

4747
const COST = Dict{Symbol,InstructionCost}(
4848
:getindex => InstructionCost(3,0.5),
@@ -59,14 +59,21 @@ const COST = Dict{Symbol,InstructionCost}(
5959
:< => InstructionCost(1, 0.5),
6060
:>= => InstructionCost(1, 0.5),
6161
:<= => InstructionCost(1, 0.5),
62-
:inv => InstructionCost(13,4.0,-2.0,1),
63-
:muladd => InstructionCost(0.5,4), # + and * will fuse into this, so much of the time they're not twice as expensive
62+
:inv => InstructionCost(13,4.0,-2.0,2),
63+
:muladd => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
64+
:fma => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
65+
:vmuladd => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
66+
:vfma => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
67+
:vfmadd => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
68+
:vfmsub => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
69+
:vfnmadd => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
70+
:vfnmsub => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
6471
:sqrt => InstructionCost(15,4.0,-2.0),
65-
:log => InstructionCost(20,20.0,40.0,20),
66-
:exp => InstructionCost(20,20.0,20.0,18),
67-
:sin => InstructionCost(18,15.0,68.0,23),
68-
:cos => InstructionCost(18,15.0,68.0,26),
69-
:sincos => InstructionCost(25,22.0,70.0,26)
72+
:log => InstructionCost(20,20.0,40.0,21),
73+
:exp => InstructionCost(20,20.0,20.0,19),
74+
:sin => InstructionCost(18,15.0,68.0,24),
75+
:cos => InstructionCost(18,15.0,68.0,27),
76+
:sincos => InstructionCost(25,22.0,70.0,27)
7077
)
7178

7279
function sum_simd(x)

src/graphs.jl

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,27 @@
11
# using LightGraphs
22

33

4+
isdense(::Type{<:DenseArray}) = true
5+
6+
@enum NodeType begin
7+
memload
8+
memstore
9+
reduction
10+
end
11+
12+
13+
struct Operation
14+
outtype::DataType
15+
instruction::Symbol
16+
node_type::NodeType
17+
end
18+
19+
isreduction(op::Operation) = op.node_type == reduction
20+
isload(op::Operation) = op.node_type == memload
21+
isstore(op::Operation) = op.node_type == memstore
22+
accesses_memory(op::Operation) = isload(op) | isstore(op)
23+
Base.eltype(var::Operation) = op.outtype
24+
425
"""
526
ShortVector{T} simply wraps a Vector{T}, but uses a different hash function that is faster for short vectors to support using it as the keys of a Dict.
627
This hash function scales O(N) with length of the vectors, so it is slow for long vectors.
@@ -22,11 +43,6 @@ function Base.hash(x::ShortVector, h::UInt)
2243
h
2344
end
2445

25-
@enum NodeType begin
26-
input
27-
store
28-
reduction
29-
end
3046

3147
struct Node
3248
type::DataType
@@ -43,27 +59,27 @@ end
4359
function variables(ls::LoopSet)
4460

4561
end
46-
function loopdependencies(var::Variable)
62+
function loopdependencies(var::Operation)
4763

4864
end
49-
function sym(var::Variable)
65+
function sym(var::Operation)
5066

5167
end
52-
function instruction(var::Variable)
68+
function instruction(var::Operation)
5369

5470
end
55-
function accesses_memory(var::Variable)
71+
function accesses_memory(var::Operation)
5672

5773
end
58-
function stride(var::Variable, sym::Symbol)
74+
function stride(var::Operation, sym::Symbol)
5975

6076
end
61-
function cost(var::Variable, unrolled::Symbol, dim::Int)
77+
function cost(var::Operation, unrolled::Symbol, dim::Int)
6278
c = cost(instruction(var), Wshift, T)::Int
6379
if accesses_memory(var)
6480
# either vbroadcast/reductionstore, vmov(a/u)pd, or gather/scatter
6581
if (unrolled loopdependencies(var))
66-
if (stride(var, unrolled) != 1) # gather/scatter
82+
if (stride(var, unrolled) != 1) || !isdense(var) # need gather/scatter
6783
c *= W
6884
# else # vmov(a/u)pd
6985
end
@@ -73,22 +89,24 @@ function cost(var::Variable, unrolled::Symbol, dim::Int)
7389
end
7490
c
7591
end
76-
function Base.eltype(var::Variable)
77-
Base._return_type()
78-
end
92+
93+
# Base._return_type()
94+
7995
function biggest_type(ls::LoopSet)
8096

8197
end
8298

99+
100+
83101
# evaluates cost of evaluating loop in given order
84-
function evaluate_cost(
85-
ls::LoopSet, order::ShortVector{Symbol}, max_cost = typemax(Int)
102+
function evaluate_cost_unroll(
103+
ls::LoopSet, order::ShortVector{Symbol}, unrolled::Symbol, max_cost = typemax(Int)
86104
)
87105
included_vars = Set{Symbol}()
88106
nested_loop_syms = Set{Symbol}()
89107
total_cost = 0.0
90108
iter = 1.0
91-
unrolled = last(order)
109+
# Need to check if fusion is possible
92110
W, Wshift = VectorizationBase.pick_vector_width_shift(length(ls, unrolled), biggest_type(ls))::Tuple{Int,Int}
93111

94112
fused_with_previous = fill(false, length(order))
@@ -118,6 +136,11 @@ function evaluate_cost(
118136
end
119137
end
120138
end
139+
function evaluate_cost_tile(
140+
ls::LoopSet, order::ShortVector{Symbol}, tiler, tilec, max_cost = typemax(Int)
141+
)
142+
143+
end
121144

122145
struct LoopOrders
123146
syms::Vector{Symbol}

0 commit comments

Comments
 (0)