Skip to content

Commit aa61df8

Browse files
committed
WIP; added precompile statements, working on writing unroll-lowering of a loop set; still need tiled lowering, and the creating of loop sets from loop expressions, broadcasting expressions, and tensor notation.
1 parent bc6a887 commit aa61df8

File tree

6 files changed

+358
-140
lines changed

6 files changed

+358
-140
lines changed

Manifest.toml

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
# This file is machine-generated - editing it directly is not advised
22

3-
[[ArnoldiMethod]]
4-
deps = ["DelimitedFiles", "LinearAlgebra", "Random", "SparseArrays", "StaticArrays", "Test"]
5-
git-tree-sha1 = "2b6845cea546604fb4dca4e31414a6a59d39ddcd"
6-
uuid = "ec485272-7323-5ecc-a04f-4719b315124d"
7-
version = "0.0.4"
8-
93
[[Base64]]
104
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
115

@@ -23,9 +17,9 @@ version = "0.2.2"
2317

2418
[[DataStructures]]
2519
deps = ["InteractiveUtils", "OrderedCollections"]
26-
git-tree-sha1 = "1fe8fad5fc84686dcbc674aa255bc867a64f8132"
20+
git-tree-sha1 = "a1b652fb77ae8ca7ea328fa7ba5aa151036e5c10"
2721
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
28-
version = "0.17.5"
22+
version = "0.17.6"
2923

3024
[[Dates]]
3125
deps = ["Printf"]
@@ -39,12 +33,6 @@ uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
3933
deps = ["Random", "Serialization", "Sockets"]
4034
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
4135

42-
[[Inflate]]
43-
deps = ["Pkg", "Printf", "Random", "Test"]
44-
git-tree-sha1 = "b7ec91c153cf8bff9aff58b39497925d133ef7fd"
45-
uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9"
46-
version = "0.1.1"
47-
4836
[[InteractiveUtils]]
4937
deps = ["Markdown"]
5038
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
@@ -56,12 +44,6 @@ uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
5644
[[Libdl]]
5745
uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
5846

59-
[[LightGraphs]]
60-
deps = ["ArnoldiMethod", "DataStructures", "Distributed", "Inflate", "LinearAlgebra", "Random", "SharedArrays", "SimpleTraits", "SparseArrays", "Statistics"]
61-
git-tree-sha1 = "a0d4bcea4b9c056da143a5ded3c2b7f7740c2d41"
62-
uuid = "093fc24a-ae57-5d10-9952-331d41423f4d"
63-
version = "1.3.0"
64-
6547
[[LinearAlgebra]]
6648
deps = ["Libdl"]
6749
uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -88,6 +70,12 @@ git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
8870
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
8971
version = "1.1.0"
9072

73+
[[Parameters]]
74+
deps = ["OrderedCollections"]
75+
git-tree-sha1 = "b62b2558efb1eef1fa44e4be5ff58a515c287e38"
76+
uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a"
77+
version = "0.12.0"
78+
9179
[[Pkg]]
9280
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
9381
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
@@ -126,25 +114,13 @@ uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
126114
deps = ["Distributed", "Mmap", "Random", "Serialization"]
127115
uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
128116

129-
[[SimpleTraits]]
130-
deps = ["InteractiveUtils", "MacroTools"]
131-
git-tree-sha1 = "2bdf3b6300a9d66fe29ee8bb51ba100c4df9ecbc"
132-
uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
133-
version = "0.9.1"
134-
135117
[[Sockets]]
136118
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
137119

138120
[[SparseArrays]]
139121
deps = ["LinearAlgebra", "Random"]
140122
uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
141123

142-
[[StaticArrays]]
143-
deps = ["LinearAlgebra", "Random", "Statistics"]
144-
git-tree-sha1 = "5a3bcb6233adabde68ebc97be66e95dcb787424c"
145-
uuid = "90137ffa-7385-5640-81b9-e52037218182"
146-
version = "0.12.1"
147-
148124
[[Statistics]]
149125
deps = ["LinearAlgebra", "SparseArrays"]
150126
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ authors = ["Chris Elrod <[email protected]>"]
44
version = "0.1.0"
55

66
[deps]
7-
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
87
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
8+
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
99
SIMDPirates = "21efa798-c60a-11e8-04d3-e1a92915a26a"
1010
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
1111
VectorizationBase = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f"

src/LoopVectorization.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module LoopVectorization
22

3-
using VectorizationBase, SIMDPirates, SLEEFPirates, MacroTools
3+
using VectorizationBase, SIMDPirates, SLEEFPirates, MacroTools, Parameters
44
using VectorizationBase: REGISTER_SIZE, extract_data, num_vector_load_expr
55
using SIMDPirates: VECTOR_SYMBOLS, evadd, evmul
66
using MacroTools: @capture, prewalk, postwalk
@@ -693,4 +693,8 @@ for vec ∈ (false,true)
693693
end
694694
end
695695

696+
include("precompile.jl")
697+
_precompile_()
698+
699+
696700
end # module

src/costs.jl

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,36 @@
77
# end
88

99
struct InstructionCost
10-
scalar_latency::Int
11-
scalar_reciprical_throughput::Float64
1210
scaling::Float64 # sentinel values: -3 == no scaling; -2 == offset_scaling, -1 == linear scaling, >0 -> == latency == reciprical throughput
11+
scalar_reciprical_throughput::Float64
12+
scalar_latency::Int
1313
register_pressure::Int
1414
end
15-
InstructionCost(sl, srt, scaling = -3.0) = InstructionCost(sl, srt, scaling, 1)
15+
InstructionCost(sl::Int, srt::Float64, scaling::Float64 = -3.0) = InstructionCost(scaling, srt, sl, srt, 0)
1616

