3
3
4
4
isdense (:: Type{<:DenseArray} ) = true
5
5
6
+ """
7
+ ShortVector{T} simply wraps a Vector{T}, but uses a different hash function that is faster for short vectors to support using it as the keys of a Dict.
8
+ This hash function scales O(N) with length of the vectors, so it is slow for long vectors.
9
+ """
10
+ struct ShortVector{T} <: DenseVector{T}
11
+ data:: Vector{T}
12
+ end
13
+ Base. @propagate_inbounds Base. getindex (x:: ShortVector , I... ) = x. data[I... ]
14
+ Base. @propagate_inbounds Base. setindex! (x:: ShortVector , v, I... ) = x. data[I... ] = v
15
+ @inbounds Base. length (x:: ShortVector ) = length (x. data)
16
+ @inbounds Base. size (x:: ShortVector ) = size (x. data)
17
+ @inbounds Base. strides (x:: ShortVector ) = strides (x. data)
18
+ @inbounds Base. push! (x:: ShortVector , v) = push! (x. data, v)
19
+ @inbounds Base. append! (x:: ShortVector , v) = append! (x. data, v)
20
+ function Base. hash (x:: ShortVector , h:: UInt )
21
+ @inbounds for n ∈ eachindex (x)
22
+ h = hash (x[n], h)
23
+ end
24
+ h
25
+ end
26
+
27
+
28
+
6
29
@enum NodeType begin
7
30
memload
8
31
memstore
9
- reduction
10
32
compute
11
33
end
12
34
@@ -15,61 +37,62 @@ struct Operation
15
37
elementbytes:: Int
16
38
instruction:: Symbol
17
39
node_type:: NodeType
40
+ # dependencies::ShortVector{Symbol}
41
+ dependencies:: Set{Symbol}
42
+ # dependencies::Set{Symbol}
18
43
parents:: Vector{Operation}
19
44
children:: Vector{Operation}
20
- metadata:: Vector{Float64}
45
+ numerical_metadata:: Vector{Float64}
46
+ symbolic_metadata:: Vector{Symbol}
21
47
function Operation (elementbytes, instruction, node_type)
22
48
new (
23
49
elementbytes, instruction, node_type,
24
- Operation[], Operation[], Float64[]
50
+ Set {Symbol} (), Operation[], Operation[], Float64[], Symbol []
25
51
)
26
52
end
27
53
end
28
54
29
- isreduction (op:: Operation ) = op. node_type == reduction
55
+ function isreduction (op:: Operation )
56
+ (op. node_type == memstore) && (length (op. symbolic_metadata) < length (op. dependencies)) && issubset (op. symbolic_metadata, op. dependencies)
57
+ end
30
58
isload (op:: Operation ) = op. node_type == memload
31
59
isstore (op:: Operation ) = op. node_type == memstore
32
60
accesses_memory (op:: Operation ) = isload (op) | isstore (op)
33
- Base. eltype (var:: Operation ) = op. outtype
34
-
35
- """
36
- ShortVector{T} simply wraps a Vector{T}, but uses a different hash function that is faster for short vectors to support using it as the keys of a Dict.
37
- This hash function scales O(N) with length of the vectors, so it is slow for long vectors.
38
- """
39
- struct ShortVector{T} <: DenseVector{T}
40
- data:: Vector{T}
41
- end
42
- Base. @propagate_inbounds Base. getindex (x:: ShortVector , I... ) = x. data[I... ]
43
- Base. @propagate_inbounds Base. setindex! (x:: ShortVector , v, I... ) = x. data[I... ] = v
44
- @inbounds Base. length (x:: ShortVector ) = length (x. data)
45
- @inbounds Base. size (x:: ShortVector ) = size (x. data)
46
- @inbounds Base. strides (x:: ShortVector ) = strides (x. data)
47
- @inbounds Base. push! (x:: ShortVector , v) = push! (x. data, v)
48
- @inbounds Base. append! (x:: ShortVector , v) = append! (x. data, v)
49
- function Base. hash (x:: ShortVector , h:: UInt )
50
- @inbounds for n ∈ eachindex (x)
51
- h = hash (x[n], h)
52
- end
53
- h
54
- end
61
+ elsize (op:: Operation ) = op. elementbytes
62
+ dependson (op:: Operation , sym:: Symbol ) = sym ∈ op. dependencies
55
63
56
64
function stride (op:: Operation , sym:: Symbol )
57
65
@assert accesses_memory (op) " This operation does not access memory!"
58
66
# access stride info?
59
67
end
60
- function
68
+ # function
61
69
62
70
struct Node
63
71
type:: DataType
64
72
end
65
73
74
+ struct Loop
75
+ itersymbol:: Symbol
76
+ rangehint:: Int
77
+ rangesym:: Symbol
78
+ hintexact:: Bool # if true, rangesym ignored and rangehint used for final lowering
79
+ end
80
+ function Loop (itersymbol:: Symbol , rangehint:: Int )
81
+ Loop ( itersymbol, rangehint, :undef , true )
82
+ end
83
+ function Loop (itersymbol:: Symbol , rangesym:: Symbol , rangehint:: Int = 1_000_000 )
84
+ Loop ( itersymbol, rangehint, rangesym, false )
85
+ end
86
+
66
87
# Must make it easy to iterate
67
88
struct LoopSet
89
+ loops:: Dict{Symbol,Loop} # sym === loops[sym].itersymbol
90
+ operations:: Vector{Operation}
68
91
69
92
end
70
93
71
94
function Base. length (ls:: LoopSet , is:: Symbol )
72
-
95
+ ls . loops[is] . rangehint
73
96
end
74
97
function variables (ls:: LoopSet )
75
98
@@ -78,7 +101,7 @@ function loopdependencies(var::Operation)
78
101
79
102
end
80
103
function sym (var:: Operation )
81
-
104
+
82
105
end
83
106
function instruction (var:: Operation )
84
107
89
112
function stride (var:: Operation , sym:: Symbol )
90
113
91
114
end
115
+ operations (ls:: LoopSet ) = ls. operations
92
116
function cost (var:: Operation , unrolled:: Symbol , dim:: Int )
93
117
c = cost (instruction (var), Wshift, T):: Int
94
118
if accesses_memory (var)
@@ -108,31 +132,31 @@ end
108
132
# Base._return_type()
109
133
110
134
function biggest_type (ls:: LoopSet )
111
-
135
+ maximum (elsize, ls . operations)
112
136
end
113
137
114
138
115
139
116
140
# evaluates cost of evaluating loop in given order
117
141
function evaluate_cost_unroll (
118
- ls:: LoopSet , order:: ShortVector{Symbol} , unrolled:: Symbol , max_cost = typemax (Int )
142
+ ls:: LoopSet , order:: ShortVector{Symbol} , unrolled:: Symbol , max_cost = typemax (Float64 )
119
143
)
120
144
included_vars = Set {Symbol} ()
121
145
nested_loop_syms = Set {Symbol} ()
122
146
total_cost = 0.0
123
147
iter = 1.0
124
148
# Need to check if fusion is possible
125
- # W, Wshift = VectorizationBase.pick_vector_width_shift(length(ls, unrolled), biggest_type(ls))::Tuple{Int,Int}
149
+ W, Wshift = VectorizationBase. pick_vector_width_shift (length (ls, unrolled), biggest_type (ls)):: Tuple{Int,Int}
126
150
for itersym ∈ order
127
151
# Add to set of defined symbles
128
152
push! (nested_loop_syms, itersym)
129
- liter = length (ls, itersym)
153
+ liter = Float64 ( length (ls, itersym) )
130
154
if itersym == unrolled
131
155
liter /= W
132
156
end
133
157
iter *= liter
134
158
# check which vars we can define at this level of loop nest
135
- for var ∈ variables (ls)
159
+ for var ∈ operations (ls)
136
160
# won't define if already defined...
137
161
sym (var) ∈ included_vars && continue
138
162
# it must also be a subset of defined symbols
@@ -141,14 +165,48 @@ function evaluate_cost_unroll(
141
165
push! (included_vars, sym (var))
142
166
143
167
total_cost += iter * cost (var, W, Wshift, unrolled, liter)
144
- total_cost > max_cost && return total_cost # abort
168
+ total_cost > max_cost && return total_cost # abort if more expensive; we only want to know the cheapest
145
169
end
146
170
end
171
+ total_cost
147
172
end
148
- function evaluate_cost_tile (
149
- ls:: LoopSet , order:: ShortVector{Symbol} , tiler, tilec, max_cost = typemax (Int)
173
+
174
+ # only covers unrolled ops; everything else considered lifted?
175
+ function depchain_cost! (
176
+ skip:: Set{Symbol} , ls:: LoopSet , op:: Operation , unrolled:: Symbol , Wshift:: Int , size_T:: Int
177
+ )
178
+
179
+ end
180
+
181
+ function determine_unroll_factor (
182
+ ls:: LoopSet , order:: ShortVector{Symbol} , unrolled:: Symbol , Wshift:: Int , size_T:: Int
150
183
)
184
+ # The strategy is to use an unroll factor of 1, unless there appears to be loop carried dependencies (ie, num_reductions > 0)
185
+ # The assumption here is that unrolling provides no real benefit, unless it is needed to enable OOO execution by breaking up these dependency chains
186
+ num_reductions = sum (isreduction, operations (ls))
187
+ iszero (num_reductions) && return 1
188
+ # So if num_reductions > 0, we set the unroll factor to be high enough so that the CPU can be kept busy
189
+ # if there are, U = max(1, round(Int, max(latency) * throughput / num_reductions)) = max(1, round(Int, latency / (recip_througput * num_reductions)))
190
+ latency = 0
191
+ recip_throughput = 0.0
192
+ visited_nodes = Set {Symbol} ()
193
+ for op ∈ operations (ls)
194
+ if isreduction (op) && dependson (op, unrolled)
195
+ l, rt = cost_of_chain ()
196
+ num_reductions += 1
197
+ sl, rt = cost (instruction (op), Wshift, size_T)
198
+ latency = max (sl, latency)
199
+ recip_throughput += rt
200
+ end
201
+ end
202
+
151
203
204
+
205
+ end
206
+ function evaluate_cost_tile (
207
+ ls:: LoopSet , order:: ShortVector{Symbol} , tiler, tilec, max_cost = typemax (Float64)
208
+ )
209
+
152
210
end
153
211
154
212
struct LoopOrders
0 commit comments