@@ -28,9 +28,8 @@ isdense(::Type{<:DenseArray}) = true
28
28
29
29
@enum OperationType begin
30
30
memload
31
+ compute
31
32
memstore
32
- compute_new
33
- compute_update
34
33
# accumulator
35
34
end
36
35
@@ -55,25 +54,23 @@ struct Operation
55
54
elementbytes:: Int
56
55
instruction:: Symbol
57
56
node_type:: OperationType
58
- # dependencies::Vector{Symbol}
59
57
dependencies:: Set{Symbol}
60
58
reduced_deps:: Set{Symbol}
61
- # dependencies::Set{Symbol}
62
59
parents:: Vector{Operation}
63
- children:: Vector{Operation}
60
+ # children::Vector{Operation}
64
61
numerical_metadata:: Vector{Int} # stride of -1 indicates dynamic
65
62
symbolic_metadata:: Vector{Symbol}
66
- # strides::Dict{Symbol,Union{Symbol,Int}}
67
63
function Operation (
68
64
identifier,
65
+ variable,
69
66
elementbytes,
70
67
instruction,
71
68
node_type,
72
69
variable = gensym ()
73
70
)
74
71
new (
75
72
identifier, variable, elementbytes, instruction, node_type,
76
- Set {Symbol} (), Operation[] , Operation[], Int[], Symbol[]# , Dict{Symbol,Union{Symbol,Int}}()
73
+ Set {Symbol} (), Set {Symbol} () , Operation[], Int[], Symbol[]
77
74
)
78
75
end
79
76
end
@@ -84,12 +81,13 @@ function isreduction(op::Operation)
84
81
(op. node_type == memstore) && (length (op. symbolic_metadata) < length (op. dependencies))# && issubset(op.symbolic_metadata, op.dependencies)
85
82
end
86
83
isload (op:: Operation ) = op. node_type == memload
84
+ iscompute (op:: Operation ) = op. node_type == compute
87
85
isstore (op:: Operation ) = op. node_type == memstore
88
86
accesses_memory (op:: Operation ) = isload (op) | isstore (op)
89
87
elsize (op:: Operation ) = op. elementbytes
90
88
dependson (op:: Operation , sym:: Symbol ) = sym ∈ op. dependencies
91
89
parents (op:: Operation ) = op. parents
92
- children (op:: Operation ) = op. children
90
+ # children(op::Operation) = op.children
93
91
loopdependencies (op:: Operation ) = op. dependencies
94
92
reduceddependencies (op:: Operation ) = op. reduced_deps
95
93
identifier (op:: Operation ) = op. identifier
@@ -159,29 +157,57 @@ end
159
157
160
158
# load/compute/store × isunroled × istiled × pre/post loop × Loop number
161
159
struct LoopOrder <: AbstractArray{Vector{Operation},5}
162
- oporder:: Array {Vector{Operation},5 }
160
+ oporder:: Vector {Vector{Operation}}
163
161
loopnames:: Vector{Symbol}
164
162
end
165
163
function LoopOrder (N:: Int )
166
- LoopOrder ( [ Operation[] for i ∈ 1 : 3 , j ∈ 1 : 2 , k ∈ 1 : 2 , l ∈ 1 : 2 , n ∈ 1 : N ], Vector {Symbol} (undef, N) )
164
+ LoopOrder ( [ Operation[] for i ∈ 1 : 24 N ], Vector {Symbol} (undef, N) )
167
165
end
166
+ LoopOrder () = LoopOrder (Vector{Operation}[])
168
167
Base. empty! (lo:: LoopOrder ) = foreach (empty!, lo. oporder)
169
- Base. size (lo:: LoopOrder ) = (3 ,2 ,2 ,2 ,size (lo. oporder,5 ))
170
- Base. @propagate_inbounds Base. getindex (lo:: LoopOrder , i... ) = lo. oporder[i... ]
168
+ function Base. resize! (lo:: LoopOrder , N:: Int )
169
+ Nold = length (lo. loopnames)
170
+ resize! (lo. oporder, 24 N)
171
+ for n ∈ 24 Nold+ 1 : 24 N
172
+ lo. oporder[n] = Operation[]
173
+ end
174
+ resize! (lo. loopnames, N)
175
+ lo
176
+ end
177
+ Base. size (lo:: LoopOrder ) = (3 ,2 ,2 ,2 ,length (lo. loopnames))
178
+ Base. @propagate_inbounds Base. getindex (lo:: LoopOrder , i:: Int ) = lo. oporder[i]
179
+ Base. @propagate_inbounds Base. getindex (lo:: LoopOrder , i... ) = lo. oporder[LinearIndices (size (lo))[i... ]]
171
180
172
181
# Must make it easy to iterate
173
182
struct LoopSet
174
183
loops:: Dict{Symbol,Loop} # sym === loops[sym].itersymbol
175
- # operations::Vector{Operation}
176
- loadops:: Vector{Operation} # Split them to make it easier to iterate over just a subset
177
- computeops:: Vector{Operation}
178
- storeops:: Vector{Operation}
179
- inner_reductions:: Set{UInt} # IDs of reduction operations nested within loops and stored.
184
+ opdict:: Dict{Symbol,Operation}
185
+ operations:: Vector{Operation} # Split them to make it easier to iterate over just a subset
186
+ # computeops::Vector{Operation}
187
+ # storeops::Vector{Operation}
180
188
outer_reductions:: Set{UInt} # IDs of reduction operations that need to be reduced at end.
181
189
loop_order:: LoopOrder
182
- # strideset::Vector{}
190
+ preamble:: Expr # TODO : add preamble to lowering
191
+ end
192
+ function LoopSet ()
193
+ LoopSet (
194
+ Dict {Symbol,Loop} (),
195
+ Dict {Symbol,Operation} (),
196
+ Operation[],
197
+ # Operation[],
198
+ # Operation[],
199
+ # Set{UInt}(),
200
+ Set {UInt} (),
201
+ LoopOrder (),
202
+ Expr (:block ,)
203
+ )
183
204
end
184
205
num_loops (ls:: LoopSet ) = length (ls. loops)
206
+ function oporder (ls:: LoopSet )
207
+ N = length (ls. loop_order. loopnames)
208
+ reshape (ls. loop_order. oporder, (3 ,2 ,2 ,2 ,N))
209
+ end
210
+ names (ls:: LoopSet ) = ls. loop_order. loopnames
185
211
isstaticloop (ls:: LoopSet , s:: Symbol ) = ls. loops[s]. hintexact
186
212
looprangehint (ls:: LoopSets , s:: Symbol ) = ls. loops[s]. rangehint
187
213
looprangesym (ls:: LoopSets , s:: Symbol ) = ls. loops[s]. rangesym
@@ -198,15 +224,71 @@ end
198
224
function Base. length (ls:: LoopSet , is:: Symbol )
199
225
ls. loops[is]. rangehint
200
226
end
201
- load_operations (ls:: LoopSet ) = ls. loadops
202
- compute_operations (ls:: LoopSet ) = ls. computeops
203
- store_operations (ls:: LoopSet ) = ls. storeops
204
- function operations (ls:: LoopSet )
205
- Base. Iterators. flatten ((
206
- load_operations (ls),
207
- compute_operations (ls),
208
- store_operations (ls)
209
- ))
227
+ # load_operations(ls::LoopSet) = ls.loadops
228
+ # compute_operations(ls::LoopSet) = ls.computeops
229
+ # store_operations(ls::LoopSet) = ls.storeops
230
+ # function operations(ls::LoopSet)
231
+ # Base.Iterators.flatten((
232
+ # load_operations(ls),
233
+ # compute_operations(ls),
234
+ # store_operations(ls)
235
+ # ))
236
+ # end
237
+ operations (ls:: LoopSet ) = ls. operations
238
+ function add_loop! (ls:: LoopSet , looprange:: Expr )
239
+ itersym = (looprange. args[1 ]):: Symbol
240
+ r = (looprange. args[2 ]):: Expr
241
+ @assert r. head === :call
242
+ f = first (r. args)
243
+ loop:: Loop = if f === :(:)
244
+ lower = r. args[2 ]
245
+ upper = r. args[3 ]
246
+ lii:: Bool = lower isa Integer
247
+ uii:: Bool = upper isa Integer
248
+ if lii & uii
249
+ Loop (itersym, 1 + convert (Int,upper) - convert (Int,lower))
250
+ else
251
+ N = gensym (:loop , itersym)
252
+ ex = if lii
253
+ Expr (:call , :- , upper, lower - 1 )
254
+ elseif uii
255
+ Expr (:call , :- , upper + 1 , lower)
256
+ else
257
+ Expr (:call , :- , Expr (:call , :+ , upper, 1 ), lower)
258
+ end
259
+ push! (ls. preamble. args, Expr (:(= ), N, ex))
260
+ Loop (itersym, N)
261
+ end
262
+ elseif f === :eachindex
263
+ N = gensym (:loop , itersym)
264
+ push! (ls. preamble. args, Expr (:(= ), N, Expr (:call , :length , r. args[2 ])))
265
+ Loop (itersym, N)
266
+ else
267
+ throw (" Unrecognized loop range type: $r ." )
268
+ end
269
+ ls. loops[itersym] = loop
270
+ nothing
271
+ end
272
+ function add_load! (ls:: LoopSet , indexed:: Symbol , indices:: AbstractVector )
273
+ Ninds = length (indices)
274
+
275
+
276
+
277
+ end
278
+ function add_load_getindex! (ls:: LoopSet , ex:: Expr )
279
+ add_load! (ls, ex. args[2 ], @view (ex. args[3 : end ]))
280
+ end
281
+ function add_load_ref! (ls:: LoopSet , ex:: Expr )
282
+ add_load! (ls, ex. args[1 ], @view (ex. args[2 : end ]))
283
+ end
284
+ function add_compute! (ls:: LoopSet , ex:: Expr )
285
+
286
+ end
287
+ function add_store! (ls:: LoopSet , ex:: Expr )
288
+
289
+ end
290
+ function Base. push! (ls:: LoopSet , ex:: Expr )
291
+
210
292
end
211
293
212
294
function fillorder! (ls:: LoopSet , order:: Vector{Symbol} , loopistiled:: Bool )
@@ -233,13 +315,7 @@ function fillorder!(ls::LoopSet, order::Vector{Symbol}, loopistiled::Bool)
233
315
included_vars[id] = true
234
316
isunrolled = (unrolled ∈ loopdependencies (op)) + 1
235
317
istiled = (loopistiled ? false : (tiled ∈ loopdependencies (op))) + 1
236
- optype = if isload (op)
237
- 1
238
- elseif isstore (op)
239
- 3
240
- else # if compute
241
- 2
242
- end
318
+ optype = Int (op. node_type)
243
319
after_loop = (length (reduceddependencies (op)) > 0 ) + 1
244
320
push! (lo[optype,isunrolled,istiled,after_loop,_n], op)
245
321
end
0 commit comments