1717
function scalar_cost(instruction::InstructionCost)#, ::Type{T} = Float64) where {T}
18-
instruction.scalar_latency, instruction.scalar_reciprical_throughput
18+
@unpack scalar_reciprical_throughput, scalar_latency, register_pressure = instruction
19+
scalar_reciprical_throughput, scalar_latency, register_pressure
1920
end
2021
function vector_cost(instruction::InstructionCost, Wshift, sizeof_T)
21-
sl, srt = scalar_cost(instruction)
22+
srt, sl, srp = scalar_cost(instruction)
2223
scaling = instruction.scaling
23-
if scaling == -3.0 || Wshift == 0
24-
return sl, srt
25-
elseif scaling == -2.0
24+
if scaling == -3.0 || Wshift == 0 # No scaling
25+
return srt, sl, srp
26+
elseif scaling == -2.0 # offset scaling
2627
srt *= 1 << (Wshift + VectorizationBase.intlog2(sizeof_T) - 4)
2728
if (sizeof_T << Wshift) == 64 # VectorizationBase.REGISTER_SIZE # These instructions experience double latency with zmm
2829
sl += sl
2930
end
30-
elseif scaling == -1.0
31+
elseif scaling == -1.0 # linear scaling
3132
W = 1 << Wshift
3233
extra_latency = sl - srt
3334
srt *= W
3435
sl = srt + extra_latency
35-
else
36+
else # we assume custom cost, and that latency == recip_throughput
3637
sl, srt = scaling, scaling
3738
end
38-
sl, srt
39+
srt, sl, srp
3940
end
4041
function cost(instruction::InstructionCost, Wshift, sizeof_T)
4142
Wshift == 0 ? scalar_cost(instruction) : vector_cost(instruction, Wshift, sizeof_T)
@@ -48,12 +49,19 @@ function cost(instruction::Symbol, Wshift, sizeof_T)
4849
)
4950
end
5051

52+
5153
# Just a semi-reasonable assumption; should not be that sensitive to anything other than loads
5254
const OPAQUE_INSTRUCTION = InstructionCost(50, 50.0, -1.0, VectorizationBase.REGISTER_COUNT)
5355

56+
# Comments on setindex!
57+
# 1. Not a part of dependency chains, so not really twice as expensive as getindex?
58+
# 2. getindex loads a register, not setindex!, but we place cost on setindex!
59+
# as a heuristic means of approximating register pressure, since many loads can be
60+
# consolidated into a single register. The number of LICM-ed setindex!, on the other
61+
# hand, should indicate how many registers we're keeping live for the sake of eventually storing.
5462
const COST = Dict{Symbol,InstructionCost}(
55-
:getindex => InstructionCost(3,0.5),
56-
:setindex! => InstructionCost(3,1.0), # but not a part of dependency chains, so not really twice as expensive?
63+
:getindex => InstructionCost(3,0.5,-3.0,0),
64+
:setindex! => InstructionCost(3,1.0,-3.0,1),
5765
:(+) => InstructionCost(4,0.5),
5866
:(-) => InstructionCost(4,0.5),
5967
:(*) => InstructionCost(4,0.5),
@@ -66,7 +74,7 @@ const COST = Dict{Symbol,InstructionCost}(
6674
:(<) => InstructionCost(1, 0.5),
6775
:(>=) => InstructionCost(1, 0.5),
6876
:(<=) => InstructionCost(1, 0.5),
69-
:inv => InstructionCost(13,4.0,-2.0,2),
77+
:inv => InstructionCost(13,4.0,-2.0,1),
7078
:muladd => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
7179
:fma => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
7280
:vmuladd => InstructionCost(4,0.5), # + and * will fuse into this, so much of the time they're not twice as expensive
@@ -76,12 +84,15 @@ const COST = Dict{Symbol,InstructionCost}(
7684
:vfnmadd => InstructionCost(4,0.5), # + and -* will fuse into this, so much of the time they're not twice as expensive
7785
:vfnmsub => InstructionCost(4,0.5), # - and -* will fuse into this, so much of the time they're not twice as expensive
7886
:sqrt => InstructionCost(15,4.0,-2.0),
79-
:log => InstructionCost(20,20.0,40.0,21),
80-
:exp => InstructionCost(20,20.0,20.0,19),
81-
:sin => InstructionCost(18,15.0,68.0,24),
82-
:cos => InstructionCost(18,15.0,68.0,27),
83-
:sincos => InstructionCost(25,22.0,70.0,27)
87+
:log => InstructionCost(20,20.0,40.0,20),
88+
:exp => InstructionCost(20,20.0,20.0,18),
89+
:sin => InstructionCost(18,15.0,68.0,23),
90+
:cos => InstructionCost(18,15.0,68.0,26),
91+
:sincos => InstructionCost(25,22.0,70.0,26)
8492
)
93+
for (k, v) COST # so we can look up Symbol(typeof(function))
94+
COST[Symbol("typeof(", k, ")")] = v
95+
end
8596

8697

8798
# const SIMDPIRATES_COST = Dict{Symbol,InstructionCost}()

0 commit comments

Comments
 (0)